Relay pass transformation infrastructure.
"""
import types
+import inspect
+import functools
from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
return _transform.PartialEvaluate()
+def _wrap_class_module_pass(pass_cls, pass_info):
+ """Wrap a python class as function pass"""
+ class PyModulePass(ModulePass):
+ """Internal wrapper class to create a class instance."""
+ def __init__(self, *args, **kwargs):
+ # initialize handle in cass pass_cls creation failed.fg
+ self.handle = None
+ inst = pass_cls(*args, **kwargs)
+ # it is important not to capture self to
+ # avoid a cyclic dependency
+ def _pass_func(mod, ctx):
+ return inst.transform_module(mod, ctx)
+ self.__init_handle_by_constructor__(
+ _transform.MakeModulePass, _pass_func, pass_info)
+ self._inst = inst
+
+ def __getattr__(self, name):
+ # fall back to instance attribute if there is not any
+ return self._inst.__getattribute__(name)
+
+ functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
+ PyModulePass.__name__ = pass_cls.__name__
+ PyModulePass.__doc__ = pass_cls.__doc__
+ PyModulePass.__module__ = pass_cls.__module__
+ return PyModulePass
+
+
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
- """Create a module pass. This function returns a callback when pass_func
- is provided. Otherwise, it returns the created module level pass using the
- given optimization function.
+ """Decorate a module pass.
+
+ This function returns a callback when pass_func is provided.
+ Otherwise, it serves a decorator function.
+
+ pass_func can also be a class type with a method transform_module.
+ This function will create a decorated ModulePass using transform_module
+ as the pass function.
Parameters
----------
- pass_func : Optional[Callable[(Module/Function, PassContext) ->
- Module/Function]]
- The implemented optimization pass.
+ pass_func : Optional[Callable[(Module, PassContext) ->Module]]
+ The transformation function or class.
opt_level : int
The optimization level of this module pass.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
- The callable that will create a module pass is returned when
- pass_func is not passed in. Otherwise, a ModulePass object will be
- directly created.
+ A decorator will be returned if pass_func is not provided,
+ otherwise return the decorated result.
+ The returned decorator has two behaviors depending on the input:
+ A new ModulePass will be returned when we decorate a pass function.
+ A new ModulePass class will be returned when we decorate a class type.
Examples
--------
- The following code creates a module level pass and adds an abs function to
- the module.
+ The following code block decorates a module pass class.
+
+ .. code-block:: python
+
+ @relay.transform.module_pass
+ class CustomPipeline:
+ def __init__(self, enable_fold):
+ self.enable_fold = enable_fold
+ self.cse = relay.transform.EliminateCommonSubexpr()
+ self.const_fold = relay.transform.FoldConstant()
+
+ def transform_module(self, mod, ctx):
+ mod = self.cse(mod, ctx)
+ if self.enable_fold:
+ mod = self.const_fold(mod, ctx)
+ return mod
+
+ # create an instance of customized pipeline
+ pipeline = CustomPipeline(enable_fold=False)
+ assert isinstance(pipeline, transform.ModulePass)
+ # run the pipeline.
+ output_module = pipeline(input_module)
+
+ The following code creates a module pass by decorating
+ a user defined transform function.
.. code-block:: python
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
-
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
- def create_module_pass(pass_func):
+ def create_module_pass(pass_arg):
"""Internal function that creates a module pass"""
- if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
- raise TypeError("pass_func must be a callable for Module pass")
-
- fname = name if name else pass_func.__name__
+ fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
- return _transform.MakeModulePass(pass_func, info)
+ if inspect.isclass(pass_arg):
+ return _wrap_class_module_pass(pass_arg, info)
+ if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for Module pass")
+ return _transform.MakeModulePass(pass_arg, info)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
+def _wrap_class_function_pass(pass_cls, pass_info):
+ """Wrap a python class as function pass"""
+ class PyFunctionPass(FunctionPass):
+ """Internal wrapper class to create a class instance."""
+ def __init__(self, *args, **kwargs):
+ # initialize handle in cass pass_cls creation failed.fg
+ self.handle = None
+ inst = pass_cls(*args, **kwargs)
+ # it is important not to capture self to
+ # avoid a cyclic dependency
+ def _pass_func(func, mod, ctx):
+ return inst.transform_function(func, mod, ctx)
+ self.__init_handle_by_constructor__(
+ _transform.MakeFunctionPass, _pass_func, pass_info)
+ self._inst = inst
+
+ def __getattr__(self, name):
+ # fall back to instance attribute if there is not any
+ return self._inst.__getattribute__(name)
+
+ functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
+ PyFunctionPass.__name__ = pass_cls.__name__
+ PyFunctionPass.__doc__ = pass_cls.__doc__
+ PyFunctionPass.__module__ = pass_cls.__module__
+ return PyFunctionPass
+
+
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
- """Create a function pass. This function returns a callback when pass_func
+ """Decorate a function pass.
+
+ This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
- pass_func : Optional[Callable[(Module/Function, PassContext) ->
- Module/Function]]
- The implemented optimization pass.
+ pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
+ The transformation function or class.
opt_level : int
The optimization level of this module pass.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
- The callable that will create a function pass is returned when
- pass_func is not passed in. Otherwise, a FunctionPass object will be
- created.
+
+ A decorator will be returned if pass_func is not provided,
+ otherwise return the decorated result.
+ The returned decorator has two behaviors depending on the input:
+ A new FunctionPass will be returned when we decorate a pass function.
+ A new FunctionPass class will be returned when we decorate a class type.
Examples
--------
- The following code creates a function level pass that performs constant
- folding.
+ The following code block decorates a function pass class.
+
+ .. code-block:: python
+
+ @relay.transform.function_pass(opt_level=1)
+ class TestReplaceFunc:
+ def __init__(self, new_func):
+ self.new_func = new_func
+
+ def transform_function(self, func, mod, ctx):
+ # just for demo purposes
+ # transform func to new_func
+ return self.new_func
+
+ x = relay.var("x", shape=(10, 20))
+ f1 = relay.Function([x], x)
+ f2 = relay.Function([x], relay.log(x))
+ # fpass is now a special pass that replaces every
+ # function to f1
+ fpass = TestReplaceFunc(f1)
+ # now every function in input_mod is replaced by f1
+ res_mod = fpass(input_mod)
+
+
+ The following code creates a function pass by decorating
+ a user defined transform function.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
- def transform(func, ctx):
- return ir_pass.fold_constant(func)
+ def transform(func, mod, ctx):
+ # my transformations here.
+ return func
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
- def create_function_pass(pass_func):
+ def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
- if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
- raise TypeError("pass_func must be a callable for Module pass")
-
- fname = name if name else pass_func.__name__
+ fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
- return _transform.MakeFunctionPass(pass_func, info)
+ if inspect.isclass(pass_arg):
+ return _wrap_class_function_pass(pass_arg, info)
+ if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for Module pass")
+ return _transform.MakeFunctionPass(pass_arg, info)
if pass_func:
return create_function_pass(pass_func)
test_pass_run()
+def test_function_class_pass():
+ @relay.transform.function_pass(opt_level=1)
+ class TestReplaceFunc:
+ """Simple test function to replace one argument to another."""
+ def __init__(self, new_func):
+ self.new_func = new_func
+
+ def transform_function(self, func, mod, ctx):
+ return self.new_func
+
+ x = relay.var("x", shape=(10, 20))
+ f1 = relay.Function([x], x)
+ f2 = relay.Function([x], relay.log(x))
+ fpass = TestReplaceFunc(f1)
+ assert fpass.info.opt_level == 1
+ assert fpass.info.name == "TestReplaceFunc"
+ mod = relay.Module.from_expr(f2)
+ mod = fpass(mod)
+ # wrap in expr
+ mod2 = relay.Module.from_expr(f1)
+ assert relay.alpha_equal(mod["main"], mod2["main"])
+
+
def test_function_pass():
shape = (10, )
dtype = 'float32'
test_pass_run()
+def test_module_class_pass():
+ @relay.transform.module_pass(opt_level=1)
+ class TestPipeline:
+ """Simple test function to replace one argument to another."""
+ def __init__(self, new_mod, replace):
+ self.new_mod = new_mod
+ self.replace = replace
+
+ def transform_module(self, mod, ctx):
+ if self.replace:
+ return self.new_mod
+ return mod
+
+ x = relay.var("x", shape=(10, 20))
+ m1 = relay.Module.from_expr(relay.Function([x], x))
+ m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
+ fpass = TestPipeline(m2, replace=True)
+ assert fpass.info.name == "TestPipeline"
+ mod3 = fpass(m1)
+ assert mod3.same_as(m2)
+ mod4 = TestPipeline(m2, replace=False)(m1)
+ assert mod4.same_as(m1)
+
+
def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
if __name__ == "__main__":
+ test_function_class_pass()
+ test_module_class_pass()
test_module_pass()
test_function_pass()
test_sequential_pass()