[Relay] Add Python type functor and tests (#4209)
authorLogan Weber <36520469+weberlo@users.noreply.github.com>
Wed, 30 Oct 2019 04:51:20 +0000 (21:51 -0700)
committerJared Roesch <roeschinc@gmail.com>
Wed, 30 Oct 2019 04:51:20 +0000 (21:51 -0700)
* Add Python type functor and tests

* Lint roller

python/tvm/relay/__init__.py
python/tvm/relay/type_functor.py [new file with mode: 0644]
tests/python/relay/test_type_functor.py [new file with mode: 0644]

index f05098b..bd3f5bd 100644 (file)
@@ -23,6 +23,7 @@ from ..api import register_func
 from . import base
 from . import ty
 from . import expr
+from . import type_functor
 from . import expr_functor
 from . import module
 from . import adt
@@ -118,6 +119,11 @@ module_pass = transform.module_pass
 function_pass = transform.function_pass
 alpha_equal = analysis.alpha_equal
 
+# TypeFunctor
+TypeFunctor = type_functor.TypeFunctor
+TypeVisitor = type_functor.TypeVisitor
+TypeMutator = type_functor.TypeMutator
+
 # ExprFunctor
 ExprFunctor = expr_functor.ExprFunctor
 ExprVisitor = expr_functor.ExprVisitor
diff --git a/python/tvm/relay/type_functor.py b/python/tvm/relay/type_functor.py
new file mode 100644 (file)
index 0000000..1331058
--- /dev/null
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The type functor of Relay."""
+from .ty import (TypeVar, IncompleteType, TensorType, FuncType,
+                 TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
+from .adt import TypeData
+
+class TypeFunctor:
+    """
+    An abstract visitor defined over Type.
+
+    Defines the default dispatch over types.
+    """
+    def __init__(self):
+        # TODO(weberlo): make type vars hashable, so we can memoize
+        pass
+
+    # pylint: disable=no-else-return
+    def visit(self, typ):
+        """Apply the visitor to a type."""
+        if isinstance(typ, TypeVar):
+            return self.visit_type_var(typ)
+        elif isinstance(typ, IncompleteType):
+            return self.visit_incomplete_type(typ)
+        elif isinstance(typ, TensorType):
+            return self.visit_tensor_type(typ)
+        elif isinstance(typ, FuncType):
+            return self.visit_func_type(typ)
+        elif isinstance(typ, TupleType):
+            return self.visit_tuple_type(typ)
+        elif isinstance(typ, TypeRelation):
+            return self.visit_type_relation(typ)
+        elif isinstance(typ, RefType):
+            return self.visit_ref_type(typ)
+        elif isinstance(typ, GlobalTypeVar):
+            return self.visit_global_type_var(typ)
+        elif isinstance(typ, TypeCall):
+            return self.visit_type_call(typ)
+        elif isinstance(typ, TypeData):
+            return self.visit_type_data(typ)
+        else:
+            raise Exception('unhandled case: {0}'.format(type(typ)))
+
+    def visit_type_var(self, _):
+        raise NotImplementedError()
+
+    def visit_incomplete_type(self, _):
+        raise NotImplementedError()
+
+    def visit_tensor_type(self, _):
+        raise NotImplementedError()
+
+    def visit_func_type(self, _):
+        raise NotImplementedError()
+
+    def visit_tuple_type(self, _):
+        raise NotImplementedError()
+
+    def visit_type_relation(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_type(self, _):
+        raise NotImplementedError()
+
+    def visit_global_type_var(self, _):
+        raise NotImplementedError()
+
+    def visit_type_call(self, _):
+        raise NotImplementedError()
+
+    def visit_type_data(self, _):
+        raise NotImplementedError()
+
+
+class TypeVisitor(TypeFunctor):
+    """
+    A visitor over Type.
+
+    The default behavior recursively traverses the AST.
+    """
+    def visit_type_var(self, tv):
+        pass
+
+    def visit_incomplete_type(self, it):
+        pass
+
+    def visit_tensor_type(self, tt):
+        pass
+
+    def visit_func_type(self, ft):
+        for arg_type in ft.arg_types:
+            self.visit(arg_type)
+        self.visit(ft.ret_type)
+        for type_param in getattr(ft, 'type_params', []):
+            self.visit(type_param)
+        for type_constraint in getattr(ft, 'type_constraints', []):
+            self.visit(type_constraint)
+
+    def visit_tuple_type(self, tt):
+        for field in tt.fields:
+            self.visit(field)
+
+    def visit_type_relation(self, tr):
+        for arg in tr.args:
+            self.visit(arg)
+
+    def visit_ref_type(self, rt):
+        self.visit(rt.value)
+
+    def visit_global_type_var(self, gtv):
+        pass
+
+    def visit_type_call(self, tc):
+        self.visit(tc.func)
+        for arg in tc.args:
+            self.visit(arg)
+
+    def visit_type_data(self, td):
+        self.visit(td.header)
+        for type_var in td.type_vars:
+            self.visit(type_var)
+
+
+class TypeMutator(TypeFunctor):
+    """
+    A functional visitor over Type.
+
+    The default behavior recursively traverses the AST
+    and reconstructs the AST.
+    """
+    def visit_type_var(self, tv):
+        return TypeVar(tv.var.name, tv.kind)
+
+    def visit_incomplete_type(self, it):
+        return IncompleteType(it.kind)
+
+    def visit_tensor_type(self, tt):
+        return TensorType(tt.shape, tt.dtype)
+
+    def visit_func_type(self, ft):
+        new_arg_types = [self.visit(arg_type) for arg_type in ft.arg_types]
+        new_ret_type = self.visit(ft.ret_type)
+        new_type_params = [
+            self.visit(type_param)
+            for type_param in getattr(ft, 'type_params', [])]
+        new_type_constraints = [
+            self.visit(type_constraint)
+            for type_constraint in getattr(ft, 'type_constraints', [])]
+        return FuncType(
+            new_arg_types,
+            new_ret_type,
+            new_type_params,
+            new_type_constraints)
+
+    def visit_tuple_type(self, tt):
+        return TupleType([self.visit(field) for field in tt.fields])
+
+    def visit_type_relation(self, tr):
+        return TypeRelation(
+            tr.func,
+            [self.visit(arg) for arg in tr.args],
+            tr.num_inputs,
+            tr.attrs)
+
+    def visit_ref_type(self, rt):
+        return RefType(self.visit(rt.value))
+
+    def visit_global_type_var(self, gtv):
+        return GlobalTypeVar(gtv.var.name, gtv.kind)
+
+    def visit_type_call(self, tc):
+        return TypeCall(
+            self.visit(tc.func),
+            [self.visit(arg) for arg in tc.args])
+
+    def visit_type_data(self, td):
+        return TypeData(
+            self.visit(td.header),
+            [self.visit(type_var) for type_var in td.type_vars],
+            td.constructors)
diff --git a/tests/python/relay/test_type_functor.py b/tests/python/relay/test_type_functor.py
new file mode 100644 (file)
index 0000000..d09a893
--- /dev/null
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
+from tvm.relay.analysis import assert_graph_equal
+from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
+                 TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
+from tvm.relay.adt import TypeData
+
+def check_visit(typ):
+    try:
+        ef = TypeFunctor()
+        ef.visit(typ)
+        assert False
+    except NotImplementedError:
+        pass
+
+    ev = TypeVisitor()
+    ev.visit(typ)
+
+    assert_graph_equal(TypeMutator().visit(typ), typ)
+
+
+def test_type_var():
+    tv = TypeVar('a')
+    check_visit(tv)
+
+
+def test_incomplete_type():
+    it = IncompleteType()
+    check_visit(it)
+
+
+def test_tensor_type():
+    tt = TensorType([])
+    check_visit(tt)
+
+
+def test_func_type():
+    tv = TypeVar('tv')
+    tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
+    ft = FuncType([tt], tt, type_params=[tv])
+    check_visit(ft)
+
+
+def test_tuple_type():
+    tt = TupleType([TupleType([])])
+    check_visit(tt)
+
+
+def test_type_relation():
+    func = tvm.get_env_func('tvm.relay.type_relation.Broadcast')
+    attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
+    tp = TypeVar('tp')
+    tf = FuncType([], TupleType([]), [], [])
+    tt = TensorType([1, 2, 3], 'float32')
+    tr = TypeRelation(func, [tp, tf, tt], 2, attrs)
+
+    check_visit(tr)
+
+
+def test_ref_type():
+    rt = RefType(TupleType([]))
+    check_visit(rt)
+
+
+def test_global_type_var():
+    gtv = GlobalTypeVar('gtv')
+    check_visit(gtv)
+
+
+def test_type_call():
+    tc = TypeCall(GlobalTypeVar('tf'), [TupleType([])])
+    check_visit(tc)
+
+
+def test_type_data():
+    td = TypeData(GlobalTypeVar('td'), [TypeVar('tv')], [])
+    check_visit(td)
+
+
+if __name__ == "__main__":
+    test_type_var()
+    test_incomplete_type()
+    test_tensor_type()
+    test_func_type()
+    test_tuple_type()
+    test_type_relation()
+    test_ref_type()
+    test_global_type_var()
+    test_type_call()
+    test_type_data()