From 6027412bcb443572088d71f1060370317eb6e671 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 15 Mar 2020 15:39:47 -0700 Subject: [PATCH] [IR] Update the type_keys to reflect the code-org (#5074) --- include/tvm/ir/expr.h | 2 +- include/tvm/ir/module.h | 2 +- include/tvm/ir/span.h | 4 +- include/tvm/ir/transform.h | 6 +- include/tvm/ir/type.h | 18 +++--- include/tvm/ir/type_relation.h | 6 +- python/tvm/ir/__init__.py | 2 +- python/tvm/ir/base.py | 4 +- python/tvm/ir/expr.py | 2 +- python/tvm/ir/json_compact.py | 24 +++++++ python/tvm/ir/module.py | 2 +- python/tvm/ir/transform.py | 10 +-- python/tvm/ir/type.py | 25 ++++++-- python/tvm/ir/type_relation.py | 3 +- src/ir/transform.cc | 4 +- tests/python/relay/test_ir_nodes.py | 78 ----------------------- tests/python/relay/test_json_compact.py | 73 ++++++++++++++++++++- tests/python/unittest/test_ir_type.py | 108 ++++++++++++++++++++++++++++++++ 18 files changed, 255 insertions(+), 118 deletions(-) create mode 100644 tests/python/unittest/test_ir_type.py diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index e37374a..c8b1a3f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode { v->Visit("_checked_type_", &checked_type_); } - static constexpr const char* _type_key = "relay.GlobalVar"; + static constexpr const char* _type_key = "GlobalVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); }; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 23d1f6e..1ee7c32 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -226,7 +226,7 @@ class IRModuleNode : public Object { */ TVM_DLL std::unordered_set Imports() const; - static constexpr const char* _type_key = "relay.Module"; + static constexpr const char* _type_key = "IRModule"; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); private: diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 8cbfff7..4720dfe 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -44,7 +44,7 @@ class SourceNameNode : public Object { // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } - static constexpr const char* _type_key = "relay.SourceName"; + static constexpr const char* _type_key = "SourceName"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; @@ -89,7 +89,7 @@ class SpanNode : public Object { TVM_DLL static Span make(SourceName source, int lineno, int col_offset); - static constexpr const char* _type_key = "relay.Span"; + static constexpr const char* _type_key = "Span"; TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 2afcb17..1b6ea25 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -110,7 +110,7 @@ class PassContextNode : public Object { v->Visit("disabled_pass", &disabled_pass); } - static constexpr const char* _type_key = "relay.PassContext"; + static constexpr const char* _type_key = "transform.PassContext"; TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; @@ -206,7 +206,7 @@ class PassInfoNode : public Object { v->Visit("required", &required); } - static constexpr const char* _type_key = "relay.PassInfo"; + static constexpr const char* _type_key = "transform.PassInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); }; @@ -265,7 +265,7 @@ class PassNode : public Object { void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "relay.Pass"; + static constexpr const char* _type_key = "transform.Pass"; TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 7fd224b..a9475a1 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -78,7 +78,7 @@ class TypeNode : public Object { */ mutable Span span; - static constexpr const char* _type_key = "relay.Type"; + static constexpr const char* _type_key = "Type"; TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; @@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode { v->Visit("dtype", &dtype); } - static constexpr const char* _type_key = "relay.PrimType"; + static constexpr const char* _type_key = "PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; @@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.TypeVar"; + static constexpr const char* _type_key = "TypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); }; @@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode { v->Visit("kind", &kind); } - static constexpr const char* _type_key = "relay.GlobalTypeVar"; + static constexpr const char* _type_key = "GlobalTypeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); }; @@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.TupleType"; + static constexpr const char* _type_key = "TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; @@ -289,7 +289,7 @@ inline Type VoidType() { */ class TypeConstraintNode : public TypeNode { public: - static constexpr const char* _type_key = "relay.TypeConstraint"; + static constexpr const char* _type_key = "TypeConstraint"; TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); }; @@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.FuncType"; + static constexpr const char* _type_key = "FuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; @@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.IncompleteType"; + static constexpr const char* _type_key = "IncompleteType"; TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); }; @@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode { v->Visit("span", &span); } + // Keep the relay prefix in the type as this type is specific + // to the relay itself. static constexpr const char* _type_key = "relay.RefType"; TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode); }; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index ff36b96..f7bfb68 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.TypeCall"; + static constexpr const char* _type_key = "TypeCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); }; @@ -119,7 +119,7 @@ class TypeReporterNode : public Object { // solver is not serializable. void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "relay.TypeReporter"; + static constexpr const char* _type_key = "TypeReporter"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); }; @@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode { v->Visit("span", &span); } - static constexpr const char* _type_key = "relay.TypeRelation"; + static constexpr const char* _type_key = "TypeRelation"; TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); }; diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 4160326..8418d63 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -17,7 +17,7 @@ # pylint: disable=unused-import """Common data structures across all IR variants.""" from .base import SourceName, Span, Node, EnvFunc, load_json, save_json -from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType +from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 944daa1..810d78f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -56,7 +56,7 @@ class Node(Object): return _ffi_api.PrettyPrint(self) -@tvm._ffi.register_object("relay.SourceName") +@tvm._ffi.register_object("SourceName") class SourceName(Object): """A identifier for a source location. @@ -69,7 +69,7 @@ class SourceName(Object): self.__init_handle_by_constructor__(_ffi_api.SourceName, name) -@tvm._ffi.register_object("relay.Span") +@tvm._ffi.register_object("Span") class Span(Object): """Specifies a location in a source program. diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 4e6bf16..eedfff8 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,7 +51,7 @@ class RelayExpr(BaseExpr): return ret -@tvm._ffi.register_object("relay.GlobalVar") +@tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): """A global variable in the IR. diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index d1cac95..10ecbaa 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -62,11 +62,35 @@ def create_updater_06_to_07(): # set vindex to null nodes[vindex]["type_key"] = "" del item["attrs"]["var"] + assert item["type_key"].startswith("relay.") + item["type_key"] = item["type_key"][len("relay."):] return item + def _rename(new_name): + def _convert(item, _): + item["type_key"] = new_name + return item + return _convert + node_map = { "relay.TypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var, + "relay.Type": _rename("Type"), + "relay.TupleType": _rename("TupleType"), + "relay.TypeConstraint": _rename("TypeConstraint"), + "relay.FuncType": _rename("FuncType"), + "relay.IncompleteType": _rename("IncompleteType"), + "relay.TypeRelation": _rename("TypeRelation"), + "relay.TypeCall": _rename("TypeCall"), + "relay.Module": _rename("IRModule"), + "relay.SourceName": _rename("SourceName"), + "relay.Span": _rename("Span"), + "relay.GlobalVar": _rename("GlobalVar"), + "relay.Pass": _rename("transform.Pass"), + "relay.PassInfo": _rename("transform.PassInfo"), + "relay.PassContext": _rename("transform.PassContext"), + "relay.ModulePass": _rename("transform.ModulePass"), + "relay.Sequantial": _rename("transform.Sequantial"), } return create_updater(node_map, "0.6", "0.7") diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 2d7481f..24f5211 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -24,7 +24,7 @@ from . import type as _ty from . import _ffi_api -@tvm._ffi.register_object("relay.Module") +@tvm._ffi.register_object("IRModule") class IRModule(Node): """IRModule that holds functions and type definitions. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index a35feb3..cdb9257 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd from . import _ffi_transform_api -@tvm._ffi.register_object("relay.PassInfo") +@tvm._ffi.register_object("transform.PassInfo") class PassInfo(Object): """The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. @@ -51,7 +51,7 @@ class PassInfo(Object): _ffi_transform_api.PassInfo, opt_level, name, required) -@tvm._ffi.register_object("relay.PassContext") +@tvm._ffi.register_object("transform.PassContext") class PassContext(Object): """The basis where a Relay optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used @@ -112,7 +112,7 @@ class PassContext(Object): return _ffi_transform_api.GetCurrentPassContext() -@tvm._ffi.register_object("relay.Pass") +@tvm._ffi.register_object("transform.Pass") class Pass(Object): """The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to @@ -141,7 +141,7 @@ class Pass(Object): return _ffi_transform_api.RunPass(self, mod) -@tvm._ffi.register_object("relay.ModulePass") +@tvm._ffi.register_object("transform.ModulePass") class ModulePass(Pass): """A pass that works on tvm.IRModule. Users don't need to interact with this class directly. Instead, a module pass should be created through @@ -152,7 +152,7 @@ class ModulePass(Pass): """ -@tvm._ffi.register_object("relay.Sequential") +@tvm._ffi.register_object("transform.Sequential") class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index ebe2aae..ebbb629 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -46,7 +46,20 @@ class TypeKind(IntEnum): TypeData = 6 -@tvm._ffi.register_object("relay.TypeVar") +class PrimType(Type): + """Primitive data type in the low level IR + + Parameters + ---------- + dtype : str + The runtime data type relates to the primtype. + """ + def __init__(self, dtype): + self.__init_handle_by_constructor__( + _ffi_api.PrimType, dtype) + + +@tvm._ffi.register_object("TypeVar") class TypeVar(Type): """Type parameter in functions. @@ -85,7 +98,7 @@ class TypeVar(Type): return TypeCall(self, args) -@tvm._ffi.register_object("relay.GlobalTypeVar") +@tvm._ffi.register_object("GlobalTypeVar") class GlobalTypeVar(Type): """A global type variable that is used for defining new types or type aliases. @@ -120,7 +133,7 @@ class GlobalTypeVar(Type): return TypeCall(self, args) -@tvm._ffi.register_object("relay.TupleType") +@tvm._ffi.register_object("TupleType") class TupleType(Type): """The type of tuple values. @@ -135,12 +148,12 @@ class TupleType(Type): _ffi_api.TupleType, fields) -@tvm._ffi.register_object("relay.TypeConstraint") +@tvm._ffi.register_object("TypeConstraint") class TypeConstraint(Type): """Abstract class representing a type constraint.""" -@tvm._ffi.register_object("relay.FuncType") +@tvm._ffi.register_object("FuncType") class FuncType(Type): """Function type. @@ -179,7 +192,7 @@ class FuncType(Type): _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints) -@tvm._ffi.register_object("relay.IncompleteType") +@tvm._ffi.register_object("IncompleteType") class IncompleteType(Type): """Incomplete type during type inference. diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py index 63c83d9..bacb2c2 100644 --- a/python/tvm/ir/type_relation.py +++ b/python/tvm/ir/type_relation.py @@ -21,6 +21,7 @@ from .type import Type, TypeConstraint from . import _ffi_api +@tvm._ffi.register_object("TypeCall") class TypeCall(Type): """Type function application. @@ -41,7 +42,7 @@ class TypeCall(Type): self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) -@tvm._ffi.register_object("relay.TypeRelation") +@tvm._ffi.register_object("TypeRelation") class TypeRelation(TypeConstraint): """User defined type relation, it is an input-output relation on types. diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 2b5010b..6878abc 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -132,7 +132,7 @@ class ModulePassNode : public PassNode { */ PassInfo Info() const override { return pass_info; } - static constexpr const char* _type_key = "relay.ModulePass"; + static constexpr const char* _type_key = "transform.ModulePass"; TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); }; @@ -206,7 +206,7 @@ class SequentialNode : public PassNode { */ IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; - static constexpr const char* _type_key = "relay.Sequential"; + static constexpr const char* _type_key = "transform.Sequential"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index d3d0808..968a3bb 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -30,13 +30,6 @@ def check_json_roundtrip(node): assert graph_equal(back, node) -def test_bad_constructor(): - try: - x = relay.ty.TensorType("xx", "xx") - except tvm.error.TVMError: - pass - - # Span def test_span(): span = relay.Span(None, 1, 1) @@ -55,71 +48,6 @@ def test_span(): assert back.lineno == span.lineno assert back.col_offset == span.col_offset -# Types - -def test_tensor_type(): - shape = tvm.runtime.convert([1, 2, 3]) - dtype = 'float32' - tt = relay.TensorType(shape, dtype) - assert tt.dtype == dtype - assert tt.shape == shape - assert tt.span == None - str(tt) - check_json_roundtrip(tt) - - -def test_type_param(): - tp = relay.TypeVar('name', relay.TypeKind.Type) - assert tp.kind == relay.TypeKind.Type - # assert tp.span # TODO allow us to set span - str(tp) - check_json_roundtrip(tp) - - -def test_func_type(): - type_params = tvm.runtime.convert([]) - type_constraints = tvm.runtime.convert([]) # TODO: fill me in - arg_types = tvm.runtime.convert([]) - ret_type = relay.TensorType((1, 2, 3), 'float32') - tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) - assert tf.type_params == type_params - assert tf.type_constraints == type_constraints - assert tf.arg_types == arg_types - assert tf.ret_type == ret_type - assert tf.span == None - # TODO make sure we can set span - str(tf) - check_json_roundtrip(tf) - - -def test_tuple_type(): - tp = relay.TypeVar('tp', relay.TypeKind.Type) - tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([])) - tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32') - fields = tvm.runtime.convert([tp, tf, tt]) - - tup_ty = relay.TupleType(fields) - assert tup_ty.fields == fields - str(tup_ty) - check_json_roundtrip(tup_ty) - - -def test_type_relation(): - tp = relay.TypeVar('tp', relay.TypeKind.Type) - tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([])) - tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32') - args = tvm.runtime.convert([tp, tf, tt]) - - num_inputs = 2 - func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") - attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) - - tr = relay.TypeRelation(func, args, num_inputs, attrs) - assert tr.args == args - assert tr.num_inputs == num_inputs - str(tr) - check_json_roundtrip(tr) - def test_constant(): arr = tvm.nd.array(10) @@ -280,13 +208,7 @@ def test_conv2d_attrs(): if __name__ == "__main__": - test_bad_constructor() test_span() - test_tensor_type() - test_type_param() - test_func_type() - test_tuple_type() - test_type_relation() test_constant() test_tuple() test_local_var() diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 6316791..d58ddd5 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -17,7 +17,6 @@ import tvm from tvm import te -from tvm import relay import json def test_type_var(): @@ -36,13 +35,81 @@ def test_type_var(): "b64ndarrays": [], } tvar = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar, relay.TypeVar) + assert isinstance(tvar, tvm.ir.TypeVar) assert tvar.name_hint == "in0" nodes[1]["type_key"] = "relay.GlobalTypeVar" tvar = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar, relay.GlobalTypeVar) + assert isinstance(tvar, tvm.ir.GlobalTypeVar) assert tvar.name_hint == "in0" +def test_incomplete_type(): + nodes = [ + {"type_key": ""}, + {"type_key": "relay.IncompleteType", + "attrs": {"kind": "0", "span": "0"}}] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.IncompleteType) + + +def test_func_tuple_type(): + nodes = [ + {"type_key": ""}, + {"type_key": "relay.FuncType", + "attrs": { + "arg_types": "2", + "ret_type": "3", + "span": "0", + "type_constraints": "6", + "type_params": "5" + } + }, + {"type_key": "Array"}, + {"type_key": "relay.TupleType", + "attrs": { "fields": "4", "span": "0" }}, + {"type_key": "Array"}, + {"type_key": "Array"}, + {"type_key": "Array"} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.FuncType) + + +def test_global_var(): + nodes = [ + {"type_key": ""}, + {"type_key": "relay.GlobalVar", + "attrs": { + "_checked_type_": "0", + "name_hint": "x", + "span": "0" + } + } + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.GlobalVar) + + if __name__ == "__main__": test_type_var() + test_incomplete_type() + test_func_tuple_type() + test_global_var() diff --git a/tests/python/unittest/test_ir_type.py b/tests/python/unittest/test_ir_type.py new file mode 100644 index 0000000..f919f92 --- /dev/null +++ b/tests/python/unittest/test_ir_type.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test type nodes in the IR""" +import tvm + +def check_json_roundtrip(node): + from tvm.relay.analysis import graph_equal + json_str = tvm.ir.save_json(node) + back = tvm.ir.load_json(json_str) + assert graph_equal(back, node) + + +def test_prim_type(): + x = tvm.ir.PrimType("int32") + assert isinstance(x, tvm.ir.PrimType) + assert x.dtype == "int32" + + +def test_tensor_type_bad_constructor(): + try: + x = tvm.ir.TensorType("xx", "xx") + except tvm.error.TVMError: + pass + +def test_tensor_type(): + shape = tvm.runtime.convert([1, 2, 3]) + dtype = 'float32' + tt = tvm.ir.TensorType(shape, dtype) + assert tt.dtype == dtype + assert tt.shape == shape + assert tt.span == None + str(tt) + check_json_roundtrip(tt) + + +def test_type_param(): + tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type) + assert tp.kind == tvm.ir.TypeKind.Type + # assert tp.span # TODO allow us to set span + str(tp) + check_json_roundtrip(tp) + + +def test_func_type(): + type_params = tvm.runtime.convert([]) + type_constraints = tvm.runtime.convert([]) # TODO: fill me in + arg_types = tvm.runtime.convert([]) + ret_type = tvm.ir.TensorType((1, 2, 3), 'float32') + tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints) + assert tf.type_params == type_params + assert tf.type_constraints == type_constraints + assert tf.arg_types == arg_types + assert tf.ret_type == ret_type + assert tf.span == None + # TODO make sure we can set span + str(tf) + check_json_roundtrip(tf) + + +def test_tuple_type(): + tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type) + tf = tvm.ir.FuncType([], None, [], []) + tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32') + fields = tvm.runtime.convert([tp, tf, tt]) + + tup_ty = tvm.ir.TupleType(fields) + assert tup_ty.fields == fields + str(tup_ty) + check_json_roundtrip(tup_ty) + +def test_type_relation(): + tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type) + tf = tvm.ir.FuncType([], None, [], []) + tt = tvm.ir.TensorType( + tvm.runtime.convert([1, 2, 3]), 'float32') + args = tvm.runtime.convert([tp, tf, tt]) + + num_inputs = 2 + func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") + attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) + + tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs) + assert tr.args == args + assert tr.num_inputs == num_inputs + str(tr) + check_json_roundtrip(tr) + +if __name__ == "__main__": + test_tensor_type_bad_constructor() + test_tensor_type() + test_type_param() + test_func_type() + test_tuple_type() + test_type_relation() -- 2.7.4