* [REFACTOR][TYPE] Finish move all types to IR.
- Move definition of Ref and TensorType to ir
- Move type_functor.h to public header.
- Rename RefType -> RelayRefType for clarity.
* Add atol
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/tensor_type.h
+ * \brief Polymorphic tensor types.
+ */
+#ifndef TVM_IR_TENSOR_TYPE_H_
+#define TVM_IR_TENSOR_TYPE_H_
+
+#include <tvm/ir/type.h>
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+/*!
+ * \brief Base of all Tensor types
+ * This container can hold TensorType or GenericTensorType.
+ * \sa BaseTensorType, TensorTypeNode
+ */
+class BaseTensorTypeNode : public TypeNode {
+ public:
+ static constexpr const char* _type_key = "relay.BaseTensorType";
+ TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to BaseTensorTypeNode.
+ * \sa BaseTensorTypeNode.
+ */
+class BaseTensorType : public Type {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
+};
+
+/*!
+ * \brief This is the most commonly used type in relay.
+ * TensorType have a fixed dimension, data type.
+ *
+ * The elements of shape can be either IntImm(constant integer),
+ * or any symbolic integer expression.
+ * The symbolic integer allows generic shape inference in certain cases.
+ * \sa TensorType
+ */
+class TensorTypeNode : public BaseTensorTypeNode {
+ public:
+ /*!
+ * \brief The shape of the tensor,
+ * represented by PrimExpr(tvm::Expr).
+ */
+ Array<PrimExpr> shape;
+ /*! \brief The content data type */
+ DataType dtype;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("shape", &shape);
+ v->Visit("dtype", &dtype);
+ v->Visit("span", &span);
+ }
+
+ /*! \brief Return product of elements in the shape.
+ * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
+ */
+ TVM_DLL PrimExpr Size() const;
+
+ static constexpr const char* _type_key = "relay.TensorType";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
+};
+
+/*!
+ * \brief Managed reference to TensorTypeNode.
+ * \sa TensorTypeNode.
+ */
+class TensorType : public Type {
+ public:
+ /*!
+ * \brief Constructor.
+ * \param shape The shape of the tensor.
+ * \param dtype The runtime dtype of the tensor's elements.
+ */
+ TVM_DLL TensorType(Array<PrimExpr> shape, DataType dtype);
+
+ /*!
+ * \brief Construct an scalar containing elements of dtype.
+ * \param dtype The runtime dtype of the tensor's elements.
+ * \return THe constructed type.
+ */
+ TVM_DLL static TensorType Scalar(DataType dtype);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
+};
+
+// The following fields contains advanced typing
+// Only keep the class name and reserved for future usage.
+class GenericTensorType;
+// stores a DataType.
+class GenericDataType;
+// stores a DataType.
+class GenericShape;
+
+} // namespace tvm
+#endif // TVM_IR_TENSOR_TYPE_H_
TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode);
};
+/*!
+ * \brief Intermediate values that is used to indicate incomplete type
+ * during type inference.
+ *
+ * If we view the type relations as "computational graph of types",
+ * then IncompleteType represents intermediate values of the graph,
+ * TypeVar represents the input to the graph.
+ *
+ * \sa IncompleteType
+ */
+class IncompleteTypeNode : public TypeNode {
+ public:
+ /*! \brief kind of the type. */
+ TypeKind kind;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("kind", &kind);
+ v->Visit("span", &span);
+ }
+
+ static constexpr const char* _type_key = "relay.IncompleteType";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to IncompleteTypeNode.
+ * \sa IncompleteTypeNode
+ */
+class IncompleteType : public Type {
+ public:
+ /*!
+ * \brief Constructor.
+ * \param kind kind of the type.
+ */
+ TVM_DLL explicit IncompleteType(TypeKind kind);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
+};
+
+
+/*!
+ * \brief Reference Type High-level Relay IR.
+ *
+ * \sa RelayRefType.
+ */
+class RelayRefTypeNode : public TypeNode {
+ public:
+ /*! \brief The type of value in the Reference. */
+ Type value;
+
+ RelayRefTypeNode() {}
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("value", &value);
+ v->Visit("span", &span);
+ }
+
+ static constexpr const char* _type_key = "relay.RefType";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
+};
+
+/*!
+ * \brief Managed reference to RelayRefTypeNode.
+ * \sa RelayRefTypeNode.
+ */
+class RelayRefType : public Type {
+ public:
+ TVM_DLL explicit RelayRefType(Type value);
+ TVM_DEFINE_OBJECT_REF_METHODS(RelayRefType, Type, RelayRefTypeNode);
+};
} // namespace tvm
#endif // TVM_IR_TYPE_H_
*/
/*!
- * \file type_functor.h
+ * \file tvm/ir/type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
-#ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
-#define TVM_RELAY_IR_TYPE_FUNCTOR_H_
+#ifndef TVM_IR_TYPE_FUNCTOR_H_
+#define TVM_IR_TYPE_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <tvm/relay/expr.h>
#include <utility>
namespace tvm {
-namespace relay {
template <typename FType>
class TypeFunctor;
// functions to be overriden.
-#define TYPE_FUNCTOR_DEFAULT \
+#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
-#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
+#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.get()), \
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
- virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+ virtual R VisitType_(const RelayRefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+ virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
static FType InitVTable() {
FType vtable;
// Set dispatch
- RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
- RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TypeVarNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(RelayRefTypeNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
+ TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
return vtable;
}
};
+#undef TVM_TYPE_FUNCTOR_DISPATCH
+
/*!
* \brief A type visitor that recursively visit types.
*/
-class TypeVisitor : public TypeFunctor<void(const Type& n)> {
+class TVM_DLL TypeVisitor :
+ public TypeFunctor<void(const Type& n)> {
public:
void VisitType_(const TypeVarNode* op) override;
void VisitType_(const IncompleteTypeNode* op) override;
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
- void VisitType_(const RefTypeNode* op) override;
+ void VisitType_(const RelayRefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
+ void VisitType_(const PrimTypeNode* op) override;
};
-// Mutator that transform a type to another one.
-class TypeMutator : public TypeFunctor<Type(const Type& n)> {
+/*!
+ * \brief TypeMutator that mutates expressions.
+ */
+class TVM_DLL TypeMutator :
+ public TypeFunctor<Type(const Type& n)> {
public:
Type VisitType(const Type& t) override;
Type VisitType_(const TypeVarNode* op) override;
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
- Type VisitType_(const RefTypeNode* op) override;
+ Type VisitType_(const RelayRefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
+ Type VisitType_(const PrimTypeNode* op) override;
private:
Array<Type> MutateArray(Array<Type> arr);
*/
Type Bind(const Type& type, const Map<TypeVar, Type>& args_map);
-} // namespace relay
} // namespace tvm
-#endif // TVM_RELAY_IR_TYPE_FUNCTOR_H_
+#endif // TVM_IR_TYPE_FUNCTOR_H_
#define TVM_RELAY_TYPE_H_
#include <tvm/ir/type.h>
+#include <tvm/ir/tensor_type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
+using IncompleteType = tvm::IncompleteType;
+using IncompleteTypeNode = tvm::IncompleteTypeNode;
+using RelayRefType = tvm::RelayRefType;
+using RelayRefTypeNode = tvm::RelayRefTypeNode;
+using TensorType = tvm::TensorType;
+using TensorTypeNode = tvm::TensorTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation;
using TypeReporter = tvm::TypeReporter;
using TypeReporterNode = tvm::TypeReporterNode;
-/*!
- * \brief Base of all Tensor types
- * This container can hold TensorType or GenericTensorType.
- */
-class BaseTensorTypeNode : public TypeNode {
- public:
- static constexpr const char* _type_key = "relay.BaseTensorType";
- TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode);
-};
-
-class BaseTensorType : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(BaseTensorType, Type, BaseTensorTypeNode);
-};
-
-/*!
- * \brief This is the most commonly used type in relay.
- * TensorType have a fixed dimension, data type.
- *
- * The elements of shape can be either IntImm(constant integer),
- * or any symbolic integer expression.
- * The symbolic integer allows generic shape inference in certain cases.
- * \sa TensorTypeNode The container class of TensorType.
- */
-class TensorType;
-/*! \brief TensorType container node */
-class TensorTypeNode : public BaseTensorTypeNode {
- public:
- /*!
- * \brief The shape of the tensor,
- * represented by IndexExpr(tvm::Expr).
- */
- Array<IndexExpr> shape;
- /*! \brief The content data type */
- DataType dtype;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("shape", &shape);
- v->Visit("dtype", &dtype);
- v->Visit("span", &span);
- }
-
- /*! \brief Return product of elements in the shape.
- * \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
- */
- TVM_DLL IndexExpr Size() const;
-
- TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
-
- /*! \brief Construct an scalar containing elements of dtype. */
- TVM_DLL static TensorType Scalar(DataType dtype);
-
- static constexpr const char* _type_key = "relay.TensorType";
- TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
-};
-
-class TensorType : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode);
-};
-
-/*!
- * \brief IncompleteType.
- * This is intermediate values that is used during type inference.
- *
- * If we view the type relations as "computational graph of types",
- * then IncompleteType represents intermediate values of the graph,
- * TypeVar represents the input to the graph.
- */
-class IncompleteType;
-
-/*! \brief IncompleteType container node */
-class IncompleteTypeNode : public TypeNode {
- public:
- Kind kind;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("kind", &kind);
- v->Visit("span", &span);
- }
-
- TVM_DLL static IncompleteType make(Kind kind);
-
- static constexpr const char* _type_key = "relay.IncompleteType";
- TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
-};
-
-class IncompleteType : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode);
-};
-
-/*!
- * \brief The type of reference values.
- */
-class RefType;
-/*!
- * \brief Reference Type in relay.
- */
-class RefTypeNode : public TypeNode {
- public:
- /*! \brief The type of value in the Reference. */
- Type value;
-
- RefTypeNode() {}
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("value", &value);
- v->Visit("span", &span);
- }
-
- TVM_DLL static RefType make(Type value);
-
- static constexpr const char* _type_key = "relay.RefType";
- TVM_DECLARE_FINAL_OBJECT_INFO(RefTypeNode, TypeNode);
-};
-
-class RefType : public Type {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
-};
-
-// The following fields contains advanced typing
-// Only keep the class name and reserved for future usage.
-class GenericTensorType;
-// stores a DataType.
-class GenericDataType;
-// stores a DataType.
-class GenericShape;
-
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TYPE_H_
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
-template <typename RefType, typename ObjectType>
-inline RefType GetRef(const ObjectType* ptr);
+template <typename RelayRefType, typename ObjectType>
+inline RelayRefType GetRef(const ObjectType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
friend class TVMArgsSetter;
friend class TVMRetValue;
friend class TVMArgValue;
- template <typename RefType, typename ObjType>
- friend RefType GetRef(const ObjType* ptr);
+ template <typename RelayRefType, typename ObjType>
+ friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};
}
}
-template <typename RefType, typename ObjType>
-inline RefType GetRef(const ObjType* ptr) {
- static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
+template <typename RelayRefType, typename ObjType>
+inline RelayRefType GetRef(const ObjType* ptr) {
+ static_assert(std::is_base_of<typename RelayRefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
- return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
+ return RelayRefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
template <typename BaseType, typename ObjType>
*/
/*!
- * \file src/tvm/ir/type.cc
+ * \file src/tvm/ir/tensor_type.cc
* \brief The type system AST nodes of Relay.
*/
-#include <tvm/relay/type.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/ir/tensor_type.h>
#include <tvm/tir/op.h>
namespace tvm {
-namespace relay {
using tvm::NodePrinter;
using namespace tvm::runtime;
-TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
+TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
ObjectPtr<TensorTypeNode> n = make_object<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
- return TensorType(n);
+ data_ = std::move(n);
}
-TensorType TensorTypeNode::Scalar(DataType dtype) {
- return TensorTypeNode::make({}, dtype);
+TensorType TensorType::Scalar(DataType dtype) {
+ return TensorType({}, dtype);
}
-IndexExpr TensorTypeNode::Size() const {
+PrimExpr TensorTypeNode::Size() const {
if (shape.size() == 0) {
return tir::make_const(DataType::Int(64), 1);
}
- IndexExpr size = shape[0];
+ PrimExpr size = shape[0];
for (size_t i = 1; i < shape.size(); ++i) {
size *= shape[i];
}
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_GLOBAL("relay._make.TensorType")
-.set_body_typed(TensorTypeNode::make);
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype) {
+ return TensorType(shape, dtype);
+});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
-IncompleteType IncompleteTypeNode::make(Kind kind) {
- auto n = make_object<IncompleteTypeNode>();
- n->kind = std::move(kind);
- return IncompleteType(n);
-}
-
-TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
-
-TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
-.set_body_typed([](int kind) {
- return IncompleteTypeNode::make(static_cast<Kind>(kind));
- });
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
- p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
- });
-
-
-RefType RefTypeNode::make(Type value) {
- ObjectPtr<RefTypeNode> n = make_object<RefTypeNode>();
- n->value = std::move(value);
- return RefType(n);
-}
-
-TVM_REGISTER_GLOBAL("relay._make.RefType")
-.set_body_typed(RefTypeNode::make);
-
-TVM_REGISTER_NODE_TYPE(RefTypeNode);
-
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<RefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
- auto* node = static_cast<const RefTypeNode*>(ref.get());
- p->stream << "RefTypeNode(" << node->value << ")";
-});
-
-TVM_REGISTER_GLOBAL("relay._make.Any")
-.set_body_typed([]() { return Any::make(); });
-
-} // namespace relay
} // namespace tvm
<< node->type_constraints << ")";
});
+
TupleType::TupleType(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
n->fields = std::move(fields);
p->stream << "TupleTypeNode(" << node->fields << ")";
});
+
+IncompleteType::IncompleteType(TypeKind kind) {
+ auto n = make_object<IncompleteTypeNode>();
+ n->kind = std::move(kind);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
+
+TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
+.set_body_typed([](int kind) {
+ return IncompleteType(static_cast<TypeKind>(kind));
+ });
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
+ p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
+ });
+
+
+RelayRefType::RelayRefType(Type value) {
+ ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
+ n->value = std::move(value);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("relay._make.RefType")
+.set_body_typed([](Type value) {
+ return RelayRefType(value);
+});
+
+TVM_REGISTER_NODE_TYPE(RelayRefTypeNode);
+
+TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
+.set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
+ auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
+ p->stream << "RelayRefTypeNode(" << node->value << ")";
+});
+
} // namespace tvm
* \file type_functor.cc
* \brief Implementations of type functors.
*/
+#include <tvm/ir/type_functor.h>
#include <utility>
-#include "type_functor.h"
namespace tvm {
-namespace relay {
void TypeVisitor::VisitType_(const TypeVarNode* op) {
}
}
}
-void TypeVisitor::VisitType_(const RefTypeNode* op) {
+void TypeVisitor::VisitType_(const RelayRefTypeNode* op) {
this->VisitType(op->value);
}
}
}
+void TypeVisitor::VisitType_(const PrimTypeNode* op) {
+}
+
Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
}
}
}
-Type TypeMutator::VisitType_(const RefTypeNode* op) {
- return RefTypeNode::make(this->VisitType(op->value));
+Type TypeMutator::VisitType_(const RelayRefTypeNode* op) {
+ return RelayRefType(this->VisitType(op->value));
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
return GetRef<Type>(op);
}
+Type TypeMutator::VisitType_(const PrimTypeNode* op) {
+ return GetRef<Type>(op);
+}
+
// Implements bind.
class TypeBinder : public TypeMutator {
public:
return TypeBinder(args_map).VisitType(type);
}
-} // namespace relay
} // namespace tvm
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
-#include "compile_engine.h"
-
+#include <tvm/ir/type_functor.h>
#include <tvm/top/schedule.h>
#include <tvm/top/operation.h>
#include <tvm/top/schedule_pass.h>
#include <functional>
#include <vector>
#include <unordered_map>
-#include "../ir/type_functor.h"
+
+#include "compile_engine.h"
namespace tvm {
namespace relay {
// TODO(@icemelon): Support recursive tuple
Type call_node_type = call_node->checked_type();
if (const auto* tt = call_node->checked_type().as<TensorTypeNode>()) {
- call_node_type = TensorTypeNode::make(GetShape(tt->shape), tt->dtype);
+ call_node_type = TensorType(GetShape(tt->shape), tt->dtype);
} else if (const auto* tuple_t = call_node->checked_type().as<TupleTypeNode>()) {
std::vector<Type> new_fields;
for (auto field : tuple_t->fields) {
if (const auto* tt = field.as<TensorTypeNode>()) {
- new_fields.push_back(TensorTypeNode::make(GetShape(tt->shape), tt->dtype));
+ new_fields.push_back(TensorType(GetShape(tt->shape), tt->dtype));
} else {
new_fields.push_back(field);
}
if (is_dyn) {
auto sh = out_shapes[i];
auto tt = Downcast<TensorType>(rtype->fields[i]);
- fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
+ fields.push_back(fset_output(i, TensorType(sh, tt->dtype)));
} else {
fields.push_back(fset_output(i, rtype->fields[i]));
}
CHECK_EQ(out_shapes.size(), 1);
auto sh = out_shapes[0];
auto tt = Downcast<TensorType>(ret_type);
- out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
+ out_tensor = fset_output(0, TensorType(sh, tt->dtype));
} else {
out_tensor = fset_output(0, ret_type);
}
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
-#include "type_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
}
}
- bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
- if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
+ bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
+ if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
return TypeEqual(lhs->value, rhs->value);
}
return false;
tvm::IntImm(DataType::Int(32), data->shape[i]));
}
- return TensorTypeNode::make(shape, dtype);
+ return TensorType(shape, dtype);
}
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
- : IncompleteTypeNode::make(Kind::kType);
+ : IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
- : IncompleteTypeNode::make(Kind::kType);
+ : IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
return FunctionSetAttr(func, name, ref);
});
+TVM_REGISTER_GLOBAL("relay._make.Any")
+.set_body_typed([]() { return Any::make(); });
+
} // namespace relay
} // namespace tvm
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
-#include "type_functor.h"
namespace tvm {
namespace relay {
* \file src/tvm/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/attrs.h>
-#include "type_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
return hash;
}
- size_t VisitType_(const RefTypeNode* rtn) final {
- size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
+ size_t VisitType_(const RelayRefTypeNode* rtn) final {
+ size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
hash = Combine(hash, TypeHash(rtn->value));
return hash;
}
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
-
+#include <tvm/ir/type_functor.h>
#include <tvm/node/serialization.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
-#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../ir/attr_functor.h"
return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
}
- Doc VisitType_(const RefTypeNode* node) final {
+ Doc VisitType_(const RelayRefTypeNode* node) final {
Doc doc;
return doc << "ref(" << Print(node->value) << ")";
}
<< types[0];
return false;
}
- reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
+ reporter->Assign(types[1], TensorType(data->shape, param->dtype));
return true;
}
out_shape.push_back(param->k);
}
}
- auto values_ty = TensorTypeNode::make(out_shape, data->dtype);
- auto indices_ty = TensorTypeNode::make(out_shape, param->dtype);
+ auto values_ty = TensorType(out_shape, data->dtype);
+ auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
// assign output type
reporter->Assign(types[1],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+ TensorType(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
reporter->Assign(types[3],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+ TensorType(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
for (auto i = 0u; i < dims; i++) {
out_shape.push_back(tvm::Integer(sh[i]));
}
- alloc_type = TensorTypeNode::make(out_shape, alloc_attrs->dtype);
+ alloc_type = TensorType(out_shape, alloc_attrs->dtype);
} else {
CHECK(alloc_attrs->assert_shape.defined())
<< "the assert_shape must be set when const_shape is not";
- alloc_type = TensorTypeNode::make(alloc_attrs->assert_shape, alloc_attrs->dtype);
+ alloc_type = TensorType(alloc_attrs->assert_shape, alloc_attrs->dtype);
return true;
}
shape_func_ins.push_back(in_type);
} else {
auto shape = RankShape(in_type->shape);
- shape_func_ins.push_back(TensorTypeNode::make(shape, DataType::Int(64)));
+ shape_func_ins.push_back(TensorType(shape, DataType::Int(64)));
}
}
for (auto out_type : out_types) {
auto rank_shape = RankShape(out_type->shape);
- shape_func_outs.push_back(TensorTypeNode::make(rank_shape, DataType::Int(64)));
+ shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64)));
}
auto input_type = TupleType(shape_func_ins);
out_shape.push_back(bits);
}
- reporter->Assign(types[1], TensorTypeNode::make(out_shape, pack_type));
+ reporter->Assign(types[1], TensorType(out_shape, pack_type));
return true;
}
DataType out_dtype = param->out_dtype;
oshape = trans_in_layout.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
}
// Assign output type.
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
channels = param->channels;
// assign result to reporter
- reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+ reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
channels = param->channels;
// assign result to reporter
- reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
+ reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
data->shape[1],
};
- reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
- reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
return true;
}
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
- reporter->Assign(types[2], TensorTypeNode::make(wshape, data->dtype));
+ reporter->Assign(types[2], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
// infer offset shape
Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
oshape[2], oshape[3]});
- reporter->Assign(types[1], TensorTypeNode::make(offset_shape, data->dtype));
+ reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
- reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
weight_dtype = weight->dtype;
}
// assign result to reporter
- reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
weight_dtype = weight->dtype;
}
// assign result to reporter
- reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
}
// assign result to reporter
- reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
<< "axis " << param->axis << " is out of range";
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(
+ reporter->Assign(types[1], TensorType(
{data->shape[axis]}, data->dtype));
reporter->Assign(types[2], types[0]);
return true;
Array<tvm::PrimExpr> oshape = buffer->shape;
- reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
+ reporter->Assign(types[2], TensorType(oshape, buffer->dtype));
return true;
}
// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
- reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));
+ reporter->Assign(types[1], TensorType(alpha_shape, data->dtype));
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
+ reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}
std::vector<IndexExpr> oshape({data->shape[0], target_dim});
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
// dropout returns the original tensor with dropout applied
// and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
- auto ret_type = TensorTypeNode::make(data->shape, data->dtype);
+ auto ret_type = TensorType(data->shape, data->dtype);
reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
return true;
}
auto axis_size = data->shape[axis];
// if we are using beta and gamma, they need to be of shape (dim,)
- reporter->Assign(types[1], TensorTypeNode::make({axis_size}, data->dtype));
- reporter->Assign(types[2], TensorTypeNode::make({axis_size}, data->dtype));
- reporter->Assign(types[3], TensorTypeNode::make({axis_size}, data->dtype));
- reporter->Assign(types[4], TensorTypeNode::make({axis_size}, data->dtype));
+ reporter->Assign(types[1], TensorType({axis_size}, data->dtype));
+ reporter->Assign(types[2], TensorType({axis_size}, data->dtype));
+ reporter->Assign(types[3], TensorType({axis_size}, data->dtype));
+ reporter->Assign(types[4], TensorType({axis_size}, data->dtype));
// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
std::vector<Type> fields;
- auto vec_ty = TensorTypeNode::make(Array<IndexExpr>({data->shape[axis]}),
+ auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}),
data->dtype);
- fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
+ fields.push_back(TensorType(data->shape, data->dtype));
fields.push_back(vec_ty);
fields.push_back(vec_ty);
reporter->Assign(types[5], TupleType(Array<Type>(fields)));
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
- reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
- reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
- reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
+ reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
+ reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
+ reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
- reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
- reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
- reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
+ reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
+ reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
+ reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
oshape.Set(2, y->shape[1]);
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
+ reporter->Assign(types[2], TensorType(oshape, x->dtype));
return true;
}
<< "x shape = " << x->shape << ", "
<< "y shape = " << y->shape;
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make({}, x->dtype));
+ reporter->Assign(types[2], TensorType({}, x->dtype));
return true;
}
// Assign output type
reporter->Assign(types[1],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
+ TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
// Assign output type
reporter->Assign(types[1],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape), data->dtype));
+ TensorType(layout_converter.BackwardShape(oshape), data->dtype));
return true;
}
// data dtype as the weight dtype. However if weight dtype is explicitly
// present we will use that.
auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype);
- reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
+ reporter->Assign(types[1], TensorType(wshape, weight_dtype));
oshape.Set((oshape.size() - 1), param->units);
} else {
if (weight == nullptr) return false;
out_dtype = data->dtype;
}
// assign output type
- reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
}
}
- reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
oshape.push_back(data->shape[i] + padding);
}
- reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
+ reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
}
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
oshape.Set(widx, 1);
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
oshape.Set(widx, output_width);
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
}
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
}
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
if (weight_data->shape.size() == 1) {
// CSR case.
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
- reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
Array<IndexExpr> oshape({
data->shape[0],
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
- reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[4], TensorType(oshape, data->dtype));
return true;
}
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
const auto* sparse_indptr = types[2].as<TensorTypeNode>();
std::vector<Type> output_types;
- output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype));
- output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
- output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
+ output_types.push_back(TensorType(sparse_data->shape, sparse_data->dtype));
+ output_types.push_back(TensorType(sparse_indices->shape, sparse_indices->dtype));
+ output_types.push_back(TensorType(sparse_indptr->shape, sparse_indptr->dtype));
reporter->Assign(types[3], TupleType(Array<Type>(output_types)));
return true;
// assign output type
reporter->Assign(types[1],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+ TensorType(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
// assign output type
reporter->Assign(types[1],
- TensorTypeNode::make(layout_converter.BackwardShape(oshape),
+ TensorType(layout_converter.BackwardShape(oshape),
data->dtype));
return true;
}
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
- reporter->Assign(types[1], TensorTypeNode::make(oshape, DataType::Int(32)));
+ reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
return true;
}
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
return false;
}
const auto* param = attrs.as<CastAttrs>();
- reporter->Assign(types[1], TensorTypeNode::make(
+ reporter->Assign(types[1], TensorType(
data->shape, param->dtype));
return true;
}
<< types[1];
return false;
}
- reporter->Assign(types[2], TensorTypeNode::make(data->shape, dtype_like->dtype));
+ reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype));
return true;
}
for (int i = pivot; i < ndim; ++i) {
oshape.emplace_back(data->shape[i]);
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
for (int i = axis; i < ndim; ++i) {
oshape.emplace_back(first->shape[i]);
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype));
+ reporter->Assign(types[1], TensorType(oshape, dtype));
return true;
}
for (int axis : int_axes) {
oshape.push_back(data->shape[axis]);
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
}
if (param->reverse) {
- reporter->Assign(types[1], TensorTypeNode::make(
+ reporter->Assign(types[1], TensorType(
Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
}
return true;
}
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible.";
}
- reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
+ reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
return true;
}
std::vector<IndexExpr> result_shape;
result_shape.push_back(Any::make());
result_shape.push_back(IntImm(DataType::Int(32), input_rank));
- reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
+ reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32)));
return true;
}
if (!param->axis.defined()) {
std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
oshape.emplace_back(data->shape[i]);
}
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
<< "Fill value should be a scalar but has dimension "
<< fill_value->shape.size() << ".";
- reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype));
+ reporter->Assign(types[1], TensorType(param->shape, out_dtype));
return true;
}
CHECK_EQ(types.size(), 1);
const InitOpAttrs* param = attrs.as<InitOpAttrs>();
- reporter->Assign(types[0], TensorTypeNode::make(param->shape, param->dtype));
+ reporter->Assign(types[0], TensorType(param->shape, param->dtype));
return true;
}
<< "The fill value should be a scalar but here it has dimension "
<< fill_value->shape.size() << ".";
- reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
+ reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}
reporter->Assign(types[0], types[1]);
reporter->Assign(types[1], types[2]);
- reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype));
+ reporter->Assign(types[2], TensorType({}, attrs->dtype));
if ((cstart = attrs->start.as<ConstantNode>()) &&
(cstop = attrs->stop.as<ConstantNode>()) &&
CHECK_GT(num_elem, 0)
<< "Invalid arange attributes (start, stop, step): " << attrs->start
<< ", " << attrs->stop << ", " << attrs->step;
- reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype));
+ reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype));
return true;
} else {
- reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype));
+ reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype));
return true;
}
}
for (int i = pivot + 1; i < ndim; ++i) {
oshape.emplace_back(data->shape[i]);
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
oshape.emplace_back(data_shape[i] * reps_shape[i]);
}
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
- reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype));
+ reporter->Assign(types[3], TensorType(x_shape, x->dtype));
return true;
}
}
}
}
- reporter->Assign(types[1], TensorTypeNode::make(result_shape, data->dtype));
+ reporter->Assign(types[1], TensorType(result_shape, data->dtype));
return true;
}
CHECK(ioattrs);
auto intt = types[0].as<TensorTypeNode>();
if (intt == nullptr) { return false; }
- auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
+ auto type = TensorType(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type);
return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
}
}
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
for (int i = 0; i < sections->value; ++i) {
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = indexdiv(oshape[axis], sections->value);
- auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+ auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = Downcast<IndexExpr>(indices[i]) - begin;
begin = Downcast<IndexExpr>(indices[i]);
- auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+ auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
}
CHECK(reporter->Assert(begin < data->shape[axis]))
<< "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
oshape[axis] = data->shape[axis] - begin;
- auto vec_type = TensorTypeNode::make(oshape, data->dtype);
+ auto vec_type = TensorType(oshape, data->dtype);
fields.push_back(vec_type);
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
}
}
}
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
const auto& out_shape = layout_converter.ForwardShape(data->shape);
- reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
+ reporter->Assign(types[1], TensorType(out_shape, data->dtype));
return true;
}
oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
Array<IndexExpr> valid_length_shape;
CHECK(param->axis == 0 || param->axis == 1);
valid_length_shape.push_back(data->shape[1 - param->axis]);
- reporter->Assign(types[1], TensorTypeNode::make(valid_length_shape, valid_length->dtype));
+ reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype));
reporter->Assign(types[2], types[0]);
return true;
}
}
}
- reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
+ reporter->Assign(types[3], TensorType(oshape, param->dtype));
return true;
}
concat_dim = Any::make();
}
- auto rtype = TensorTypeNode::make(oshape, dtype);
+ auto rtype = TensorType(oshape, dtype);
reporter->Assign(types[1], rtype);
return true;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
- reporter->Assign(types[1], TensorTypeNode::make(rank_shape, param->dtype));
+ reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}
CHECK(tt != nullptr);
const auto* param = attrs.as<NdarraySizeAttrs>();
CHECK(param != nullptr);
- reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
+ reporter->Assign(types[1], TensorType({1}, param->dtype));
return true;
}
for (; i <= max_ndim; ++i) {
oshape.push_back(rshape[max_ndim - i]);
}
- return TensorTypeNode::make(Array<IndexExpr>(
+ return TensorType(Array<IndexExpr>(
oshape.rbegin(), oshape.rend()), output_dtype);
}
{1, in_height * in_width * (num_sizes + num_ratios - 1), 4});
// assign output type
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
std::vector<IndexExpr> oshape0({cls_shape[0], anchor_shape[1], 6});
std::vector<IndexExpr> oshape1({cls_shape[0]});
std::vector<Type> fields;
- fields.push_back(TensorTypeNode::make(oshape0, cls_prob->dtype));
- fields.push_back(TensorTypeNode::make(oshape1, DataType::Int(32)));
+ fields.push_back(TensorType(oshape0, cls_prob->dtype));
+ fields.push_back(TensorType(oshape1, DataType::Int(32)));
// assign output type
reporter->Assign(types[3], TupleType(Array<Type>(fields)));
std::vector<IndexExpr> oshape({data->shape[0]});
std::vector<Type> fields;
- fields.push_back(TensorTypeNode::make(oshape, DataType::Int(32)));
- fields.push_back(TensorTypeNode::make(data->shape, data->dtype));
+ fields.push_back(TensorType(oshape, DataType::Int(32)));
+ fields.push_back(TensorType(data->shape, data->dtype));
// assign output type
reporter->Assign(types[1], TupleType(Array<Type>(fields)));
// assign output type
if (param->return_indices) {
std::vector<IndexExpr> oshape({dshape[0], dshape[1]});
- reporter->Assign(types[2], TensorTypeNode::make(oshape, DataType::Int(32)));
+ reporter->Assign(types[2], TensorType(oshape, DataType::Int(32)));
} else {
- reporter->Assign(types[2], TensorTypeNode::make(dshape, data->dtype));
+ reporter->Assign(types[2], TensorType(dshape, data->dtype));
}
return true;
}
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
- reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
std::vector<IndexExpr> oshape(
{batch * proposal_attrs->rpn_post_nms_top_n, 5});
- reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype));
+ reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype));
return true;
}
oshape[1] = oshape[1] * param->stride * param->stride;
oshape[2] = indexdiv(oshape[2], param->stride);
oshape[3] = indexdiv(oshape[3], param->stride);
- reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+ reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
* \file de_duplicate.cc
* \brief Use a fresh Id for every Var to make the result well-formed.
*/
-
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/pattern_functor.h>
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
* \brief Add an abstraction over constructors and/or global variables bound to a function.
*
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr_functor.h>
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
* \file ad.cc
* \brief API for Automatic Differentiation for the Relay IR.
*/
-
+#include <tvm/ir/type_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/top/operation.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "let_list.h"
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
- return TupleType({t, RefTypeNode::make(t)});
+ return TupleType({t, RelayRefType(t)});
}
};
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/error.h>
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
return Kind::kType;
}
- Kind VisitType_(const RefTypeNode* op) override {
+ Kind VisitType_(const RelayRefTypeNode* op) override {
// ref types should only contain normal types
- RefType rt = GetRef<RefType>(op);
+ RelayRefType rt = GetRef<RelayRefType>(op);
CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
return Kind::kType;
}
*
* These assumptions do not affect the correctness of the algorithm, however.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
-#include "../ir/type_functor.h"
#include "pass_util.h"
#include "let_list.h"
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
- subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
+ subst.Set(func->type_params[i], IncompleteType(kType));
}
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
} else {
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
- reporter->Assign(types[1], TensorTypeNode::make({}, DataType::Float(32))); // dom_scale
- reporter->Assign(types[2], TensorTypeNode::make({}, DataType::Float(32))); // clip_min
- reporter->Assign(types[3], TensorTypeNode::make({}, DataType::Float(32))); // clip_max
+ reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
+ reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
+ reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
reporter->Assign(types[4], types[0]); // output
return true;
}
* All cases in the transform must return via the mcont,
* wheter directly invoking it, or indirectly by recursion.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
-#include "../ir/type_functor.h"
#include "let_list.h"
#include "pass_util.h"
* If we can not infer a type or there are conflicting typing
* constraints we will trigger an error.
*/
-
+#include <tvm/ir/type_functor.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
#include "./pass_util.h"
#include "type_solver.h"
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
if (op->type_annotation.defined()) {
return op->type_annotation;
} else {
- return IncompleteTypeNode::make(Kind::kType);
+ return IncompleteType(Kind::kType);
}
}
EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
}
Type tuple_type = GetType(op->tuple);
- Type rtype = IncompleteTypeNode::make(Kind::kType);
+ Type rtype = IncompleteType(Kind::kType);
auto attrs = make_object<TupleGetItemAttrs>();
attrs->index = op->index;
solver_.AddConstraint(TypeRelation(
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < td->type_vars.size(); i++) {
- unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+ unknown_args.push_back(IncompleteType(Kind::kType));
}
Type expected = TypeCall(con->constructor->belong_to, unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(con));
// we can expect a certain number of arguments
Array<Type> unknown_args;
for (size_t i = 0; i < tup->patterns.size(); i++) {
- unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+ unknown_args.push_back(IncompleteType(Kind::kType));
}
Type expected = TupleType(unknown_args);
Type unified = Unify(t, expected, GetRef<ObjectRef>(tup));
for (const auto& c : op->clauses) {
VisitPattern(c->lhs, dtype);
}
- Type rtype = IncompleteTypeNode::make(Kind::kType);
+ Type rtype = IncompleteType(Kind::kType);
for (const auto& c : op->clauses) {
rtype = this->Unify(rtype,
GetType(c->rhs),
Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
- Type let_type = IncompleteTypeNode::make(Kind::kType);
+ Type let_type = IncompleteType(Kind::kType);
if (is_functional_literal) {
let_type = GetType(let->var);
// that is a rank-0 boolean tensor.
Type cond_type = this->GetType(ite->cond);
this->Unify(cond_type,
- TensorTypeNode::Scalar(tvm::DataType::Bool()),
+ TensorType::Scalar(tvm::DataType::Bool()),
ite->cond);
Type checked_true = this->GetType(ite->true_branch);
Type checked_false = this->GetType(ite->false_branch);
for (size_t i = 0; i < op->type_params.size(); ++i) {
if (!op->type_params[i].same_as(rel->args[i])) return Type();
}
- Type rtype = IncompleteTypeNode::make(Kind::kType);
+ Type rtype = IncompleteType(Kind::kType);
arg_types.push_back(rtype);
// we can do simple replacement here
solver_.AddConstraint(TypeRelation(
}
for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
- subst_map.Set(fn_ty->type_params[i], IncompleteTypeNode::make(Kind::kType));
+ subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
}
Type ret_type = fn_ty->ret_type;
// This is a temporary work around to check recursive functions whose
// return type is not yet known.
if (!ret_type.defined()) {
- ret_type = IncompleteTypeNode::make(Kind::kType);
+ ret_type = IncompleteType(Kind::kType);
}
Type inst_ty = FuncType(fn_ty->arg_types,
Array<Type> type_args;
for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
- type_args.push_back(IncompleteTypeNode::make(Kind::kType));
+ type_args.push_back(IncompleteType(Kind::kType));
}
return InstantiateFuncType(fn_ty, type_args);
}
// incomplete type => it must be a function taking the arg types
// with an unknown return type
if (inc_ty_node != nullptr) {
- Type ret_type = IncompleteTypeNode::make(Kind::kType);
+ Type ret_type = IncompleteType(Kind::kType);
Type func_type = FuncType(arg_types, ret_type, {}, {});
Type unified = this->Unify(ftype, func_type, GetRef<Call>(call));
fn_ty_node = unified.as<FuncTypeNode>();
}
Type VisitExpr_(const RefCreateNode* op) final {
- return RefTypeNode::make(GetType(op->value));
+ return RelayRefType(GetType(op->value));
}
Type VisitExpr_(const RefReadNode* op) final {
- Type it = IncompleteTypeNode::make(Kind::kType);
- this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
+ Type it = IncompleteType(Kind::kType);
+ this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefRead>(op));
return it;
}
Type VisitExpr_(const RefWriteNode* op) final {
- Type it = IncompleteTypeNode::make(Kind::kType);
- this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
+ Type it = IncompleteType(Kind::kType);
+ this->Unify(GetType(op->ref), RelayRefType(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleType::Empty();
}
* \file type_solver.cc
* \brief Type solver implementations.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/tir/op.h>
#include <string>
#include <memory>
#include <tuple>
#include <utility>
#include "type_solver.h"
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
return Type(nullptr);
}
- return TensorTypeNode::make(shape, tt1->dtype);
+ return TensorType(shape, tt1->dtype);
}
Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
}
for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) {
- subst_map.Set(op->type_params[i], IncompleteTypeNode::make(kType));
+ subst_map.Set(op->type_params[i], IncompleteType(kType));
}
FuncType ft = FuncType(op->arg_types,
return FuncType(arg_types, ret_type, ft2->type_params, type_constraints);
}
- Type VisitType_(const RefTypeNode* op, const Type& tn) final {
- const auto* rtn = tn.as<RefTypeNode>();
+ Type VisitType_(const RelayRefTypeNode* op, const Type& tn) final {
+ const auto* rtn = tn.as<RelayRefTypeNode>();
if (!rtn) {
return Type(nullptr);
}
- return RefTypeNode::make(Unify(op->value, rtn->value));
+ return RelayRefType(Unify(op->value, rtn->value));
}
Type VisitType_(const TypeCallNode* op, const Type& tn) override {
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
Expr e = VarNode::make("dummy_var",
- IncompleteTypeNode::make(Kind::kType));
+ IncompleteType(Kind::kType));
return solver->AddConstraint(c, e);
});
} else {
*
* \brief Utility functions for Relay.
*/
+#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
-#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type, output will always be float 32.
- reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
+ reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
return true;
}
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type
- reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
- reporter->Assign(types[5], TensorTypeNode::make(oshape, out_dtype));
+ reporter->Assign(types[5], TensorType(oshape, out_dtype));
return true;
}
const auto tensor_dtype = tensor_type->dtype;
CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
if (tensor_type->shape.size() != 0) {
- reporter->Assign(expr_type, TensorTypeNode::make({shape}, tensor_type->dtype));
+ reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype));
}
}
TEST(Relay, BuildModule) {
using namespace tvm;
- auto tensor_type = relay::TensorTypeNode::make({2, 3}, DataType::Float(32));
+ auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
TEST(Relay, SelfReference) {
using namespace tvm;
- auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool());
+ auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
TEST(Relay, Sequential) {
using namespace tvm;
- auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, DataType::Float(32));
+ auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32));
auto c_data =
tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
TEST(MicroStandaloneRuntime, BuildModule) {
using namespace tvm;
- auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
+ auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
- tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify((2, 5))
verify((2, 5), axis=0)
verify((2, 5, 6))
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x.astype("float32"))
- tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify((3,), 3, 1, 0, "int32")
verify((3,), 3, 1.0, 0.0, "float32")
verify((2, 2), 5, 2, -2, "int32")
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, weight, bias)
- tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(20,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)