[REFACTOR][TYPE] Finish move all types to IR. (#4746)
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 20 Jan 2020 22:01:31 +0000 (14:01 -0800)
committerGitHub <noreply@github.com>
Mon, 20 Jan 2020 22:01:31 +0000 (14:01 -0800)
* [REFACTOR][TYPE] Finish move all types to IR.

- Move definition of Ref and TensorType to ir
- Move type_functor.h to public header.
- Rename RefType -> RelayRefType for clarity.

* Add atol

56 files changed:
include/tvm/ir/tensor_type.h [new file with mode: 0644]
include/tvm/ir/type.h
include/tvm/ir/type_functor.h [moved from src/relay/ir/type_functor.h with 77% similarity]
include/tvm/relay/type.h
include/tvm/runtime/object.h
src/ir/tensor_type.cc [moved from src/relay/ir/type.cc with 50% similarity]
src/ir/type.cc
src/ir/type_functor.cc [moved from src/relay/ir/type_functor.cc with 94% similarity]
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc
src/relay/ir/alpha_equal.cc
src/relay/ir/expr.cc
src/relay/ir/expr_functor.cc
src/relay/ir/hash.cc
src/relay/ir/pretty_printer.cc
src/relay/op/algorithm/argsort.cc
src/relay/op/algorithm/topk.cc
src/relay/op/image/resize.cc
src/relay/op/memory/memory.cc
src/relay/op/nn/bitserial.cc
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h
src/relay/op/nn/nn.cc
src/relay/op/nn/nn.h
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/nn/sparse.cc
src/relay/op/nn/upsampling.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
src/relay/op/tensor/transform.h
src/relay/op/tensor/unary.cc
src/relay/op/type_relations.cc
src/relay/op/vision/multibox_op.cc
src/relay/op/vision/nms.cc
src/relay/op/vision/rcnn_op.cc
src/relay/op/vision/yolo.cc
src/relay/pass/de_duplicate.cc
src/relay/pass/eta_expand.cc
src/relay/pass/gradient.cc
src/relay/pass/kind_check.cc
src/relay/pass/partial_eval.cc
src/relay/pass/quantize/quantize.cc
src/relay/pass/to_cps.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
src/relay/pass/util.cc
src/relay/qnn/op/dequantize.cc
src/relay/qnn/op/quantize.cc
src/relay/qnn/op/requantize.cc
src/relay/qnn/util.h
tests/cpp/relay_build_module_test.cc
tests/cpp/relay_pass_type_infer_test.cc
tests/cpp/relay_transform_sequential.cc
tests/cpp/utvm_runtime_standalone_test.cc
tests/python/frontend/mxnet/test_forward.py

diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h
new file mode 100644 (file)
index 0000000..70a2df1
--- /dev/null
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/tensor_type.h
+ * \brief Polymorphic tensor types.
+ */
+#ifndef TVM_IR_TENSOR_TYPE_H_
+#define TVM_IR_TENSOR_TYPE_H_
+
+#include <tvm/ir/type.h>
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+/*!
+ * \brief Base of all Tensor types
+ *  This container can hold TensorType or GenericTensorType.
+ * \sa BaseTensorType, TensorTypeNode
+ */
+class BaseTensorTypeNode : public TypeNode {
+ public:
+  static constexpr const char* _type_key = "relay.BaseTensorType";
+  TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to BaseTensorTypeNode.
+ * \sa BaseTensorTypeNode.
+ */
+class BaseTensorType : public Type {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
+};
+
+/*!
+ * \brief This is the most commonly used type in relay.
+ *  TensorType have a fixed dimension, data type.
+ *
+ *  The elements of shape can be either IntImm(constant integer),
+ *  or any symbolic integer expression.
+ *  The symbolic integer allows generic shape inference in certain cases.
+ * \sa TensorType
+ */
+class TensorTypeNode : public BaseTensorTypeNode {
+ public:
+  /*!
+   * \brief The shape of the tensor,
+   *  represented by PrimExpr(tvm::Expr).
+   */
+  Array<PrimExpr> shape;
+  /*! \brief The content data type */
+  DataType dtype;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+    v->Visit("span", &span);
+  }
+
+  /*! \brief Return product of elements in the shape.
+   *  \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
+   */
+  TVM_DLL PrimExpr Size() const;
+
+  static constexpr const char* _type_key = "relay.TensorType";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
+};
+
+/*!
+ * \brief Managed reference to TensorTypeNode.
+ * \sa TensorTypeNode.
+ */
+class TensorType : public Type {
+ public:
+  /*!
+   * \brief Constructor.
+   * \param shape The shape of the tensor.
+   * \param dtype The runtime dtype of the tensor's elements.
+   */
+  TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);
+
+  /*!
+   * \brief Construct an scalar containing elements of dtype.
+   * \param dtype The runtime dtype of the tensor's elements.
+   * \return THe constructed type.
+   */
+  TVM_DLL static TensorType Scalar(DataType dtype);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
+};
+
+// The following fields contains advanced typing
+// Only keep the class name and reserved for future usage.
+class GenericTensorType;
+// stores a DataType.
+class GenericDataType;
+// stores a DataType.
+class GenericShape;
+
+}  // namespace tvm
+#endif  // TVM_IR_TENSOR_TYPE_H_
index e143588..56f2389 100644 (file)
@@ -352,5 +352,75 @@ class FuncType : public Type {
   TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
 };
 
+/*!
+ * \brief Intermediate values that is used to indicate incomplete type
+ *         during type inference.
+ *
+ * If we view the type relations as "computational graph of types",
+ * then IncompleteType represents intermediate values of the graph,
+ * TypeVar represents the input to the graph.
+ *
+ * \sa IncompleteType
+ */
+class IncompleteTypeNode : public TypeNode {
+ public:
+  /*! \brief kind of the type. */
+  TypeKind kind;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("kind", &kind);
+    v->Visit("span", &span);
+  }
+
+  static constexpr const char* _type_key = "relay.IncompleteType";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to IncompleteTypeNode.
+ * \sa IncompleteTypeNode
+ */
+class IncompleteType : public Type {
+ public:
+  /*!
+   * \brief Constructor.
+   * \param kind kind of the type.
+   */
+  TVM_DLL explicit IncompleteType(TypeKind kind);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
+};
+
+
+/*!
+ * \brief Reference Type High-level Relay IR.
+ *
+ * \sa RelayRefType.
+ */
+class RelayRefTypeNode : public TypeNode {
+ public:
+  /*! \brief The type of value in the Reference. */
+  Type value;
+
+  RelayRefTypeNode() {}
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("value", &value);
+    v->Visit("span", &span);
+  }
+
+  static constexpr const char* _type_key = "relay.RefType";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to RelayRefTypeNode.
+ * \sa RelayRefTypeNode.
+ */
+class RelayRefType : public Type {
+ public:
+  TVM_DLL explicit RelayRefType(Type value);
+  TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
+};
 }  // namespace tvm
 #endif  // TVM_IR_TYPE_H_
similarity index 77%
rename from src/relay/ir/type_functor.h
rename to include/tvm/ir/type_functor.h
index 09049cf..476538c 100644 (file)
  */
 
 /*!
- * \file type_functor.h
+ * \file tvm/ir/type_functor.h
  * \brief A way to defined arbitrary function signature with dispatch on types.
  */
-#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
-#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
+#ifndef TVM_IR_TYPE_FUNCTOR_H_
+#define TVM_IR_TYPE_FUNCTOR_H_
 
 #include <tvm/node/functor.h>
 #include <tvm/relay/expr.h>
 #include <utility>
 
 namespace tvm {
-namespace relay {
 
 template <typename FType>
 class TypeFunctor;
 
 // functions to be overriden.
-#define TYPE_FUNCTOR_DEFAULT \
+#define TYPE_FUNCTOR_DEFAULT                                            \
   { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
 
 
-#define RELAY_TYPE_FUNCTOR_DISPATCH(OP)                                 \
+#define TVM_TYPE_FUNCTOR_DISPATCH(OP)                                   \
   vtable.template set_dispatch<OP>(                                     \
       [](const ObjectRef& n, TSelf* self, Args... args) {               \
         return self->VisitType_(static_cast<const OP*>(n.get()),        \
@@ -89,10 +88,11 @@ class TypeFunctor<R(const Type& n, Args...)> {
   virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
-  virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitTypeDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     throw;  // unreachable, written to stop compiler warning
@@ -103,25 +103,29 @@ class TypeFunctor<R(const Type& n, Args...)> {
   static FType InitVTable() {
     FType vtable;
     // Set dispatch
-    RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
-    RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
     return vtable;
   }
 };
 
+#undef TVM_TYPE_FUNCTOR_DISPATCH
+
 /*!
  * \brief A type visitor that recursively visit types.
  */
-class TypeVisitor : public TypeFunctor<void(const Type& n)> {
+class TVM_DLL TypeVisitor :
+      public TypeFunctor<void(const Type& n)> {
  public:
   void VisitType_(const TypeVarNode* op) override;
   void VisitType_(const IncompleteTypeNode* op) override;
@@ -129,14 +133,18 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
   void VisitType_(const FuncTypeNode* op) override;
   void VisitType_(const TupleTypeNode* op) override;
   void VisitType_(const TypeRelationNode* op) override;
-  void VisitType_(const RefTypeNode* op) override;
+  void VisitType_(const RelayRefTypeNode* op) override;
   void VisitType_(const GlobalTypeVarNode* op) override;
   void VisitType_(const TypeCallNode* op) override;
   void VisitType_(const TypeDataNode* op) override;
+  void VisitType_(const PrimTypeNode* op) override;
 };
 
-// Mutator that transform a type to another one.
-class TypeMutator : public TypeFunctor<Type(const Type& n)> {
+/*!
+ * \brief TypeMutator that mutates expressions.
+ */
+class TVM_DLL TypeMutator :
+      public TypeFunctor<Type(const Type& n)> {
  public:
   Type VisitType(const Type& t) override;
   Type VisitType_(const TypeVarNode* op) override;
@@ -145,10 +153,11 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
   Type VisitType_(const FuncTypeNode* op) override;
   Type VisitType_(const TupleTypeNode* op) override;
   Type VisitType_(const TypeRelationNode* type_rel) override;
-  Type VisitType_(const RefTypeNode* op) override;
+  Type VisitType_(const RelayRefTypeNode* op) override;
   Type VisitType_(const GlobalTypeVarNode* op) override;
   Type VisitType_(const TypeCallNode* op) override;
   Type VisitType_(const TypeDataNode* op) override;
+  Type VisitType_(const PrimTypeNode* op) override;
 
  private:
   Array<Type> MutateArray(Array<Type> arr);
@@ -161,6 +170,5 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
  */
 Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
 
-}  // namespace relay
 }  // namespace tvm
-#endif  // TVM_RELAY_IR_TYPE_FUNCTOR_H_
+#endif  // TVM_IR_TYPE_FUNCTOR_H_
index adf1380..e8f402a 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_RELAY_TYPE_H_
 
 #include <tvm/ir/type.h>
+#include <tvm/ir/tensor_type.h>
 #include <tvm/ir/type_relation.h>
 #include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
@@ -54,6 +55,12 @@ using TypeConstraint = tvm::TypeConstraint;
 using TypeConstraintNode = tvm::TypeConstraintNode;
 using FuncType = tvm::FuncType;
 using FuncTypeNode = tvm::FuncTypeNode;
+using IncompleteType = tvm::IncompleteType;
+using IncompleteTypeNode = tvm::IncompleteTypeNode;
+using RelayRefType = tvm::RelayRefType;
+using RelayRefTypeNode = tvm::RelayRefTypeNode;
+using TensorType = tvm::TensorType;
+using TensorTypeNode = tvm::TensorTypeNode;
 using TypeCall = tvm::TypeCall;
 using TypeCallNode = tvm::TypeCallNode;
 using TypeRelation = tvm::TypeRelation;
@@ -62,136 +69,6 @@ using TypeRelationFn = tvm::TypeRelationFn;
 using TypeReporter = tvm::TypeReporter;
 using TypeReporterNode = tvm::TypeReporterNode;
 
-/*!
- * \brief Base of all Tensor types
- *  This container can hold TensorType or GenericTensorType.
- */
-class BaseTensorTypeNode : public TypeNode {
- public:
-  static constexpr const char* _type_key = "relay.BaseTensorType";
-  TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
-};
-
-class BaseTensorType : public Type {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
-};
-
-/*!
- * \brief This is the most commonly used type in relay.
- *  TensorType have a fixed dimension, data type.
- *
- *  The elements of shape can be either IntImm(constant integer),
- *  or any symbolic integer expression.
- *  The symbolic integer allows generic shape inference in certain cases.
- * \sa TensorTypeNode The container class of TensorType.
- */
-class TensorType;
-/*! \brief TensorType container node */
-class TensorTypeNode : public BaseTensorTypeNode {
- public:
-  /*!
-   * \brief The shape of the tensor,
-   *  represented by IndexExpr(tvm::Expr).
-   */
-  Array<IndexExpr> shape;
-  /*! \brief The content data type */
-  DataType dtype;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("shape", &shape);
-    v->Visit("dtype", &dtype);
-    v->Visit("span", &span);
-  }
-
-  /*! \brief Return product of elements in the shape.
-   *  \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
-   */
-  TVM_DLL IndexExpr Size() const;
-
-  TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
-
-  /*! \brief Construct an scalar containing elements of dtype.  */
-  TVM_DLL static TensorType Scalar(DataType dtype);
-
-  static constexpr const char* _type_key = "relay.TensorType";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
-};
-
-class TensorType : public Type {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
-};
-
-/*!
- * \brief IncompleteType.
- * This is intermediate values that is used during type inference.
- *
- * If we view the type relations as "computational graph of types",
- * then IncompleteType represents intermediate values of the graph,
- * TypeVar represents the input to the graph.
- */
-class IncompleteType;
-
-/*! \brief IncompleteType container node */
-class IncompleteTypeNode : public TypeNode {
- public:
-  Kind kind;
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("kind", &kind);
-    v->Visit("span", &span);
-  }
-
-  TVM_DLL static IncompleteType make(Kind kind);
-
-  static constexpr const char* _type_key = "relay.IncompleteType";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
-};
-
-class IncompleteType : public Type {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
-};
-
-/*!
- * \brief The type of reference values.
- */
-class RefType;
-/*!
- * \brief Reference Type in relay.
- */
-class RefTypeNode : public TypeNode {
- public:
-  /*! \brief The type of value in the Reference. */
-  Type value;
-
-  RefTypeNode() {}
-
-  void VisitAttrs(tvm::AttrVisitor* v) {
-    v->Visit("value", &value);
-    v->Visit("span", &span);
-  }
-
-  TVM_DLL static RefType make(Type value);
-
-  static constexpr const char* _type_key = "relay.RefType";
-  TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode);
-};
-
-class RefType : public Type {
- public:
-  TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
-};
-
-// The following fields contains advanced typing
-// Only keep the class name and reserved for future usage.
-class GenericTensorType;
-// stores a DataType.
-class GenericDataType;
-// stores a DataType.
-class GenericShape;
-
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_TYPE_H_
index 1989afd..ba84b5f 100644 (file)
@@ -317,8 +317,8 @@ class Object {
  * \tparam ObjectType The object type
  * \return The corresponding RefType
  */
-template <typename RefType, typename ObjectType>
-inline RefType GetRef(const ObjectType* ptr);
+template <typename RelayRefType, typename ObjectType>
+inline RelayRefType GetRef(const ObjectType* ptr);
 
 /*!
  * \brief Downcast a base reference type to a more specific type.
@@ -484,8 +484,8 @@ class ObjectPtr {
   friend class TVMArgsSetter;
   friend class TVMRetValue;
   friend class TVMArgValue;
-  template <typename RefType, typename ObjType>
-  friend RefType GetRef(const ObjType* ptr);
+  template <typename RelayRefType, typename ObjType>
+  friend RelayRefType GetRef(const ObjType* ptr);
   template <typename BaseType, typename ObjType>
   friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
 };
@@ -848,11 +848,11 @@ inline const ObjectType* ObjectRef::as() const {
   }
 }
 
-template <typename RefType, typename ObjType>
-inline RefType GetRef(const ObjType* ptr) {
-  static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
+template <typename RelayRefType, typename ObjType>
+inline RelayRefType GetRef(const ObjType* ptr) {
+  static_assert(std::is_base_of<typename RelayRefType::ContainerType, ObjType>::value,
                 "Can only cast to the ref of same container type");
-  return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
+  return RelayRefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
 }
 
 template <typename BaseType, typename ObjType>
similarity index 50%
rename from src/relay/ir/type.cc
rename to src/ir/tensor_type.cc
index f1e59a4..0a9ed4e 100644 (file)
  */
 
 /*!
- * \file src/tvm/ir/type.cc
+ * \file src/tvm/ir/tensor_type.cc
  * \brief The type system AST nodes of Relay.
  */
-#include <tvm/relay/type.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/ir/tensor_type.h>
 #include <tvm/tir/op.h>
 
 namespace tvm {
-namespace relay {
 
 using tvm::NodePrinter;
 using namespace tvm::runtime;
 
-TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
+TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
   ObjectPtr<TensorTypeNode> n = make_object<TensorTypeNode>();
   n->shape = std::move(shape);
   n->dtype = std::move(dtype);
-  return TensorType(n);
+  data_ = std::move(n);
 }
 
-TensorType TensorTypeNode::Scalar(DataType dtype) {
-  return TensorTypeNode::make({}, dtype);
+TensorType TensorType::Scalar(DataType dtype) {
+  return TensorType({}, dtype);
 }
 
-IndexExpr TensorTypeNode::Size() const {
+PrimExpr TensorTypeNode::Size() const {
   if (shape.size() == 0) {
     return tir::make_const(DataType::Int(64), 1);
   }
 
-  IndexExpr size = shape[0];
+  PrimExpr size = shape[0];
   for (size_t i = 1; i < shape.size(); ++i) {
     size *= shape[i];
   }
@@ -56,7 +56,9 @@ IndexExpr TensorTypeNode::Size() const {
 TVM_REGISTER_NODE_TYPE(TensorTypeNode);
 
 TVM_REGISTER_GLOBAL("relay._make.TensorType")
-.set_body_typed(TensorTypeNode::make);
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
+  return TensorType(shape, dtype);
+});
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 .set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
@@ -64,45 +66,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
 });
 
-IncompleteType IncompleteTypeNode::make(Kind kind) {
-  auto n = make_object<IncompleteTypeNode>();
-  n->kind = std::move(kind);
-  return IncompleteType(n);
-}
-
-TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
-
-TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
-.set_body_typed([](int kind) {
-    return IncompleteTypeNode::make(static_cast<Kind>(kind));
-  });
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
-    auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
-    p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
-  });
-
-
-RefType RefTypeNode::make(Type value) {
-  ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>();
-  n->value = std::move(value);
-  return RefType(n);
-}
-
-TVM_REGISTER_GLOBAL("relay._make.RefType")
-.set_body_typed(RefTypeNode::make);
-
-TVM_REGISTER_NODE_TYPE(RefTypeNode);
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<RefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
-  auto* node = static_cast<const RefTypeNode*>(ref.get());
-  p->stream << "RefTypeNode(" << node->value << ")";
-});
-
-TVM_REGISTER_GLOBAL("relay._make.Any")
-.set_body_typed([]() { return Any::make(); });
-
-}  // namespace relay
 }  // namespace tvm
index 9e250db..233274a 100644 (file)
@@ -118,6 +118,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
             << node->type_constraints << ")";
 });
 
+
 TupleType::TupleType(Array<Type> fields) {
   ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
   n->fields = std::move(fields);
@@ -141,4 +142,44 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   p->stream << "TupleTypeNode(" << node->fields << ")";
 });
 
+
+IncompleteType::IncompleteType(TypeKind kind) {
+  auto n = make_object<IncompleteTypeNode>();
+  n->kind = std::move(kind);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
+
+TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
+.set_body_typed([](int kind) {
+    return IncompleteType(static_cast<TypeKind>(kind));
+  });
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
+    auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
+    p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
+  });
+
+
+RelayRefType::RelayRefType(Type value) {
+  ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
+  n->value = std::move(value);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("relay._make.RefType")
+.set_body_typed([](Type value) {
+  return RelayRefType(value);
+});
+
+TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
+  auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
+  p->stream << "RelayRefTypeNode(" << node->value << ")";
+});
+
 }  // namespace tvm
similarity index 94%
rename from src/relay/ir/type_functor.cc
rename to src/ir/type_functor.cc
index 0180a0c..cbd3538 100644 (file)
  * \file type_functor.cc
  * \brief Implementations of type functors.
  */
+#include <tvm/ir/type_functor.h>
 #include <utility>
-#include "type_functor.h"
 
 namespace tvm {
-namespace relay {
 
 void TypeVisitor::VisitType_(const TypeVarNode* op) {
 }
@@ -57,7 +56,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) {
   }
 }
 
-void TypeVisitor::VisitType_(const RefTypeNode* op) {
+void TypeVisitor::VisitType_(const RelayRefTypeNode* op) {
   this->VisitType(op->value);
 }
 
@@ -91,6 +90,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
   }
 }
 
+void TypeVisitor::VisitType_(const PrimTypeNode* op) {
+}
+
 Type TypeMutator::VisitType(const Type& t) {
   return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
 }
@@ -169,8 +171,8 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
   }
 }
 
-Type TypeMutator::VisitType_(const RefTypeNode* op) {
-  return RefTypeNode::make(this->VisitType(op->value));
+Type TypeMutator::VisitType_(const RelayRefTypeNode* op) {
+  return RelayRefType(this->VisitType(op->value));
 }
 
 Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
@@ -203,6 +205,10 @@ Type TypeMutator::VisitType_(const TypeDataNode* op) {
   return GetRef<Type>(op);
 }
 
+Type TypeMutator::VisitType_(const PrimTypeNode* op) {
+  return GetRef<Type>(op);
+}
+
 // Implements bind.
 class TypeBinder : public TypeMutator {
  public:
@@ -227,5 +233,4 @@ Type Bind(const Type& type, const tvm::Map<TypeVar, Type>& args_map) {
   return TypeBinder(args_map).VisitType(type);
 }
 
-}  // namespace relay
 }  // namespace tvm
index 86d16a8..8ba6eb4 100644 (file)
@@ -21,8 +21,7 @@
  * \file relay/backend/compile_engine.cc
  * \brief Internal compialtion engine.
  */
-#include "compile_engine.h"
-
+#include <tvm/ir/type_functor.h>
 #include <tvm/top/schedule.h>
 #include <tvm/top/operation.h>
 #include <tvm/top/schedule_pass.h>
@@ -42,7 +41,8 @@
 #include <functional>
 #include <vector>
 #include <unordered_map>
-#include "../ir/type_functor.h"
+
+#include "compile_engine.h"
 
 namespace tvm {
 namespace relay {
@@ -239,12 +239,12 @@ class ScheduleGetter :
     // TODO(@icemelon): Support recursive tuple
     Type call_node_type = call_node->checked_type();
     if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
-      call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
+      call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
     } else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
       std::vector<Type> new_fields;
       for (auto field : tuple_t->fields) {
         if (const auto* tt = field.as<TensorTypeNode>()) {
-          new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
+          new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
         } else {
           new_fields.push_back(field);
         }
index 95d667f..224ff77 100644 (file)
@@ -529,7 +529,7 @@ class Interpreter :
         if (is_dyn) {
           auto sh = out_shapes[i];
           auto tt = Downcast<TensorType>(rtype->fields[i]);
-          fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
+          fields.push_back(fset_output(i, TensorType(sh, tt->dtype)));
         } else {
           fields.push_back(fset_output(i, rtype->fields[i]));
         }
@@ -542,7 +542,7 @@ class Interpreter :
         CHECK_EQ(out_shapes.size(), 1);
         auto sh = out_shapes[0];
         auto tt = Downcast<TensorType>(ret_type);
-        out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
+        out_tensor = fset_output(0, TensorType(sh, tt->dtype));
       } else {
         out_tensor = fset_output(0, ret_type);
       }
index b55a4af..2d07f61 100644 (file)
@@ -21,6 +21,7 @@
  * \file src/tvm/relay/ir/alpha_equal.cc
  * \brief Alpha equality check by deep comparing two nodes.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
@@ -28,7 +29,6 @@
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/nn.h>
-#include "type_functor.h"
 #include "../../ir/attr_functor.h"
 namespace tvm {
 namespace relay {
@@ -277,8 +277,8 @@ class AlphaEqualHandler:
     }
   }
 
-  bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
-    if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
+  bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
+    if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
       return TypeEqual(lhs->value, rhs->value);
     }
     return false;
index 7e19d51..3d8cc3a 100644 (file)
@@ -59,7 +59,7 @@ TensorType ConstantNode::tensor_type() const {
         tvm::IntImm(DataType::Int(32), data->shape[i]));
   }
 
-  return TensorTypeNode::make(shape, dtype);
+  return TensorType(shape, dtype);
 }
 
 Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
@@ -129,12 +129,12 @@ FuncType FunctionNode::func_type_annotation() const {
   Array<Type> param_types;
   for (auto param : this->params) {
     Type param_type = (param->type_annotation.defined()) ? param->type_annotation
-      : IncompleteTypeNode::make(Kind::kType);
+      : IncompleteType(Kind::kType);
     param_types.push_back(param_type);
   }
 
   Type ret_type = (this->ret_type.defined()) ? this->ret_type
-    : IncompleteTypeNode::make(Kind::kType);
+    : IncompleteType(Kind::kType);
   return FuncType(param_types, ret_type, this->type_params, {});
 }
 
@@ -359,5 +359,8 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
     return FunctionSetAttr(func, name, ref);
 });
 
+TVM_REGISTER_GLOBAL("relay._make.Any")
+.set_body_typed([]() { return Any::make(); });
+
 }  // namespace relay
 }  // namespace tvm
index 0da763a..c525b9e 100644 (file)
  * ExprMutator uses memoization and self return in order to amortize
  * the cost of using functional updates.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
-#include "type_functor.h"
 
 namespace tvm {
 namespace relay {
index b1906d3..9977b5c 100644 (file)
  * \file src/tvm/relay/ir/hash.cc
  * \brief Hash functions for Relay types and expressions.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/tir/ir_pass.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/ir/attrs.h>
-#include "type_functor.h"
 #include "../../ir/attr_functor.h"
 
 namespace tvm {
@@ -201,8 +201,8 @@ class RelayHashHandler:
     return hash;
   }
 
-  size_t VisitType_(const RefTypeNode* rtn) final {
-    size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
+  size_t VisitType_(const RelayRefTypeNode* rtn) final {
+    size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
     hash = Combine(hash, TypeHash(rtn->value));
     return hash;
   }
index ae2089d..c21f565 100644 (file)
  *    - Var
  *  - Otherwise, inline if the node is at the end of a scope and is used at most once.
  */
-
+#include <tvm/ir/type_functor.h>
 #include <tvm/node/serialization.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/ir/module.h>
 #include <tvm/relay/pattern_functor.h>
 #include "doc.h"
-#include "type_functor.h"
 #include "../pass/dependency_graph.h"
 #include "../../ir/attr_functor.h"
 
@@ -779,7 +778,7 @@ class PrettyPrinter :
     return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
   }
 
-  Doc VisitType_(const RefTypeNode* node) final {
+  Doc VisitType_(const RelayRefTypeNode* node) final {
     Doc doc;
     return doc << "ref(" << Print(node->value) << ")";
   }
index 0d68b44..13d89a7 100644 (file)
@@ -43,7 +43,7 @@ bool ArgsortRel(const Array<Type>& types,
         << types[0];
     return false;
   }
-  reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
+  reporter->Assign(types[1], TensorType(data->shape, param->dtype));
   return true;
 }
 
index 888d431..0ff30bb 100644 (file)
@@ -52,8 +52,8 @@ bool TopKRel(const Array<Type>& types,
       out_shape.push_back(param->k);
     }
   }
-  auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
-  auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
+  auto values_ty = TensorType(out_shape, data->dtype);
+  auto indices_ty = TensorType(out_shape, param->dtype);
   if (param->ret_type == "both") {
     reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
   } else if (param->ret_type == "values") {
index e796a04..4349e09 100644 (file)
@@ -60,7 +60,7 @@ bool ResizeRel(const Array<Type>& types,
 
   // assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+                   TensorType(layout_converter.BackwardShape(oshape),
                                         out_dtype));
   return true;
 }
@@ -143,7 +143,7 @@ bool CropAndResizeRel(const Array<Type>& types,
   auto bshape = layout_converter.BackwardShape(oshape);
   // assign output type
   reporter->Assign(types[3],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+                   TensorType(layout_converter.BackwardShape(oshape),
                                         out_dtype));
   return true;
 }
index 6c4b3ea..aa0ba2d 100644 (file)
@@ -154,11 +154,11 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
     for (auto i = 0u; i < dims; i++) {
       out_shape.push_back(tvm::Integer(sh[i]));
     }
-    alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype);
+    alloc_type = TensorType(out_shape, alloc_attrs->dtype);
   } else {
     CHECK(alloc_attrs->assert_shape.defined())
         << "the assert_shape must be set when const_shape is not";
-    alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype);
+    alloc_type = TensorType(alloc_attrs->assert_shape, alloc_attrs->dtype);
     return true;
   }
 
@@ -309,13 +309,13 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
       shape_func_ins.push_back(in_type);
     } else {
       auto shape = RankShape(in_type->shape);
-      shape_func_ins.push_back(TensorTypeNode::make(shape, DataType::Int(64)));
+      shape_func_ins.push_back(TensorType(shape, DataType::Int(64)));
     }
   }
 
   for (auto out_type : out_types) {
     auto rank_shape = RankShape(out_type->shape);
-    shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64)));
+    shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64)));
   }
 
   auto input_type = TupleType(shape_func_ins);
index eccffc8..c9e05e1 100644 (file)
@@ -81,7 +81,7 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     out_shape.push_back(bits);
   }
 
-  reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type));
+  reporter->Assign(types[1], TensorType(out_shape, pack_type));
   return true;
 }
 
@@ -144,7 +144,7 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
   DataType out_dtype = param->out_dtype;
   oshape = trans_in_layout.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -220,7 +220,7 @@ bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   }
 
   // Assign output type.
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
index 82f4ba5..6977ac9 100644 (file)
@@ -271,7 +271,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
     channels = param->channels;
 
     // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+    reporter->Assign(types[1], TensorType(wshape, data->dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -310,7 +310,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
     out_dtype = data->dtype;
   }
   oshape = trans_out_layout.BackwardShape(oshape);
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -434,7 +434,7 @@ bool Conv1DTransposeRel(const Array<Type>& types,
     channels = param->channels;
 
     // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+    reporter->Assign(types[1], TensorType(wshape, data->dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -469,7 +469,7 @@ bool Conv1DTransposeRel(const Array<Type>& types,
     out_dtype = data->dtype;
   }
   oshape = trans_out_layout.BackwardShape(oshape);
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -616,7 +616,7 @@ bool Conv2DWinogradRel(const Array<Type>& types,
   }
   oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -702,7 +702,7 @@ bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
       data->shape[1],
   };
 
-  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
                                                   data->dtype));
   return true;
 }
@@ -817,7 +817,7 @@ bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
   }
-  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
   return true;
 }
 
@@ -1025,7 +1025,7 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
     dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
     // assign result to reporter
-    reporter->Assign(types[2], TensorTypeNode::make(wshape, data->dtype));
+    reporter->Assign(types[2], TensorType(wshape, data->dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -1066,12 +1066,12 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
   // infer offset shape
   Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
           oshape[2], oshape[3]});
-  reporter->Assign(types[1], TensorTypeNode::make(offset_shape, data->dtype));
+  reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
   if (out_dtype.bits() == 0) {
     out_dtype = data->dtype;
   }
 
-  reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[3], TensorType(oshape, out_dtype));
   return true;
 }
 
index f858efc..4061909 100644 (file)
@@ -81,7 +81,7 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
       weight_dtype = weight->dtype;
     }
     // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -117,7 +117,7 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
   oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -179,7 +179,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
       weight_dtype = weight->dtype;
     }
     // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -226,7 +226,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
   oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
@@ -290,7 +290,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     }
 
     // assign result to reporter
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
@@ -346,7 +346,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   }
   oshape = trans_out_layout.BackwardShape(oshape);
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
index 2ff439a..1f6ad8f 100644 (file)
@@ -61,7 +61,7 @@ bool BiasAddRel(const Array<Type>& types,
       << "axis " << param->axis << " is out of range";
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(
+  reporter->Assign(types[1], TensorType(
       {data->shape[axis]}, data->dtype));
   reporter->Assign(types[2], types[0]);
   return true;
@@ -138,7 +138,7 @@ bool FIFOBufferRel(const Array<Type>& types,
 
   Array<tvm::PrimExpr> oshape = buffer->shape;
 
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
+  reporter->Assign(types[2], TensorType(oshape, buffer->dtype));
   return true;
 }
 
@@ -260,10 +260,10 @@ bool PReluRel(const Array<Type>& types,
 
   // assign alpha type
   Array<IndexExpr> alpha_shape({data->shape[param->axis]});
-  reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));
+  reporter->Assign(types[1], TensorType(alpha_shape, data->dtype));
 
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
+  reporter->Assign(types[2], TensorType(data->shape, data->dtype));
   return true;
 }
 
@@ -419,7 +419,7 @@ bool BatchFlattenRel(const Array<Type>& types,
   std::vector<IndexExpr> oshape({data->shape[0], target_dim});
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -585,7 +585,7 @@ bool DropoutRel(const Array<Type>& types,
 
   // dropout returns the original tensor with dropout applied
   // and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
-  auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
+  auto ret_type = TensorType(data->shape, data->dtype);
   reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
   return true;
 }
@@ -661,17 +661,17 @@ bool BatchNormRel(const Array<Type>& types,
   auto axis_size = data->shape[axis];
 
   // if we are using beta and gamma, they need to be of shape (dim,)
-  reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype));
-  reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype));
-  reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype));
-  reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype));
+  reporter->Assign(types[1], TensorType({axis_size}, data->dtype));
+  reporter->Assign(types[2], TensorType({axis_size}, data->dtype));
+  reporter->Assign(types[3], TensorType({axis_size}, data->dtype));
+  reporter->Assign(types[4], TensorType({axis_size}, data->dtype));
 
   // output is a tuple of the normed data (same shape as input), new running mean,
   // and new running average (the latter two are both vectors of length dim)
   std::vector<Type> fields;
-  auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
+  auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}),
                                      data->dtype);
-  fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
+  fields.push_back(TensorType(data->shape, data->dtype));
   fields.push_back(vec_ty);
   fields.push_back(vec_ty);
   reporter->Assign(types[5], TupleType(Array<Type>(fields)));
@@ -754,9 +754,9 @@ bool InstanceNormRel(const Array<Type>& types,
   const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
   int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
   CHECK(axis >= 0 && axis < (int)data->shape.size());
-  reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
-  reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
-  reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
+  reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
 
   return true;
 }
@@ -824,9 +824,9 @@ bool LayerNormRel(const Array<Type>& types,
   const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
   int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
   CHECK(axis >= 0 && axis < (int)data->shape.size());
-  reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
-  reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
-  reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
+  reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
 
   return true;
 }
@@ -881,7 +881,7 @@ bool BatchMatmulRel(const Array<Type>& types,
   oshape.Set(2, y->shape[1]);
 
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
+  reporter->Assign(types[2], TensorType(oshape, x->dtype));
   return true;
 }
 
@@ -940,7 +940,7 @@ bool CrossEntropyRel(const Array<Type>& types,
     << "x shape = " << x->shape << ", "
     << "y shape = " << y->shape;
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
+  reporter->Assign(types[2], TensorType({}, x->dtype));
   return true;
 }
 
@@ -1016,7 +1016,7 @@ bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
 
   // Assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
+                   TensorType(layout_converter.BackwardShape(oshape), data->dtype));
 
   return true;
 }
@@ -1074,7 +1074,7 @@ bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attr
 
   // Assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
+                   TensorType(layout_converter.BackwardShape(oshape), data->dtype));
 
   return true;
 }
index 7389909..dc876e8 100644 (file)
@@ -52,7 +52,7 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     // data dtype as the weight dtype. However if weight dtype is explicitly
     // present we will use that.
     auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
-    reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+    reporter->Assign(types[1], TensorType(wshape, weight_dtype));
     oshape.Set((oshape.size() - 1), param->units);
   } else {
     if (weight == nullptr) return false;
@@ -70,7 +70,7 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     out_dtype = data->dtype;
   }
   // assign output type
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
   return true;
 }
 
index e33f751..1656158 100644 (file)
@@ -155,7 +155,7 @@ bool PadRel(const Array<Type>& types,
     }
   }
 
-  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
                                                   data->dtype));
   return true;
 }
@@ -260,7 +260,7 @@ bool MirrorPadRel(const Array<Type>& types,
     oshape.push_back(data->shape[i] + padding);
   }
 
-  reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+  reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
                                                   data->dtype));
   return true;
 }
index 0c74b27..7b6deff 100644 (file)
@@ -161,7 +161,7 @@ bool Pool2DRel(const Array<Type>& types,
   }
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -327,7 +327,7 @@ bool GlobalPool2DRel(const Array<Type>& types,
   oshape.Set(widx, 1);
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -462,7 +462,7 @@ bool AdaptivePool2DRel(const Array<Type>& types,
   oshape.Set(widx, output_width);
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -792,7 +792,7 @@ bool Pool1DRel(const Array<Type>& types,
   }
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -987,7 +987,7 @@ bool Pool3DRel(const Array<Type>& types,
   }
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
index caad01b..c01b760 100644 (file)
@@ -47,7 +47,7 @@ bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   if (weight_data->shape.size() == 1) {
     // CSR case.
     Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
-    reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
+    reporter->Assign(types[4], TensorType(oshape, data->dtype));
     return true;
   }
 
@@ -56,7 +56,7 @@ bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
     Array<IndexExpr> oshape({
         data->shape[0],
           (weight_indptr->shape[0] - 1) * weight_data->shape[1]});
-    reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
+    reporter->Assign(types[4], TensorType(oshape, data->dtype));
     return true;
   }
   LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
@@ -105,9 +105,9 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
   const auto* sparse_indptr = types[2].as<TensorTypeNode>();
 
   std::vector<Type> output_types;
-  output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype));
-  output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
-  output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
+  output_types.push_back(TensorType(sparse_data->shape, sparse_data->dtype));
+  output_types.push_back(TensorType(sparse_indices->shape, sparse_indices->dtype));
+  output_types.push_back(TensorType(sparse_indptr->shape, sparse_indptr->dtype));
 
   reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
   return true;
index e78f7fd..477cec7 100644 (file)
@@ -87,7 +87,7 @@ bool UpSamplingRel(const Array<Type>& types,
 
   // assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+                   TensorType(layout_converter.BackwardShape(oshape),
                                         data->dtype));
   return true;
 }
@@ -167,7 +167,7 @@ bool UpSampling3DRel(const Array<Type>& types,
 
   // assign output type
   reporter->Assign(types[1],
-                   TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+                   TensorType(layout_converter.BackwardShape(oshape),
                                         data->dtype));
   return true;
 }
index 5156330..880a337 100644 (file)
@@ -272,7 +272,7 @@ bool ArgReduceRel(const Array<Type>& types,
 
   // assign output type and shape
   auto oshape = ReduceShapeImpl(in_shape, param, reporter);
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Int(32)));
+  reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
   return true;
 }
 
@@ -297,7 +297,7 @@ bool ReduceRel(const Array<Type>& types,
 
   // assign output type and shape
   auto oshape = ReduceShapeImpl(in_shape, param, reporter);
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -594,7 +594,7 @@ bool VarianceRel(const Array<Type>& types,
 
   // assign output type and shape
   auto oshape = ReduceShapeImpl(in_shape, param, reporter);
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
index 538c92e..b958755 100644 (file)
@@ -60,7 +60,7 @@ bool CastRel(const Array<Type>& types,
     return false;
   }
   const auto* param = attrs.as<CastAttrs>();
-  reporter->Assign(types[1], TensorTypeNode::make(
+  reporter->Assign(types[1], TensorType(
       data->shape, param->dtype));
   return true;
 }
@@ -120,7 +120,7 @@ bool CastLikeRel(const Array<Type>& types,
         << types[1];
     return false;
   }
-  reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype));
+  reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype));
   return true;
 }
 
@@ -226,7 +226,7 @@ bool ExpandDimsRel(const Array<Type>& types,
   for (int i = pivot; i < ndim; ++i) {
     oshape.emplace_back(data->shape[i]);
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -408,7 +408,7 @@ bool StackRel(const Array<Type>& types,
   for (int i = axis; i < ndim; ++i) {
     oshape.emplace_back(first->shape[i]);
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype));
+  reporter->Assign(types[1], TensorType(oshape, dtype));
   return true;
 }
 
@@ -500,7 +500,7 @@ bool TransposeRel(const Array<Type>& types,
   for (int axis : int_axes) {
     oshape.push_back(data->shape[axis]);
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -679,10 +679,10 @@ bool ReshapeRel(const Array<Type>& types,
   }
 
   if (param->reverse) {
-    reporter->Assign(types[1], TensorTypeNode::make(
+    reporter->Assign(types[1], TensorType(
         Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
   } else {
-    reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+    reporter->Assign(types[1], TensorType(oshape, data->dtype));
   }
   return true;
 }
@@ -809,7 +809,7 @@ bool ReshapeLikeRel(const Array<Type>& types,
     CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
       << "Reshape inputs size should be compatible.";
   }
-  reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
+  reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
   return true;
 }
 
@@ -853,7 +853,7 @@ bool ArgWhereRel(const Array<Type>& types,
   std::vector<IndexExpr> result_shape;
   result_shape.push_back(Any::make());
   result_shape.push_back(IntImm(DataType::Int(32), input_rank));
-  reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
+  reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32)));
   return true;
 }
 
@@ -894,7 +894,7 @@ bool TakeRel(const Array<Type>& types,
 
   if (!param->axis.defined()) {
     std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
-    reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+    reporter->Assign(types[2], TensorType(oshape, data->dtype));
     return true;
   }
 
@@ -918,7 +918,7 @@ bool TakeRel(const Array<Type>& types,
     oshape.emplace_back(data->shape[i]);
   }
 
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -1005,7 +1005,7 @@ bool FullRel(const Array<Type>& types,
     << "Fill value should be a scalar but has dimension "
     << fill_value->shape.size() << ".";
 
-  reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype));
+  reporter->Assign(types[1], TensorType(param->shape, out_dtype));
   return true;
 }
 
@@ -1049,7 +1049,7 @@ bool InitOpRel(const Array<Type>& types,
   CHECK_EQ(types.size(), 1);
   const InitOpAttrs* param = attrs.as<InitOpAttrs>();
 
-  reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype));
+  reporter->Assign(types[0], TensorType(param->shape, param->dtype));
   return true;
 }
 
@@ -1113,7 +1113,7 @@ bool FullLikeRel(const Array<Type>& types,
     << "The fill value should be a scalar but here it has dimension "
     << fill_value->shape.size() << ".";
 
-  reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
+  reporter->Assign(types[2], TensorType(data->shape, data->dtype));
   return true;
 }
 
@@ -1197,7 +1197,7 @@ bool ArangeRel(const Array<Type>& types,
 
   reporter->Assign(types[0], types[1]);
   reporter->Assign(types[1], types[2]);
-  reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype));
+  reporter->Assign(types[2], TensorType({}, attrs->dtype));
 
   if ((cstart = attrs->start.as<ConstantNode>()) &&
       (cstop = attrs->stop.as<ConstantNode>()) &&
@@ -1209,10 +1209,10 @@ bool ArangeRel(const Array<Type>& types,
     CHECK_GT(num_elem, 0)
         << "Invalid arange attributes (start, stop, step): " << attrs->start
         << ", " << attrs->stop << ", " << attrs->step;
-    reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype));
+    reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype));
     return true;
   } else {
-    reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype));
+    reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype));
     return true;
   }
 }
@@ -1320,7 +1320,7 @@ bool RepeatRel(const Array<Type>& types,
   for (int i = pivot + 1; i < ndim; ++i) {
     oshape.emplace_back(data->shape[i]);
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -1431,7 +1431,7 @@ bool TileRel(const Array<Type>& types,
       oshape.emplace_back(data_shape[i] * reps_shape[i]);
     }
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -1560,7 +1560,7 @@ bool WhereRel(const Array<Type>& types,
         << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
     }
   }
-  reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype));
+  reporter->Assign(types[3], TensorType(x_shape, x->dtype));
   return true;
 }
 
@@ -1683,7 +1683,7 @@ bool SqueezeRel(const Array<Type>& types,
       }
     }
   }
-  reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
+  reporter->Assign(types[1], TensorType(result_shape, data->dtype));
   return true;
 }
 
@@ -1761,7 +1761,7 @@ bool BroadCastToRel(const Array<Type>& types,
   CHECK(ioattrs);
   auto intt = types[0].as<TensorTypeNode>();
   if (intt == nullptr) { return false; }
-  auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
+  auto type = TensorType(ioattrs->shape, intt->dtype);
   reporter->Assign(types[1], type);
   return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
 }
@@ -1942,7 +1942,7 @@ bool StridedSliceRel(const Array<Type>& types,
     }
     oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
   }
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -2147,7 +2147,7 @@ bool SplitRel(const Array<Type>& types,
     for (int i = 0; i < sections->value; ++i) {
         std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
         oshape[axis] = indexdiv(oshape[axis], sections->value);
-        auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+        auto vec_type = TensorType(oshape, data->dtype);
         fields.push_back(vec_type);
     }
     reporter->Assign(types[1], TupleType(Array<Type>(fields)));
@@ -2161,14 +2161,14 @@ bool SplitRel(const Array<Type>& types,
       std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
       oshape[axis] = Downcast<IndexExpr>(indices[i]) - begin;
       begin = Downcast<IndexExpr>(indices[i]);
-      auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+      auto vec_type = TensorType(oshape, data->dtype);
       fields.push_back(vec_type);
     }
     CHECK(reporter->Assert(begin < data->shape[axis]))
         << "The sum of sections must match the input.shape[axis]";
     std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
     oshape[axis] = data->shape[axis] - begin;
-    auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+    auto vec_type = TensorType(oshape, data->dtype);
     fields.push_back(vec_type);
     reporter->Assign(types[1], TupleType(Array<Type>(fields)));
   }
@@ -2290,7 +2290,7 @@ bool SliceLikeRel(const Array<Type>& types,
     }
   }
 
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -2400,7 +2400,7 @@ bool LayoutTransformRel(const Array<Type>& types,
     << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
 
   const auto& out_shape = layout_converter.ForwardShape(data->shape);
-  reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
+  reporter->Assign(types[1], TensorType(out_shape, data->dtype));
   return true;
 }
 
@@ -2499,7 +2499,7 @@ bool GatherNDRel(const Array<Type>& types,
       oshape.push_back(indices->shape[i]);
   for (size_t i = mdim->value; i < ndim; ++i)
       oshape.push_back(data->shape[i]);
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -2552,7 +2552,7 @@ bool SequenceMaskRel(const Array<Type>& types,
   Array<IndexExpr> valid_length_shape;
   CHECK(param->axis == 0 || param->axis == 1);
   valid_length_shape.push_back(data->shape[1 - param->axis]);
-  reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype));
+  reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype));
   reporter->Assign(types[2], types[0]);
   return true;
 }
@@ -2666,7 +2666,7 @@ bool OneHotRel(const Array<Type>& types,
     }
   }
 
-  reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
+  reporter->Assign(types[3], TensorType(oshape, param->dtype));
   return true;
 }
 
index a1cbf7a..b69f6e7 100644 (file)
@@ -119,7 +119,7 @@ bool ConcatenateRel(const Array<Type>& types,
     concat_dim = Any::make();
   }
 
-  auto rtype = TensorTypeNode::make(oshape, dtype);
+  auto rtype = TensorType(oshape, dtype);
   reporter->Assign(types[1], rtype);
   return true;
 }
index 331653b..98ff099 100644 (file)
@@ -286,7 +286,7 @@ bool ShapeOfRel(const Array<Type>& types,
   const auto* param = attrs.as<ShapeOfAttrs>();
   CHECK(param != nullptr);
   auto rank_shape = RankShape(tt->shape);
-  reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
+  reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
   return true;
 }
 
@@ -337,7 +337,7 @@ bool NdarraySizeRel(const Array<Type>& types,
   CHECK(tt != nullptr);
   const auto* param = attrs.as<NdarraySizeAttrs>();
   CHECK(param != nullptr);
-  reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
+  reporter->Assign(types[1], TensorType({1}, param->dtype));
   return true;
 }
 
index cd476fd..3c7d148 100644 (file)
@@ -96,7 +96,7 @@ Type ConcreteBroadcast(const TensorType& t1,
   for (; i <= max_ndim; ++i) {
     oshape.push_back(rshape[max_ndim - i]);
   }
-  return TensorTypeNode::make(Array<IndexExpr>(
+  return TensorType(Array<IndexExpr>(
       oshape.rbegin(), oshape.rend()), output_dtype);
 }
 
index b801186..eb5012f 100644 (file)
@@ -50,7 +50,7 @@ bool MultiboxPriorRel(const Array<Type>& types,
     {1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
 
   // assign output type
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -122,8 +122,8 @@ bool MultiBoxTransformLocRel(const Array<Type>& types,
   std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
   std::vector<IndexExpr> oshape1({cls_shape[0]});
   std::vector<Type> fields;
-  fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype));
-  fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32)));
+  fields.push_back(TensorType(oshape0, cls_prob->dtype));
+  fields.push_back(TensorType(oshape1, DataType::Int(32)));
 
   // assign output type
   reporter->Assign(types[3], TupleType(Array<Type>(fields)));
index 4524779..bec0c1d 100644 (file)
@@ -40,8 +40,8 @@ bool GetValidCountRel(const Array<Type>& types,
 
   std::vector<IndexExpr> oshape({data->shape[0]});
   std::vector<Type> fields;
-  fields.push_back(TensorTypeNode::make(oshape, DataType::Int(32)));
-  fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
+  fields.push_back(TensorType(oshape, DataType::Int(32)));
+  fields.push_back(TensorType(data->shape, data->dtype));
 
   // assign output type
   reporter->Assign(types[1], TupleType(Array<Type>(fields)));
@@ -95,9 +95,9 @@ bool NMSRel(const Array<Type>& types,
   // assign output type
   if (param->return_indices) {
     std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
-    reporter->Assign(types[2], TensorTypeNode::make(oshape, DataType::Int(32)));
+    reporter->Assign(types[2], TensorType(oshape, DataType::Int(32)));
   } else {
-    reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
+    reporter->Assign(types[2], TensorType(dshape, data->dtype));
   }
   return true;
 }
index 7b3533d..65efd04 100644 (file)
@@ -45,7 +45,7 @@ bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   // assign output type
   std::vector<IndexExpr> oshape(
       {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -96,7 +96,7 @@ bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   // assign output type
   std::vector<IndexExpr> oshape(
       {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
-  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
@@ -155,7 +155,7 @@ bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   std::vector<IndexExpr> oshape(
       {batch * proposal_attrs->rpn_post_nms_top_n, 5});
-  reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
+  reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype));
   return true;
 }
 
index 9964a82..5a59a74 100644 (file)
@@ -56,7 +56,7 @@ bool YoloReorgRel(const Array<Type>& types,
   oshape[1] = oshape[1] * param->stride * param->stride;
   oshape[2] = indexdiv(oshape[2], param->stride);
   oshape[3] = indexdiv(oshape[3], param->stride);
-  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
   return true;
 }
 
index fc7f820..d816760 100644 (file)
  * \file de_duplicate.cc
  * \brief Use a fresh Id for every Var to make the result well-formed.
  */
-
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/pattern_functor.h>
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
index 8dece3f..b274460 100644 (file)
  * \brief Add an abstraction over constructors and/or global variables bound to a function.
  *
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/expr_functor.h>
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
index 20958ab..7d94d4e 100644 (file)
@@ -21,7 +21,7 @@
  * \file ad.cc
  * \brief API for Automatic Differentiation for the Relay IR.
  */
-
+#include <tvm/ir/type_functor.h>
 #include <tvm/tir/lowered_func.h>
 #include <tvm/top/operation.h>
 #include <tvm/relay/expr_functor.h>
@@ -30,7 +30,6 @@
 #include "pattern_util.h"
 #include "pass_util.h"
 #include "let_list.h"
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
@@ -265,7 +264,7 @@ TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
 struct ReverseADType : TypeMutator {
   Type VisitType_(const TensorTypeNode* ttn) final {
     Type t = GetRef<Type>(ttn);
-    return TupleType({t, RefTypeNode::make(t)});
+    return TupleType({t, RelayRefType(t)});
   }
 };
 
index 55fd78a..d43059c 100644 (file)
@@ -31,9 +31,9 @@
  * We check this by ensuring the `dtype` field of a Tensor always
  * contains a data type such as `int`, `float`, `uint`.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/ir/error.h>
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
@@ -107,9 +107,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
     return Kind::kType;
   }
 
-  Kind VisitType_(const RefTypeNode* op) override {
+  Kind VisitType_(const RelayRefTypeNode* op) override {
     // ref types should only contain normal types
-    RefType rt = GetRef<RefType>(op);
+    RelayRefType rt = GetRef<RelayRefType>(op);
     CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
     return Kind::kType;
   }
index e9e37d2..37ce348 100644 (file)
  *
  * These assumptions do not affect the correctness of the algorithm, however.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/transform.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/interpreter.h>
-#include "../ir/type_functor.h"
 #include "pass_util.h"
 #include "let_list.h"
 
@@ -863,7 +863,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
               subst.Set(func->type_params[i], type_args[i]);
             }
             for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
-              subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
+              subst.Set(func->type_params[i], IncompleteType(kType));
             }
             return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
           } else {
index 41a5a8e..2441f6e 100644 (file)
@@ -48,9 +48,9 @@ bool SimulatedQuantizeRel(const Array<Type>& types,
   CHECK(data != nullptr);
   CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
 
-  reporter->Assign(types[1], TensorTypeNode::make({}, DataType::Float(32)));    // dom_scale
-  reporter->Assign(types[2], TensorTypeNode::make({}, DataType::Float(32)));    // clip_min
-  reporter->Assign(types[3], TensorTypeNode::make({}, DataType::Float(32)));    // clip_max
+  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));    // dom_scale
+  reporter->Assign(types[2], TensorType({}, DataType::Float(32)));    // clip_min
+  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));    // clip_max
   reporter->Assign(types[4], types[0]);                               // output
   return true;
 }
index f88a7c9..293d696 100644 (file)
  * All cases in the transform must return via the mcont,
  * wheter directly invoking it, or indirectly by recursion.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
-#include "../ir/type_functor.h"
 #include "let_list.h"
 #include "pass_util.h"
 
index a513f3e..ed5f91a 100644 (file)
@@ -37,7 +37,7 @@
  * If we can not infer a type or there are conflicting typing
  * constraints we will trigger an error.
  */
-
+#include <tvm/ir/type_functor.h>
 #include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
@@ -45,7 +45,6 @@
 #include <tvm/relay/transform.h>
 #include "./pass_util.h"
 #include "type_solver.h"
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
@@ -180,7 +179,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     if (op->type_annotation.defined()) {
       return op->type_annotation;
     } else {
-      return IncompleteTypeNode::make(Kind::kType);
+      return IncompleteType(Kind::kType);
     }
   }
 
@@ -215,7 +214,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
           EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
     }
     Type tuple_type = GetType(op->tuple);
-    Type rtype = IncompleteTypeNode::make(Kind::kType);
+    Type rtype = IncompleteType(Kind::kType);
     auto attrs = make_object<TupleGetItemAttrs>();
     attrs->index = op->index;
     solver_.AddConstraint(TypeRelation(
@@ -233,7 +232,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     // we can expect a certain number of arguments
     Array<Type> unknown_args;
     for (size_t i = 0; i < td->type_vars.size(); i++) {
-      unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+      unknown_args.push_back(IncompleteType(Kind::kType));
     }
     Type expected = TypeCall(con->constructor->belong_to, unknown_args);
     Type unified = Unify(t, expected, GetRef<ObjectRef>(con));
@@ -275,7 +274,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     // we can expect a certain number of arguments
     Array<Type> unknown_args;
     for (size_t i = 0; i < tup->patterns.size(); i++) {
-      unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+      unknown_args.push_back(IncompleteType(Kind::kType));
     }
     Type expected = TupleType(unknown_args);
     Type unified = Unify(t, expected, GetRef<ObjectRef>(tup));
@@ -302,7 +301,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     for (const auto& c : op->clauses) {
       VisitPattern(c->lhs, dtype);
     }
-    Type rtype = IncompleteTypeNode::make(Kind::kType);
+    Type rtype = IncompleteType(Kind::kType);
     for (const auto& c : op->clauses) {
       rtype = this->Unify(rtype,
                           GetType(c->rhs),
@@ -336,7 +335,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   Type VisitExpr_(const LetNode* let) final {
     // if the definition is a function literal, permit recursion
     bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
-    Type let_type = IncompleteTypeNode::make(Kind::kType);
+    Type let_type = IncompleteType(Kind::kType);
 
     if (is_functional_literal) {
       let_type = GetType(let->var);
@@ -362,7 +361,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     // that is a rank-0 boolean tensor.
     Type cond_type = this->GetType(ite->cond);
     this->Unify(cond_type,
-                TensorTypeNode::Scalar(tvm::DataType::Bool()),
+                TensorType::Scalar(tvm::DataType::Bool()),
                 ite->cond);
     Type checked_true = this->GetType(ite->true_branch);
     Type checked_false = this->GetType(ite->false_branch);
@@ -385,7 +384,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     for (size_t i = 0; i < op->type_params.size(); ++i) {
       if (!op->type_params[i].same_as(rel->args[i])) return Type();
     }
-    Type rtype = IncompleteTypeNode::make(Kind::kType);
+    Type rtype = IncompleteType(Kind::kType);
     arg_types.push_back(rtype);
     // we can do simple replacement here
     solver_.AddConstraint(TypeRelation(
@@ -404,7 +403,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     }
 
     for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
-      subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType));
+      subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
     }
 
     Type ret_type = fn_ty->ret_type;
@@ -415,7 +414,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     // This is a temporary work around to check recursive functions whose
     // return type is not yet known.
     if (!ret_type.defined()) {
-      ret_type = IncompleteTypeNode::make(Kind::kType);
+      ret_type = IncompleteType(Kind::kType);
     }
 
     Type inst_ty = FuncType(fn_ty->arg_types,
@@ -433,7 +432,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
     Array<Type> type_args;
     for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
-      type_args.push_back(IncompleteTypeNode::make(Kind::kType));
+      type_args.push_back(IncompleteType(Kind::kType));
     }
     return InstantiateFuncType(fn_ty, type_args);
   }
@@ -466,7 +465,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     // incomplete type => it must be a function taking the arg types
     // with an unknown return type
     if (inc_ty_node != nullptr) {
-      Type ret_type = IncompleteTypeNode::make(Kind::kType);
+      Type ret_type = IncompleteType(Kind::kType);
       Type func_type = FuncType(arg_types, ret_type, {}, {});
       Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
       fn_ty_node = unified.as<FuncTypeNode>();
@@ -562,18 +561,18 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   }
 
   Type VisitExpr_(const RefCreateNode* op) final {
-    return RefTypeNode::make(GetType(op->value));
+    return RelayRefType(GetType(op->value));
   }
 
   Type VisitExpr_(const RefReadNode* op) final {
-    Type it = IncompleteTypeNode::make(Kind::kType);
-    this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
+    Type it = IncompleteType(Kind::kType);
+    this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefRead>(op));
     return it;
   }
 
   Type VisitExpr_(const RefWriteNode* op) final {
-    Type it = IncompleteTypeNode::make(Kind::kType);
-    this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
+    Type it = IncompleteType(Kind::kType);
+    this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefWrite>(op));
     this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
     return TupleType::Empty();
   }
index ec6d721..0ad43d0 100644 (file)
  * \file type_solver.cc
  * \brief Type solver implementations.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/tir/op.h>
 #include <string>
 #include <memory>
 #include <tuple>
 #include <utility>
 #include "type_solver.h"
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
@@ -270,7 +270,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       return Type(nullptr);
     }
 
-    return TensorTypeNode::make(shape, tt1->dtype);
+    return TensorType(shape, tt1->dtype);
   }
 
   Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
@@ -312,7 +312,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     }
 
     for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
-      subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
+      subst_map.Set(op->type_params[i], IncompleteType(kType));
     }
 
     FuncType ft = FuncType(op->arg_types,
@@ -343,12 +343,12 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     return FuncType(arg_types, ret_type, ft2->type_params, type_constraints);
   }
 
-  Type VisitType_(const RefTypeNode* op, const Type& tn) final {
-    const auto* rtn = tn.as<RefTypeNode>();
+  Type VisitType_(const RelayRefTypeNode* op, const Type& tn) final {
+    const auto* rtn = tn.as<RelayRefTypeNode>();
     if (!rtn) {
       return Type(nullptr);
     }
-    return RefTypeNode::make(Unify(op->value, rtn->value));
+    return RelayRefType(Unify(op->value, rtn->value));
   }
 
   Type VisitType_(const TypeCallNode* op, const Type& tn) override {
@@ -690,7 +690,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
       } else if (name == "AddConstraint") {
         return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
             Expr e = VarNode::make("dummy_var",
-              IncompleteTypeNode::make(Kind::kType));
+              IncompleteType(Kind::kType));
             return solver->AddConstraint(c, e);
           });
       } else {
index e45b15a..b341966 100644 (file)
  *
  * \brief Utility functions for Relay.
  */
+#include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/pattern_functor.h>
 #include "pass_util.h"
-#include "../ir/type_functor.h"
 
 namespace tvm {
 namespace relay {
index e01a47d..ee67997 100644 (file)
@@ -52,7 +52,7 @@ bool DequantizeRel(const Array<Type>& types,
 
   const Array<tvm::PrimExpr> oshape = data->shape;
   // assign output type, output will always be float 32.
-  reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
+  reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
   return true;
 }
 
index f53d2c5..e2472c6 100644 (file)
@@ -63,7 +63,7 @@ bool QuantizeRel(const Array<Type>& types,
         out_dtype == DataType::Int(32))
       << "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
   // assign output type
-  reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[3], TensorType(oshape, out_dtype));
   return true;
 }
 
index 2686965..cf5b313 100644 (file)
@@ -197,7 +197,7 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         out_dtype == DataType::UInt(8) ||
         out_dtype == DataType::Int(32))
       << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
-  reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype));
+  reporter->Assign(types[5], TensorType(oshape, out_dtype));
   return true;
 }
 
index 2d4bcb4..6362421 100644 (file)
@@ -176,7 +176,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons
   const auto tensor_dtype = tensor_type->dtype;
   CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
   if (tensor_type->shape.size() != 0) {
-    reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype));
+    reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype));
   }
 }
 
index 5ddb6d4..9d954ea 100644 (file)
@@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("test.sch")
 
 TEST(Relay, BuildModule) {
   using namespace tvm;
-  auto tensor_type = relay::TensorTypeNode::make({2, 3}, DataType::Float(32));
+  auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
   auto a = relay::VarNode::make("a", tensor_type);
   auto b = relay::VarNode::make("b", tensor_type);
   auto add_op = relay::Op::Get("add");
index 68d5d0d..dcb4443 100644 (file)
@@ -26,7 +26,7 @@
 
 TEST(Relay, SelfReference) {
   using namespace tvm;
-  auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool());
+  auto tensor_type = relay::TensorType({}, DataType::Bool());
   auto x = relay::VarNode::make("x", relay::Type());
   auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
   CHECK(f->IsInstance<BaseFuncNode>());
index d429554..5593f07 100644 (file)
@@ -36,7 +36,7 @@ TVM_REGISTER_GLOBAL("schedule")
 
 TEST(Relay, Sequential) {
   using namespace tvm;
-  auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, DataType::Float(32));
+  auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
   auto c_data =
       tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
 
index 7d3c809..55f5c97 100644 (file)
@@ -51,7 +51,7 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue*
 
 TEST(MicroStandaloneRuntime, BuildModule) {
   using namespace tvm;
-  auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
+  auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32));
   auto a = relay::VarNode::make("a", tensor_type);
   auto b = relay::VarNode::make("b", tensor_type);
   auto add_op = relay::Op::Get("add");
index 18250d0..7381a07 100644 (file)
@@ -793,7 +793,7 @@ def test_forward_layer_norm():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, gamma, beta)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
     verify((2, 5))
     verify((2, 5), axis=0)
     verify((2, 5, 6))
@@ -809,7 +809,7 @@ def test_forward_one_hot():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x.astype("float32"))
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
     verify((3,), 3, 1, 0, "int32")
     verify((3,), 3, 1.0, 0.0, "float32")
     verify((2, 2), 5, 2, -2, "int32")
@@ -898,7 +898,7 @@ def test_forward_deconvolution():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
                 op_res = intrp.evaluate()(x, weight, bias)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
 
     verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
     verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)