[RELAY] Remove re-exports of tvm.transform (#5337)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 15 Apr 2020 00:03:15 +0000 (17:03 -0700)
committerGitHub <noreply@github.com>
Wed, 15 Apr 2020 00:03:15 +0000 (17:03 -0700)
38 files changed:
docs/api/python/ir.rst
docs/dev/convert_layout.rst
docs/dev/relay_pass_infra.rst
include/tvm/ir/transform.h
python/tvm/ir/json_compact.py
python/tvm/ir/transform.py
python/tvm/relay/__init__.py
python/tvm/relay/backend/interpreter.py
python/tvm/relay/qnn/transform.py
python/tvm/relay/quantize/quantize.py
python/tvm/relay/testing/__init__.py
python/tvm/relay/testing/py_converter.py
python/tvm/relay/transform/transform.py
src/ir/transform.cc
src/relay/transforms/print_ir.cc [deleted file]
tests/python/relay/test_op_level10.py
tests/python/relay/test_pass_alter_op_layout.py
tests/python/relay/test_pass_annotation.py
tests/python/relay/test_pass_canonicalize_cast.py
tests/python/relay/test_pass_combine_parallel_conv2d.py
tests/python/relay/test_pass_combine_parallel_dense.py
tests/python/relay/test_pass_convert_op_layout.py
tests/python/relay/test_pass_dead_code_elimination.py
tests/python/relay/test_pass_eliminate_common_subexpr.py
tests/python/relay/test_pass_eta_expand.py
tests/python/relay/test_pass_fold_constant.py
tests/python/relay/test_pass_fold_scale_axis.py
tests/python/relay/test_pass_lazy_gradient_init.py
tests/python/relay/test_pass_legalize.py
tests/python/relay/test_pass_mac_count.py
tests/python/relay/test_pass_manager.py
tests/python/relay/test_pass_partial_eval.py
tests/python/relay/test_pass_partition_graph.py
tests/python/relay/test_pass_qnn_legalize.py
tests/python/relay/test_pass_to_a_normal_form.py
tests/python/relay/test_pass_to_cps.py
tutorials/dev/relay_pass_infra.py
vta/python/vta/top/graphpack.py

index 1f0dc0c..c2a1a1e 100644 (file)
@@ -21,3 +21,11 @@ tvm.ir
    :members:
    :imported-members:
    :autosummary:
+
+
+tvm.transform
+-------------
+.. automodule:: tvm.transform
+   :members:
+   :imported-members:
+   :autosummary:
index 715d810..7345c15 100644 (file)
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
 
     # Convert the layout to NCHW
     # RemoveUnunsedFunctions is used to clean up the graph.
-    seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
+    seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                       relay.transform.ConvertLayout('NCHW')])
     with relay.transform.PassContext(opt_level=3):
         mod = seq(mod)
index b42e128..3b443fa 100644 (file)
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
     func = relay.Function([x], z2)
 
     # Customize the optimization pipeline.
-    seq = _transform.Sequential([
+    seq = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.FoldConstant(),
         relay.transform.EliminateCommonSubexpr(),
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
 
 .. code:: python
 
-    seq = _transform.Sequential([
+    seq = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.FoldConstant(),
         relay.transform.PrintIR(),
index 4c55204..3680f6d 100644 (file)
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
 
 /*!
  * \brief A special trace pass that prints the header and IR to LOG(INFO).
+ * \param header The header to be attached to the output.
+ * \param show_meta_data Whether should we show meta data.
  * \return The pass.
  */
-TVM_DLL Pass PrintIR(std::string header);
+TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
 
 }  // namespace transform
 }  // namespace tvm
index e091cd1..9a881cf 100644 (file)
@@ -106,7 +106,7 @@ def create_updater_06_to_07():
         "relay.PassInfo": _rename("transform.PassInfo"),
         "relay.PassContext": _rename("transform.PassContext"),
         "relay.ModulePass": _rename("transform.ModulePass"),
-        "relay.Sequantial": _rename("transform.Sequantial"),
+        "relay.Sequential": _rename("transform.Sequential"),
         # TIR
         "Variable": _update_tir_var("tir.Var"),
         "SizeVar": _update_tir_var("tir.SizeVar"),
index da74fb2..614f969 100644 (file)
@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
     return create_module_pass
 
 
-def PrintIR(header):
+def PrintIR(header="", show_meta_data=False):
     """A special trace pass that prints the header and IR.
 
     Parameters
@@ -337,8 +337,11 @@ def PrintIR(header):
     header : str
         The header to be displayed along with the dump.
 
+    show_meta_data : bool
+        A boolean flag to indicate if meta data should be printed.
+
     Returns
     --------
     The pass
     """
-    return _ffi_transform_api.PrintIR(header)
+    return _ffi_transform_api.PrintIR(header, show_meta_data)
index 1517cf9..4e52019 100644 (file)
@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
 # Scope builder
 ScopeBuilder = scope_builder.ScopeBuilder
 
-module_pass = transform.module_pass
-function_pass = transform.function_pass
-
 # Parser
 fromtext = parser.fromtext
 
 # Param Serialization
 save_param_dict = param_dict.save_param_dict
 load_param_dict = param_dict.load_param_dict
-
-# Pass manager
-PassInfo = transform.PassInfo
-PassContext = transform.PassContext
-Pass = transform.Pass
-ModulePass = transform.ModulePass
-FunctionPass = transform.FunctionPass
-Sequential = transform.Sequential
index 9c4be29..213a6c6 100644 (file)
@@ -210,10 +210,10 @@ class Interpreter(Executor):
         opt_mod : tvm.IRModule
             The optimized module.
         """
-        seq = transform.Sequential([transform.SimplifyInference(),
-                                    transform.FuseOps(0),
-                                    transform.ToANormalForm(),
-                                    transform.InferType()])
+        seq = tvm.transform.Sequential([transform.SimplifyInference(),
+                                        transform.FuseOps(0),
+                                        transform.ToANormalForm(),
+                                        transform.InferType()])
         return seq(self.mod)
 
     def _make_executor(self, expr=None):
index 6d38490..492c739 100644 (file)
@@ -60,7 +60,7 @@ def CanonicalizeOps():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that canonicalizes QNN ops to Relay ops.
     """
 
@@ -108,7 +108,7 @@ def Legalize():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that legalizes QNN ops.
     """
 
index 2ad4e18..958d0dc 100644 (file)
@@ -17,6 +17,7 @@
 #pylint: disable=unused-argument, not-context-manager
 """Automatic quantization toolkit."""
 import tvm.ir
+import tvm
 from tvm.runtime import Object
 
 from . import _quantize
@@ -240,7 +241,7 @@ def partition():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass for VTA rewrite.
     """
     return _quantize.QuantizePartition()
@@ -253,7 +254,7 @@ def annotate():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass for quantization annotation.
     """
     return _quantize.QuantizeAnnotate()
@@ -267,7 +268,7 @@ def realize():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass for quantization realization.
     """
     return _quantize.QuantizeRealize()
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
     """ Prerequisite optimization passes for quantization. Perform
     "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
     "CanonicalizeOps" optimization before quantization. """
-    optimize = _transform.Sequential([_transform.SimplifyInference(),
-                                      _transform.FoldConstant(),
-                                      _transform.FoldScaleAxis(),
-                                      _transform.CanonicalizeOps(),
-                                      _transform.FoldConstant()])
+    optimize = tvm.transform.Sequential(
+        [_transform.SimplifyInference(),
+         _transform.FoldConstant(),
+         _transform.FoldScaleAxis(),
+         _transform.CanonicalizeOps(),
+         _transform.FoldConstant()])
 
     if params:
         mod['main'] = _bind_params(mod['main'], params)
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
     """
     mod = prerequisite_optimize(mod, params)
 
-    calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
-                                            name="QuantizeCalibrate")
+    calibrate_pass = tvm.transform.module_pass(
+        calibrate(dataset), opt_level=1,
+        name="QuantizeCalibrate")
     quant_passes = [partition(),
                     annotate(),
                     calibrate_pass]
     if not current_qconfig().do_simulation:
         quant_passes.append(realize())
     quant_passes.append(_transform.FoldConstant())
-    quantize_seq = _transform.Sequential(quant_passes)
-    with _transform.PassContext(opt_level=3,
-                                required_pass=["QuantizeAnnotate",
-                                               "QuantizeCalibrate",
-                                               "QuantizeRealize"]):
+    quantize_seq = tvm.transform.Sequential(quant_passes)
+    with tvm.transform.PassContext(opt_level=3,
+                                   required_pass=["QuantizeAnnotate",
+                                                  "QuantizeCalibrate",
+                                                  "QuantizeRealize"]):
         with quantize_context():
             mod = quantize_seq(mod)
 
index 54c9091..58c6fe8 100644 (file)
@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
 from ..transform import gradient
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]
index eec5e16..61a04ec 100644 (file)
@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
 
         # necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
         # and fusion (to get primitive functions)
-        opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
-                                           relay.transform.FuseOps(fuse_opt_level=0)])
+        opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
+                                         relay.transform.FuseOps(fuse_opt_level=0)])
         mod = opts(mod)
         optimized = mod['main']
         return optimized if isinstance(unwrapped, Function) else optimized.body
index 918894f..292c5fd 100644 (file)
@@ -22,10 +22,9 @@ import types
 import inspect
 import functools
 
-import tvm
+import tvm.ir
 from tvm import te
 from tvm.runtime import ndarray as _nd
-from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
 
 from tvm import relay
 from . import _ffi_api
@@ -78,12 +77,13 @@ def build_config(opt_level=2,
     pass_context: PassContext
         The pass context for optimizations.
     """
-    return PassContext(opt_level, fallback_device, required_pass,
-                       disabled_pass, trace)
+    return tvm.ir.transform.PassContext(
+        opt_level, fallback_device, required_pass,
+        disabled_pass, trace)
 
 
 @tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(Pass):
+class FunctionPass(tvm.ir.transform.Pass):
     """A pass that works on each tvm.relay.Function in a module. A function
     pass class should be created through `function_pass`.
     """
@@ -94,7 +94,7 @@ def InferType():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered type inference pass.
     """
     return _ffi_api.InferType()
@@ -106,7 +106,7 @@ def FoldScaleAxis():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass to fold expressions.
 
     Note
@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass to backward fold expressions.
 
     Note
@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass to remove unused functions.
     """
     if entry_functions is None:
@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass to forward fold expressions.
 
     Note
@@ -174,7 +174,7 @@ def SimplifyInference():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass to perform operator simplification.
     """
     return _ffi_api.SimplifyInference()
@@ -185,7 +185,7 @@ def FastMath():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass to perform fast math operations.
     """
     return _ffi_api.FastMath()
@@ -198,7 +198,7 @@ def CanonicalizeOps():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass performing the canonicalization.
     """
     return _ffi_api.CanonicalizeOps()
@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that eliminates the dead code in a Relay program.
     """
     return _ffi_api.DeadCodeElimination(inline_once)
@@ -227,7 +227,7 @@ def LazyGradientInit():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         A pass which delays and/or reduces memory allocation,
         by lazily allocating 0 or one filled tensors.
     """
@@ -238,7 +238,7 @@ def FoldConstant():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass for constant folding.
     """
     return _ffi_api.FoldConstant()
@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass for operator fusion.
     """
     return _ffi_api.FuseOps(fuse_opt_level)
@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that combines parallel conv2d operators.
     """
     return _ffi_api.CombineParallelConv2D(min_num_branches)
@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that combines parallel dense operators.
     """
     return _ffi_api.CombineParallelDense(min_num_branches)
@@ -318,7 +318,7 @@ def AlterOpLayout():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that alters the layout of operators.
     """
     return _ffi_api.AlterOpLayout()
@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that rewrites an expr.
     """
     return _ffi_api.Legalize(legalize_map_attr_name)
@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that merges operators into a single composite
         relay function.
     """
@@ -413,7 +413,7 @@ def MergeCompilerRegions():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that merges compiler regions.
     """
     return _ffi_api.MergeCompilerRegions()
@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that rewrites an expression with annotated
         `on_device` operators.
     """
@@ -448,7 +448,7 @@ def ToANormalForm():
 
     Returns
     -------
-    ret: Union[tvm.relay.Pass, tvm.relay.Expr]
+    ret: Union[tvm.transform.Pass, tvm.relay.Expr]
         The registered pass that transforms an expression into A Normal Form.
     """
     return _ffi_api.ToANormalForm()
@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
 
     Returns
     -------
-    result: tvm.relay.Pass
+    result: tvm.transform.Pass
         The registered pass that transforms an expression into CPS.
     """
     return _ffi_api.to_cps(expr, mod)
@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that eta expands an expression.
     """
     return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
@@ -492,7 +492,7 @@ def ToGraphNormalForm():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that transforms an expression into Graph Normal Form.
     """
     return _ffi_api.ToGraphNormalForm()
@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that eliminates common subexpressions.
     """
     return _ffi_api.EliminateCommonSubexpr(fskip)
@@ -527,7 +527,7 @@ def PartialEvaluate():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that performs partial evaluation on an expression.
     """
     return _ffi_api.PartialEvaluate()
@@ -539,7 +539,7 @@ def CanonicalizeCast():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that canonicalizes cast expression.
     """
     return _ffi_api.CanonicalizeCast()
@@ -551,36 +551,19 @@ def LambdaLift():
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The registered pass that lifts the lambda function.
     """
     return _ffi_api.LambdaLift()
 
 
-def PrintIR(show_meta_data=True):
-    """
-    Print the IR for a module to help debugging.
-
-    Parameters
-    ----------
-    show_meta_data : bool
-        A boolean flag to indicate if meta data should be printed.
-
-    Returns
-    -------
-    ret : tvm.relay.Pass
-        The registered pass that prints the module IR.
-    """
-    return _ffi_api.PrintIR(show_meta_data)
-
-
 def PartitionGraph():
     """Partition a Relay program into regions that can be executed on different
     backends.
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that partitions the Relay program.
     """
     return _ffi_api.PartitionGraph()
@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
 
     Returns
     -------
-    ret : tvm.relay.Pass
+    ret : tvm.transform.Pass
         The annotated pass that wrapps ops with subgraph_start and
         subgraph_end.
     """
@@ -614,7 +597,7 @@ def Inline():
 
     Returns
     -------
-    ret: tvm.relay.Pass
+    ret: tvm.transform.Pass
         The registered pass that performs inlining for a Relay IR module.
     """
     return _ffi_api.Inline()
@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
     def create_function_pass(pass_arg):
         """Internal function that creates a function pass"""
         fname = name if name else pass_arg.__name__
-        info = PassInfo(opt_level, fname, required)
+        info = tvm.transform.PassInfo(opt_level, fname, required)
         if inspect.isclass(pass_arg):
             return _wrap_class_function_pass(pass_arg, info)
         if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
index 0161cb3..c1547d5 100644 (file)
@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
 .set_body_typed(PassContext::Internal::ExitScope);
 
 
-Pass PrintIR(std::string header) {
-  auto pass_func =[header](IRModule mod, const PassContext& ctx) {
+Pass PrintIR(std::string header, bool show_meta_data) {
+  auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
     LOG(INFO) << "PrintIR(" << header << "):\n"
-              << mod;
+              << AsText(mod, show_meta_data);
     return mod;
   };
   return CreateModulePass(pass_func, 0, "PrintIR", {});
diff --git a/src/relay/transforms/print_ir.cc b/src/relay/transforms/print_ir.cc
deleted file mode 100644 (file)
index cf06b50..0000000
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/relay/transforms/print_ir.cc
- *
- * \brief Print the module IR to help debugging.
- */
-#include <tvm/relay/expr.h>
-#include <tvm/relay/transform.h>
-
-namespace tvm {
-namespace relay {
-
-namespace transform {
-
-Pass PrintIR(bool show_meta_data) {
-  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
-    [=](IRModule m, PassContext pc) {
-      LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
-      return m;
-  };
-  return CreateModulePass(pass_func, 0, "PrintIR", {});
-}
-
-TVM_REGISTER_GLOBAL("relay._transform.PrintIR")
-.set_body_typed(PrintIR);
-
-}  // namespace transform
-
-}  // namespace relay
-}  // namespace tvm
index 30e2506..5e57c80 100644 (file)
@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
     df = transform.gradient(run_infer_type(f))
 
     # run PE and DCE
-    with transform.PassContext(opt_level=3):
+    with tvm.transform.PassContext(opt_level=3):
         passes = [transform.PartialEvaluate(),
                   transform.DeadCodeElimination(inline_once=True)]
-        mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+        mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
         df = mod["main"]
 
     df_parsed = relay.parser.fromtext(
@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
     df = transform.gradient(run_infer_type(f))
 
     # run PE and DCE
-    with transform.PassContext(opt_level=3):
+    with tvm.transform.PassContext(opt_level=3):
         passes = [transform.PartialEvaluate(),
                   transform.DeadCodeElimination(inline_once=True)]
-        mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+        mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
         df = mod["main"]
 
     df_parsed = relay.parser.fromtext(
index a30492f..2a2e265 100644 (file)
@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
index ea92546..582d46a 100644 (file)
@@ -28,8 +28,8 @@ from tvm.relay import transform
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     return mod["main"]
 
index 7b6617a..e13547b 100644 (file)
@@ -54,9 +54,9 @@ def test_canonicalize_cast():
         bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
         y = before(data, conv_weight, bias1, bias2)
         mod = tvm.IRModule.from_expr(y)
-        seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
+        seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
                                      _transform.InferType()])
-        with _transform.PassContext(opt_level=3):
+        with tvm.transform.PassContext(opt_level=3):
             mod = seq(mod)
         y = mod["main"]
         y_expected = expected(data, conv_weight, bias1, bias2)
index 345f068..7f7f185 100644 (file)
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
     return mod["main"]
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     return mod["main"]
index f0f2e18..12beafb 100644 (file)
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
     return mod["main"]
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     return mod["main"]
index c783971..c5a7b0e 100644 (file)
@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
index 60dfa62..35fd444 100644 (file)
@@ -47,7 +47,7 @@ e = env()
 
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]
index 89e3b67..7af524d 100644 (file)
@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
 
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]
index 84ff54a..e0a189b 100644 (file)
@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
             @aux
         }
     """)
-    seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
-    with _transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     expected = relay.fromtext(r"""
         v0.0.4
@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
             Cons
         }
     """)
-    seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
-    with _transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     expected = relay.fromtext(r"""
         v0.0.4
index 3ddafd7..4f44d2b 100644 (file)
@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
 
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
 
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
@@ -174,7 +174,7 @@ def test_fold_batch_norm():
         add = relay.add(conv, bias)
         return relay.Function(relay.analysis.free_vars(add), add)
 
-    remove_bn_pass = transform.Sequential([
+    remove_bn_pass = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.SimplifyInference(),
         relay.transform.FoldConstant(),
index bf2a708..d7c437a 100644 (file)
@@ -26,7 +26,7 @@ def _get_positive_scale(size):
 
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]
index f9c762e..4149268 100644 (file)
@@ -80,7 +80,7 @@ def test_add_tuple():
 
   mod["main"] = y
   mod = transform.LazyGradientInit()(mod)
-  mod = transform.PrintIR(show_meta_data=True)(mod)
+  mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
   y = mod["main"]
 
   assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
@@ -116,7 +116,7 @@ def test_mult():
 def test_ret_tuple():
   """Test tuple return type. Check types and semantic equivalence."""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -141,7 +141,7 @@ def test_ret_tuple():
 def test_add_broadcast():
   """Test adding matrices of different size. Check types and semantic equivalence."""
   mod = tvm.IRModule()
-  
+
   shape1 = (3, 4, 1)
   shape2 = (1, 5)
   dtype = 'float32'
@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
   """Simple test with reverse mode ad."""
   # of f(x) = x
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
 def test_multivar_reverse_ad():
   """Simple test with multivariate reverse mode ad."""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
 def test_after_partial_eval():
   """Test transformation following reverse mode ad and PartialEval"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -248,7 +248,7 @@ def test_after_partial_eval():
   mod["main"] = back_func
   back_func = mod["main"]
 
-  seq = transform.Sequential([
+  seq = tvm.transform.Sequential([
     transform.PartialEvaluate(),
     transform.LazyGradientInit(),
     transform.DeadCodeElimination()
@@ -270,7 +270,7 @@ def test_after_partial_eval():
 def test_before_partial_eval():
   """Test transformation before PartialEval"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -284,7 +284,7 @@ def test_before_partial_eval():
   back_func = run_infer_type(back_func)
 
   mod["main"] = back_func
-  seq = transform.Sequential([
+  seq = tvm.transform.Sequential([
     transform.LazyGradientInit(),
     transform.PartialEvaluate(),
     transform.DeadCodeElimination()
@@ -306,7 +306,7 @@ def test_before_partial_eval():
 def test_zeros():
   """Simple test using "zeros" op"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -328,7 +328,7 @@ def test_zeros():
 def test_ones():
   """Simple test using "ones" op"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -350,7 +350,7 @@ def test_ones():
 def test_zeros_like():
   """Simple test using "zeros_like" op"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
@@ -372,7 +372,7 @@ def test_zeros_like():
 def test_ones_like():
   """Simple test using "ones_like" op"""
   mod = tvm.IRModule()
-  
+
   shape = (10, 10)
   dtype = 'float32'
   t = relay.TensorType(shape, dtype)
index 1456700..0882149 100644 (file)
@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
index 697aad8..d490ac7 100644 (file)
@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
 
 
 def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]
index 0a6555b..28ccf6f 100644 (file)
@@ -129,13 +129,13 @@ def test_module_pass():
     opt_tester = OptTester(mod)
     pass_ctx = None
 
-    @_transform.module_pass(opt_level=opt_level, name=pass_name)
+    @tvm.transform.module_pass(opt_level=opt_level, name=pass_name)
     def transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
     def test_pass_registration():
         mod_pass = transform
-        assert isinstance(mod_pass, _transform.ModulePass)
+        assert isinstance(mod_pass, tvm.transform.ModulePass)
         pass_info = mod_pass.info
         assert pass_info.name == pass_name
         assert pass_info.opt_level == opt_level
@@ -143,8 +143,8 @@ def test_module_pass():
     def test_pass_registration_no_decorator():
         def direct_transform(expr, ctx):
             return opt_tester.transform(expr, ctx)
-        mod_pass = _transform.module_pass(direct_transform, opt_level=3)
-        assert isinstance(mod_pass, _transform.ModulePass)
+        mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
+        assert isinstance(mod_pass, tvm.transform.ModulePass)
         pass_info = mod_pass.info
         assert pass_info.name == "direct_transform"
         assert pass_info.opt_level == 3
@@ -285,7 +285,7 @@ def test_function_pass():
 
 
 def test_module_class_pass():
-    @relay.transform.module_pass(opt_level=1)
+    @tvm.transform.module_pass(opt_level=1)
     class TestPipeline:
         """Simple test function to replace one argument to another."""
         def __init__(self, new_mod, replace):
@@ -309,7 +309,7 @@ def test_module_class_pass():
 
 
 def test_pass_info():
-    info = relay.transform.PassInfo(opt_level=1, name="xyz")
+    info = tvm.transform.PassInfo(opt_level=1, name="xyz")
     assert info.opt_level == 1
     assert info.name == "xyz"
 
@@ -350,7 +350,7 @@ def test_sequential_pass():
     opt_tester = OptTester(mod)
     pass_ctx = None
 
-    @_transform.module_pass(opt_level=1)
+    @tvm.transform.module_pass(opt_level=1)
     def mod_transform(expr, ctx):
         return opt_tester.transform(expr, ctx)
 
@@ -367,21 +367,21 @@ def test_sequential_pass():
         passes = [module_pass, function_pass]
         opt_level = 2
         pass_name = "sequential"
-        sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
+        sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
         pass_info = sequential.info
         assert pass_info.name == pass_name
         assert pass_info.opt_level == opt_level
 
     def test_no_pass():
         passes = []
-        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
         ret_mod = sequential(mod)
         mod_func = ret_mod[v_sub]
         check_func(sub, mod_func)
 
     def test_only_module_pass():
         passes = [module_pass]
-        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
         with relay.build_config(required_pass=["mod_transform"]):
             ret_mod = sequential(mod)
         # Check the subtract function.
@@ -396,7 +396,7 @@ def test_sequential_pass():
     def test_only_function_pass():
         # Check the subtract function.
         passes = [function_pass]
-        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
         with relay.build_config(required_pass=["func_transform"]):
             ret_mod = sequential(mod)
         _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
@@ -411,7 +411,7 @@ def test_sequential_pass():
         # function pass.
         mod = tvm.IRModule({v_sub: sub, v_log: log})
         passes = [module_pass, function_pass]
-        sequential = _transform.Sequential(opt_level=1, passes=passes)
+        sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
         required = ["mod_transform", "func_transform"]
         with relay.build_config(required_pass=required):
             ret_mod = sequential(mod)
@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
         z1 = relay.add(z, z)
         return relay.Function([x], z1)
 
-    seq = _transform.Sequential([
+    seq = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.FoldConstant(),
         relay.transform.EliminateCommonSubexpr(),
@@ -507,10 +507,10 @@ def test_print_ir(capfd):
     y = relay.multiply(y, relay.const(2, "float32"))
     func = relay.Function([x], y)
 
-    seq = _transform.Sequential([
+    seq = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.FoldConstant(),
-        relay.transform.PrintIR(),
+        tvm.transform.PrintIR(),
         relay.transform.DeadCodeElimination()
     ])
 
@@ -520,7 +520,7 @@ def test_print_ir(capfd):
 
     out = capfd.readouterr().err
 
-    assert "Dumping the module IR" in out
+    assert "PrintIR" in out
     assert "multiply" in out
 
 __TRACE_COUNTER__ = 0
@@ -539,7 +539,7 @@ def test_print_debug_callback():
     y = relay.multiply(y, relay.const(2, "float32"))
     func = relay.Function([x], y)
 
-    seq = _transform.Sequential([
+    seq = tvm.transform.Sequential([
         relay.transform.InferType(),
         relay.transform.FoldConstant(),
         relay.transform.DeadCodeElimination()
index 0f3eea6..45593b4 100644 (file)
@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
     if mod:
         assert isinstance(expr, Function)
         mod["main"] = expr
-        seq = transform.Sequential(passes)
+        seq = tvm.transform.Sequential(passes)
         mod = seq(mod)
         return mod["main"]
     return run_opt_pass(expr, passes)
index 5148d4e..2ee8538 100644 (file)
@@ -496,7 +496,7 @@ def test_function_lifting():
         op_list = ["nn.batch_norm", "nn.conv2d"]
         mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
 
-        opt_pass = transform.Sequential([
+        opt_pass = tvm.transform.Sequential([
             transform.InferType(),
             transform.PartitionGraph(),
             transform.SimplifyInference(),
@@ -578,7 +578,7 @@ def test_function_lifting_inline():
         op_list = ["nn.batch_norm", "nn.conv2d"]
         mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
 
-        opt_pass = transform.Sequential([
+        opt_pass = tvm.transform.Sequential([
             transform.InferType(),
             transform.PartitionGraph(),
             transform.SimplifyInference(),
@@ -878,13 +878,13 @@ def test_dnnl_fuse():
         # This is required for constant folding
         mod["main"] = bind_params_by_name(mod["main"], params)
 
-        remove_bn_pass = transform.Sequential([
+        remove_bn_pass = tvm.transform.Sequential([
             transform.InferType(),
             transform.SimplifyInference(),
             transform.FoldConstant(),
             transform.FoldScaleAxis(),
         ])
-        composite_partition = transform.Sequential([
+        composite_partition = tvm.transform.Sequential([
             remove_bn_pass,
             transform.MergeComposite(pattern_table),
             transform.AnnotateTarget("dnnl"),
index c291c4e..5f7deff 100644 (file)
@@ -37,8 +37,8 @@ def alpha_equal(x, y):
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
         mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
index d7babf3..5a63db7 100644 (file)
@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
     mod = tvm.IRModule.from_expr(expr)
-    seq = transform.Sequential(passes)
-    with transform.PassContext(opt_level=3):
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
        mod = seq(mod)
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
index 4aaa9a0..6edf185 100644 (file)
@@ -71,7 +71,8 @@ def test_cps_pe():
         x = run_infer_type(x)
         y = un_cps(x)
         y = run_infer_type(y)
-        x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
+        x = run_opt_pass(x, tvm.transform.Sequential(
+            [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
         assert Feature.fRefCreate not in detect_feature(x)
     unit = relay.Function([], relay.const(0., dtype='float32'))
     f_ref = relay.Var("f_ref")
index 6b844ff..980d96c 100644 (file)
@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
 The optimizations of a Relay program could be applied at various granularity,
 namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
 and py:class:`tvm.relay.transform.ModulePass`
-respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes
+respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
 on a Relay program where the dependencies between passes can be resolved by the
 pass infra. For more details about each type of these passes, please refer to
 the :ref:`relay-pass-infra`
@@ -130,22 +130,22 @@ print(mod)
 # fusion, as this pass generates let bindings for each expression to
 # canonicalize a Relay program.
 #
-# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling
+# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
 # these issues explicitly by specifying the required passes of each pass and
 # packing them as a whole to execute. For example, the same passes can now be
-# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is
+# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
 # similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
 # and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
 # For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
 # `Modules` that will be added to build a network. It focuses on the network
-# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing
+# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
 # pass.
 
-# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential`
+# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
 f = example()
 mod = tvm.IRModule.from_expr(f)
 # Glob the interested passes.
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
                                   relay.transform.EliminateCommonSubexpr(),
                                   relay.transform.FuseOps(fuse_opt_level=2)])
 mod1 = seq(mod)
@@ -156,7 +156,7 @@ print(mod1)
 # identical addition operations. This is because `EliminateCommonSubexpr`
 # was not actually performed. The reason is because only the passes that have
 # optimization level less or equal to 2 will be executed by default under
-# :py:class:`tvm.relay.transform.Sequential`. The pass infra,
+# :py:class:`tvm.transform.Sequential`. The pass infra,
 # however, provides a configuration interface
 # for users to customize the optimization level that they want to execute.
 
@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
     mod4 = seq(mod)
 print(mod4)
 
-seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()])
+seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
 with relay.build_config(opt_level=3):
     with tvm.target.create("llvm"):
         mod5 = seq1(mod)
@@ -237,11 +237,11 @@ print(mod3)
 
 f = example()
 mod = tvm.IRModule.from_expr(f)
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
-                                  relay.transform.PrintIR(False),
-                                  relay.transform.EliminateCommonSubexpr(),
-                                  relay.transform.FuseOps(),
-                                  relay.transform.PrintIR(False)])
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
+                                tvm.transform.PrintIR(),
+                                relay.transform.EliminateCommonSubexpr(),
+                                relay.transform.FuseOps(),
+                                tvm.transform.PrintIR()])
 with relay.build_config(opt_level=3):
     mod = seq(mod)
 
index aca00a6..2334de7 100644 (file)
@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
 
 def run_opt_pass(expr, opt_pass):
     """Exectue a relay pass."""
-    assert isinstance(opt_pass, transform.Pass)
+    assert isinstance(opt_pass, tvm.transform.Pass)
     mod = tvm.IRModule.from_expr(expr)
     mod = opt_pass(mod)
     entry = mod["main"]