:members:
:imported-members:
:autosummary:
+
+
+tvm.transform
+-------------
+.. automodule:: tvm.transform
+ :members:
+ :imported-members:
+ :autosummary:
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
- seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
+ seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
func = relay.Function([x], z2)
# Customize the optimization pipeline.
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
.. code:: python
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
+ * \param header The header to be attached to the output.
+ * \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
-TVM_DLL Pass PrintIR(std::string header);
+TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
} // namespace transform
} // namespace tvm
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
- "relay.Sequantial": _rename("transform.Sequantial"),
+ "relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
return create_module_pass
-def PrintIR(header):
+def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
header : str
The header to be displayed along with the dump.
+ show_meta_data : bool
+ A boolean flag to indicate if meta data should be printed.
+
Returns
--------
The pass
"""
- return _ffi_transform_api.PrintIR(header)
+ return _ffi_transform_api.PrintIR(header, show_meta_data)
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder
-module_pass = transform.module_pass
-function_pass = transform.function_pass
-
# Parser
fromtext = parser.fromtext
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict
-
-# Pass manager
-PassInfo = transform.PassInfo
-PassContext = transform.PassContext
-Pass = transform.Pass
-ModulePass = transform.ModulePass
-FunctionPass = transform.FunctionPass
-Sequential = transform.Sequential
opt_mod : tvm.IRModule
The optimized module.
"""
- seq = transform.Sequential([transform.SimplifyInference(),
- transform.FuseOps(0),
- transform.ToANormalForm(),
- transform.InferType()])
+ seq = tvm.transform.Sequential([transform.SimplifyInference(),
+ transform.FuseOps(0),
+ transform.ToANormalForm(),
+ transform.InferType()])
return seq(self.mod)
def _make_executor(self, expr=None):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
+import tvm
from tvm.runtime import Object
from . import _quantize
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
- optimize = _transform.Sequential([_transform.SimplifyInference(),
- _transform.FoldConstant(),
- _transform.FoldScaleAxis(),
- _transform.CanonicalizeOps(),
- _transform.FoldConstant()])
+ optimize = tvm.transform.Sequential(
+ [_transform.SimplifyInference(),
+ _transform.FoldConstant(),
+ _transform.FoldScaleAxis(),
+ _transform.CanonicalizeOps(),
+ _transform.FoldConstant()])
if params:
mod['main'] = _bind_params(mod['main'], params)
"""
mod = prerequisite_optimize(mod, params)
- calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
- name="QuantizeCalibrate")
+ calibrate_pass = tvm.transform.module_pass(
+ calibrate(dataset), opt_level=1,
+ name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
- quantize_seq = _transform.Sequential(quant_passes)
- with _transform.PassContext(opt_level=3,
- required_pass=["QuantizeAnnotate",
- "QuantizeCalibrate",
- "QuantizeRealize"]):
+ quantize_seq = tvm.transform.Sequential(quant_passes)
+ with tvm.transform.PassContext(opt_level=3,
+ required_pass=["QuantizeAnnotate",
+ "QuantizeCalibrate",
+ "QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)
from ..transform import gradient
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
- opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
- relay.transform.FuseOps(fuse_opt_level=0)])
+ opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
+ relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
return optimized if isinstance(unwrapped, Function) else optimized.body
import inspect
import functools
-import tvm
+import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd
-from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
from tvm import relay
from . import _ffi_api
pass_context: PassContext
The pass context for optimizations.
"""
- return PassContext(opt_level, fallback_device, required_pass,
- disabled_pass, trace)
+ return tvm.ir.transform.PassContext(
+ opt_level, fallback_device, required_pass,
+ disabled_pass, trace)
@tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(Pass):
+class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered type inference pass.
"""
return _ffi_api.InferType()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to fold expressions.
Note
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to backward fold expressions.
Note
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to forward fold expressions.
Note
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass to perform operator simplification.
"""
return _ffi_api.SimplifyInference()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass to perform fast math operations.
"""
return _ffi_api.FastMath()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass performing the canonicalization.
"""
return _ffi_api.CanonicalizeOps()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _ffi_api.DeadCodeElimination(inline_once)
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass for constant folding.
"""
return _ffi_api.FoldConstant()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level)
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that combines parallel conv2d operators.
"""
return _ffi_api.CombineParallelConv2D(min_num_branches)
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that alters the layout of operators.
"""
return _ffi_api.AlterOpLayout()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that rewrites an expr.
"""
return _ffi_api.Legalize(legalize_map_attr_name)
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that merges operators into a single composite
relay function.
"""
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
Returns
-------
- ret: Union[tvm.relay.Pass, tvm.relay.Expr]
+ ret: Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _ffi_api.ToANormalForm()
Returns
-------
- result: tvm.relay.Pass
+ result: tvm.transform.Pass
The registered pass that transforms an expression into CPS.
"""
return _ffi_api.to_cps(expr, mod)
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that eta expands an expression.
"""
return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return _ffi_api.ToGraphNormalForm()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
return _ffi_api.EliminateCommonSubexpr(fskip)
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _ffi_api.PartialEvaluate()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that canonicalizes cast expression.
"""
return _ffi_api.CanonicalizeCast()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that lifts the lambda function.
"""
return _ffi_api.LambdaLift()
-def PrintIR(show_meta_data=True):
- """
- Print the IR for a module to help debugging.
-
- Parameters
- ----------
- show_meta_data : bool
- A boolean flag to indicate if meta data should be printed.
-
- Returns
- -------
- ret : tvm.relay.Pass
- The registered pass that prints the module IR.
- """
- return _ffi_api.PrintIR(show_meta_data)
-
-
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _ffi_api.Inline()
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
- info = PassInfo(opt_level, fname, required)
+ info = tvm.transform.PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
.set_body_typed(PassContext::Internal::ExitScope);
-Pass PrintIR(std::string header) {
- auto pass_func =[header](IRModule mod, const PassContext& ctx) {
+Pass PrintIR(std::string header, bool show_meta_data) {
+ auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n"
- << mod;
+ << AsText(mod, show_meta_data);
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/relay/transforms/print_ir.cc
- *
- * \brief Print the module IR to help debugging.
- */
-#include <tvm/relay/expr.h>
-#include <tvm/relay/transform.h>
-
-namespace tvm {
-namespace relay {
-
-namespace transform {
-
-Pass PrintIR(bool show_meta_data) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
- return m;
- };
- return CreateModulePass(pass_func, 0, "PrintIR", {});
-}
-
-TVM_REGISTER_GLOBAL("relay._transform.PrintIR")
-.set_body_typed(PrintIR);
-
-} // namespace transform
-
-} // namespace relay
-} // namespace tvm
df = transform.gradient(run_infer_type(f))
# run PE and DCE
- with transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
- mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+ mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
df = transform.gradient(run_infer_type(f))
# run PE and DCE
- with transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
- mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+ mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod["main"]
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
y = before(data, conv_weight, bias1, bias2)
mod = tvm.IRModule.from_expr(y)
- seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
+ seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
_transform.InferType()])
- with _transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
y = mod["main"]
y_expected = expected(data, conv_weight, bias1, bias2)
return mod["main"]
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
return mod["main"]
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
@aux
}
""")
- seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
- with _transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
Cons
}
""")
- seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
- with _transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add)
- remove_bn_pass = transform.Sequential([
+ remove_bn_pass = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
- mod = transform.PrintIR(show_meta_data=True)(mod)
+ mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
def test_ret_tuple():
"""Test tuple return type. Check types and semantic equivalence."""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_add_broadcast():
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod = tvm.IRModule()
-
+
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
"""Simple test with reverse mode ad."""
# of f(x) = x
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_multivar_reverse_ad():
"""Simple test with multivariate reverse mode ad."""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_after_partial_eval():
"""Test transformation following reverse mode ad and PartialEval"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
mod["main"] = back_func
back_func = mod["main"]
- seq = transform.Sequential([
+ seq = tvm.transform.Sequential([
transform.PartialEvaluate(),
transform.LazyGradientInit(),
transform.DeadCodeElimination()
def test_before_partial_eval():
"""Test transformation before PartialEval"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
back_func = run_infer_type(back_func)
mod["main"] = back_func
- seq = transform.Sequential([
+ seq = tvm.transform.Sequential([
transform.LazyGradientInit(),
transform.PartialEvaluate(),
transform.DeadCodeElimination()
def test_zeros():
"""Simple test using "zeros" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_ones():
"""Simple test using "ones" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_zeros_like():
"""Simple test using "zeros_like" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def test_ones_like():
"""Simple test using "ones_like" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
opt_tester = OptTester(mod)
pass_ctx = None
- @_transform.module_pass(opt_level=opt_level, name=pass_name)
+ @tvm.transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
mod_pass = transform
- assert isinstance(mod_pass, _transform.ModulePass)
+ assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_pass_registration_no_decorator():
def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
- mod_pass = _transform.module_pass(direct_transform, opt_level=3)
- assert isinstance(mod_pass, _transform.ModulePass)
+ mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
+ assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3
def test_module_class_pass():
- @relay.transform.module_pass(opt_level=1)
+ @tvm.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
def test_pass_info():
- info = relay.transform.PassInfo(opt_level=1, name="xyz")
+ info = tvm.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
assert info.name == "xyz"
opt_tester = OptTester(mod)
pass_ctx = None
- @_transform.module_pass(opt_level=1)
+ @tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
passes = [module_pass, function_pass]
opt_level = 2
pass_name = "sequential"
- sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
+ sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_no_pass():
passes = []
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)
def test_only_module_pass():
passes = [module_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["mod_transform"]):
ret_mod = sequential(mod)
# Check the subtract function.
def test_only_function_pass():
# Check the subtract function.
passes = [function_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["func_transform"]):
ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
# function pass.
mod = tvm.IRModule({v_sub: sub, v_log: log})
passes = [module_pass, function_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
required = ["mod_transform", "func_transform"]
with relay.build_config(required_pass=required):
ret_mod = sequential(mod)
z1 = relay.add(z, z)
return relay.Function([x], z1)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
- relay.transform.PrintIR(),
+ tvm.transform.PrintIR(),
relay.transform.DeadCodeElimination()
])
out = capfd.readouterr().err
- assert "Dumping the module IR" in out
+ assert "PrintIR" in out
assert "multiply" in out
__TRACE_COUNTER__ = 0
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
if mod:
assert isinstance(expr, Function)
mod["main"] = expr
- seq = transform.Sequential(passes)
+ seq = tvm.transform.Sequential(passes)
mod = seq(mod)
return mod["main"]
return run_opt_pass(expr, passes)
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
- opt_pass = transform.Sequential([
+ opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
- opt_pass = transform.Sequential([
+ opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
- remove_bn_pass = transform.Sequential([
+ remove_bn_pass = tvm.transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
- composite_partition = transform.Sequential([
+ composite_partition = tvm.transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
x = run_infer_type(x)
y = un_cps(x)
y = run_infer_type(y)
- x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
+ x = run_opt_pass(x, tvm.transform.Sequential(
+ [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref")
The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass`
-respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes
+respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra`
# fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program.
#
-# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling
+# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be
-# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is
+# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network
-# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing
+# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass.
-# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential`
+# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f = example()
mod = tvm.IRModule.from_expr(f)
# Glob the interested passes.
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(fuse_opt_level=2)])
mod1 = seq(mod)
# identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under
-# :py:class:`tvm.relay.transform.Sequential`. The pass infra,
+# :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface
# for users to customize the optimization level that they want to execute.
mod4 = seq(mod)
print(mod4)
-seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()])
+seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with relay.build_config(opt_level=3):
with tvm.target.create("llvm"):
mod5 = seq1(mod)
f = example()
mod = tvm.IRModule.from_expr(f)
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
- relay.transform.PrintIR(False),
- relay.transform.EliminateCommonSubexpr(),
- relay.transform.FuseOps(),
- relay.transform.PrintIR(False)])
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
+ tvm.transform.PrintIR(),
+ relay.transform.EliminateCommonSubexpr(),
+ relay.transform.FuseOps(),
+ tvm.transform.PrintIR()])
with relay.build_config(opt_level=3):
mod = seq(mod)
def run_opt_pass(expr, opt_pass):
"""Exectue a relay pass."""
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]