[RELAY][PASS] Enable decorating python class as Pass (#3364)
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 17 Jun 2019 16:54:48 +0000 (09:54 -0700)
committerGitHub <noreply@github.com>
Mon, 17 Jun 2019 16:54:48 +0000 (09:54 -0700)
3rdparty/dmlc-core
python/tvm/relay/__init__.py
python/tvm/relay/transform.py
tests/python/relay/test_pass_manager.py

index 3943914..fbe142b 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
+Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
index 1c8f5d6..5536e50 100644 (file)
@@ -101,6 +101,7 @@ const = expr.const
 bind = expr.bind
 module_pass = transform.module_pass
 function_pass = transform.function_pass
+alpha_equal = ir_pass.alpha_equal
 
 # ExprFunctor
 ExprFunctor = expr_functor.ExprFunctor
index b76c236..d7a7c26 100644 (file)
@@ -19,6 +19,8 @@
 Relay pass transformation infrastructure.
 """
 import types
+import inspect
+import functools
 
 from tvm._ffi.runtime_ctypes import TVMContext
 from . import _transform
@@ -444,16 +446,47 @@ def PartialEvaluate():
     return _transform.PartialEvaluate()
 
 
+def _wrap_class_module_pass(pass_cls, pass_info):
+    """Wrap a python class as function pass"""
+    class PyModulePass(ModulePass):
+        """Internal wrapper class to create a class instance."""
+        def __init__(self, *args, **kwargs):
+            # initialize handle in cass pass_cls creation failed.fg
+            self.handle = None
+            inst = pass_cls(*args, **kwargs)
+            # it is important not to capture self to
+            # avoid a cyclic dependency
+            def _pass_func(mod, ctx):
+                return inst.transform_module(mod, ctx)
+            self.__init_handle_by_constructor__(
+                _transform.MakeModulePass, _pass_func, pass_info)
+            self._inst = inst
+
+        def __getattr__(self, name):
+            # fall back to instance attribute if there is not any
+            return self._inst.__getattribute__(name)
+
+    functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
+    PyModulePass.__name__ = pass_cls.__name__
+    PyModulePass.__doc__ = pass_cls.__doc__
+    PyModulePass.__module__ = pass_cls.__module__
+    return PyModulePass
+
+
 def module_pass(pass_func=None, opt_level=None, name=None, required=None):
-    """Create a module pass. This function returns a callback when pass_func
-    is provided. Otherwise, it returns the created module level pass using the
-    given optimization function.
+    """Decorate a module pass.
+
+    This function returns a callback when pass_func is provided.
+    Otherwise, it serves a decorator function.
+
+    pass_func can also be a class type with a method transform_module.
+    This function will create a decorated ModulePass using transform_module
+    as the pass function.
 
     Parameters
     ----------
-    pass_func : Optional[Callable[(Module/Function, PassContext) ->
-                Module/Function]]
-        The implemented optimization pass.
+    pass_func : Optional[Callable[(Module, PassContext) ->Module]]
+        The transformation function or class.
 
     opt_level : int
         The optimization level of this module pass.
@@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
     Returns
     -------
     create_module_pass : Union[Callable, ModulePass]
-        The callable that will create a module pass is returned when
-        pass_func is not passed in. Otherwise, a ModulePass object will be
-        directly created.
+        A decorator will be returned if pass_func is not provided,
+        otherwise return the decorated result.
+        The returned decorator has two behaviors depending on the input:
+        A new ModulePass will be returned when we decorate a pass function.
+        A new ModulePass class will be returned when we decorate a class type.
 
     Examples
     --------
-    The following code creates a module level pass and adds an abs function to
-    the module.
+    The following code block decorates a module pass class.
+
+    .. code-block:: python
+
+        @relay.transform.module_pass
+        class CustomPipeline:
+            def __init__(self, enable_fold):
+                self.enable_fold = enable_fold
+                self.cse = relay.transform.EliminateCommonSubexpr()
+                self.const_fold = relay.transform.FoldConstant()
+
+            def transform_module(self, mod, ctx):
+                mod = self.cse(mod, ctx)
+                if self.enable_fold:
+                    mod = self.const_fold(mod, ctx)
+                return mod
+
+        # create an instance of customized pipeline
+        pipeline = CustomPipeline(enable_fold=False)
+        assert isinstance(pipeline, transform.ModulePass)
+        # run the pipeline.
+        output_module = pipeline(input_module)
+
+    The following code creates a module pass by decorating
+    a user defined transform function.
 
     .. code-block:: python
 
@@ -497,7 +555,6 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
         updated_mod = module_pass(m)
         # Now a function abs should be added to the module m.
     """
-
     if opt_level is None:
         raise ValueError("Please provide opt_level for the module pass.")
 
@@ -506,30 +563,59 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
         raise TypeError("Required is expected to be the type of " +
                         "list/tuple.")
 
-    def create_module_pass(pass_func):
+    def create_module_pass(pass_arg):
         """Internal function that creates a module pass"""
-        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
-            raise TypeError("pass_func must be a callable for Module pass")
-
-        fname = name if name else pass_func.__name__
+        fname = name if name else pass_arg.__name__
         info = PassInfo(opt_level, fname, required)
-        return _transform.MakeModulePass(pass_func, info)
+        if inspect.isclass(pass_arg):
+            return _wrap_class_module_pass(pass_arg, info)
+        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Module pass")
+        return _transform.MakeModulePass(pass_arg, info)
 
     if pass_func:
         return create_module_pass(pass_func)
     return create_module_pass
 
 
+def _wrap_class_function_pass(pass_cls, pass_info):
+    """Wrap a python class as function pass"""
+    class PyFunctionPass(FunctionPass):
+        """Internal wrapper class to create a class instance."""
+        def __init__(self, *args, **kwargs):
+            # initialize handle in cass pass_cls creation failed.fg
+            self.handle = None
+            inst = pass_cls(*args, **kwargs)
+            # it is important not to capture self to
+            # avoid a cyclic dependency
+            def _pass_func(func, mod, ctx):
+                return inst.transform_function(func, mod, ctx)
+            self.__init_handle_by_constructor__(
+                _transform.MakeFunctionPass, _pass_func, pass_info)
+            self._inst = inst
+
+        def __getattr__(self, name):
+            # fall back to instance attribute if there is not any
+            return self._inst.__getattribute__(name)
+
+    functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
+    PyFunctionPass.__name__ = pass_cls.__name__
+    PyFunctionPass.__doc__ = pass_cls.__doc__
+    PyFunctionPass.__module__ = pass_cls.__module__
+    return PyFunctionPass
+
+
 def function_pass(pass_func=None, opt_level=None, name=None, required=None):
-    """Create a function pass. This function returns a callback when pass_func
+    """Decorate a function pass.
+
+    This function returns a callback when pass_func
     is provided. Otherwise, it returns the created function pass using the
     given optimization function.
 
     Parameters
     ----------
-    pass_func : Optional[Callable[(Module/Function, PassContext) ->
-                Module/Function]]
-        The implemented optimization pass.
+    pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
+        The transformation function or class.
 
     opt_level : int
         The optimization level of this module pass.
@@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
     Returns
     -------
     create_function_pass : Union[Callable, FunctionPass]
-        The callable that will create a function pass is returned when
-        pass_func is not passed in. Otherwise, a FunctionPass object will be
-        created.
+
+        A decorator will be returned if pass_func is not provided,
+        otherwise return the decorated result.
+        The returned decorator has two behaviors depending on the input:
+        A new FunctionPass will be returned when we decorate a pass function.
+        A new FunctionPass class will be returned when we decorate a class type.
 
     Examples
     --------
-    The following code creates a function level pass that performs constant
-    folding.
+    The following code block decorates a function pass class.
+
+    .. code-block:: python
+
+        @relay.transform.function_pass(opt_level=1)
+        class TestReplaceFunc:
+            def __init__(self, new_func):
+                self.new_func = new_func
+
+            def transform_function(self, func, mod, ctx):
+                # just for demo purposes
+                # transform func to new_func
+                return self.new_func
+
+        x = relay.var("x", shape=(10, 20))
+        f1 = relay.Function([x], x)
+        f2 = relay.Function([x], relay.log(x))
+        # fpass is now a special pass that replaces every
+        # function to f1
+        fpass = TestReplaceFunc(f1)
+        # now every function in input_mod is replaced by f1
+        res_mod = fpass(input_mod)
+
+
+    The following code creates a function pass by decorating
+    a user defined transform function.
 
     .. code-block:: python
 
         @relay.transform.function_pass(opt_level=2)
-        def transform(func, ctx):
-            return ir_pass.fold_constant(func)
+        def transform(func, mod, ctx):
+            # my transformations here.
+            return func
 
         function_pass = transform
         assert isinstance(function_pass, transform.FunctionPass)
@@ -577,14 +691,15 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
         raise TypeError("Required is expected to be the type of " +
                         "list/tuple.")
 
-    def create_function_pass(pass_func):
+    def create_function_pass(pass_arg):
         """Internal function that creates a function pass"""
-        if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
-            raise TypeError("pass_func must be a callable for Module pass")
-
-        fname = name if name else pass_func.__name__
+        fname = name if name else pass_arg.__name__
         info = PassInfo(opt_level, fname, required)
-        return _transform.MakeFunctionPass(pass_func, info)
+        if inspect.isclass(pass_arg):
+            return _wrap_class_function_pass(pass_arg, info)
+        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Module pass")
+        return _transform.MakeFunctionPass(pass_arg, info)
 
     if pass_func:
         return create_function_pass(pass_func)
index 7505aa9..a8f50bd 100644 (file)
@@ -189,6 +189,29 @@ def test_module_pass():
     test_pass_run()
 
 
+def test_function_class_pass():
+    @relay.transform.function_pass(opt_level=1)
+    class TestReplaceFunc:
+        """Simple test function to replace one argument to another."""
+        def __init__(self, new_func):
+            self.new_func = new_func
+
+        def transform_function(self, func, mod, ctx):
+            return self.new_func
+
+    x = relay.var("x", shape=(10, 20))
+    f1 = relay.Function([x], x)
+    f2 = relay.Function([x], relay.log(x))
+    fpass = TestReplaceFunc(f1)
+    assert fpass.info.opt_level == 1
+    assert fpass.info.name == "TestReplaceFunc"
+    mod = relay.Module.from_expr(f2)
+    mod = fpass(mod)
+    # wrap in expr
+    mod2 = relay.Module.from_expr(f1)
+    assert relay.alpha_equal(mod["main"], mod2["main"])
+
+
 def test_function_pass():
     shape = (10, )
     dtype = 'float32'
@@ -259,6 +282,30 @@ def test_function_pass():
     test_pass_run()
 
 
+def test_module_class_pass():
+    @relay.transform.module_pass(opt_level=1)
+    class TestPipeline:
+        """Simple test function to replace one argument to another."""
+        def __init__(self, new_mod, replace):
+            self.new_mod = new_mod
+            self.replace = replace
+
+        def transform_module(self, mod, ctx):
+            if self.replace:
+                return self.new_mod
+            return mod
+
+    x = relay.var("x", shape=(10, 20))
+    m1 = relay.Module.from_expr(relay.Function([x], x))
+    m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
+    fpass = TestPipeline(m2, replace=True)
+    assert fpass.info.name == "TestPipeline"
+    mod3 = fpass(m1)
+    assert mod3.same_as(m2)
+    mod4 = TestPipeline(m2, replace=False)(m1)
+    assert mod4.same_as(m1)
+
+
 def test_pass_info():
     info = relay.transform.PassInfo(opt_level=1, name="xyz")
     assert info.opt_level == 1
@@ -451,6 +498,8 @@ def test_sequential_with_scoping():
 
 
 if __name__ == "__main__":
+    test_function_class_pass()
+    test_module_class_pass()
     test_module_pass()
     test_function_pass()
     test_sequential_pass()