From 12e51e6c4a4578c39e6e19823742358386944776 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 10 Jan 2020 19:27:00 -0800 Subject: [PATCH] [REFACTOR][IR] Initialize Unified IR Expr Data Structure (#4673) This PR moves a few base types from relay and low-level Expr into the ir sub-folder. These classes will serve as a common type system across the stack. Rationale: - PrimExpr for low-level expressions - RelayExpr for advanced features, including Function definition. - Introduce BaseFunc to host all functions, including future PrimFunc(low-level expr functions, subject to discussion). This is a minimum change we can do to unify the classes into a common hierarchy. The main data structure that are variant specific will still be kept in the sub-namespaces. We only include classes that is needed to allow a common Module class. - BaseFunc - GlobalVar - Type definition part of ADT We will only need the BaseFunc and their checked_type to decide the calling convention across the function variants. --- include/tvm/expr.h | 53 +------ include/tvm/ir/adt.h | 142 +++++++++++++++++ include/tvm/ir/expr.h | 270 ++++++++++++++++++++++++++++++++ include/tvm/ir/type.h | 2 +- include/tvm/relay/adt.h | 97 +----------- include/tvm/relay/expr.h | 131 +++------------- include/tvm/relay/feature.h | 5 +- include/tvm/relay/op.h | 2 +- include/tvm/runtime/object.h | 1 + src/ir/adt.cc | 81 ++++++++++ src/ir/expr.cc | 48 ++++++ src/relay/backend/compile_engine.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/ir/adt.cc | 44 ------ src/relay/ir/base.cc | 3 +- src/relay/ir/expr.cc | 20 +-- src/relay/ir/module.cc | 2 +- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/to_cps.cc | 2 +- src/relay/pass/type_solver.cc | 2 +- tests/cpp/relay_pass_type_infer_test.cc | 2 +- 21 files changed, 587 insertions(+), 326 deletions(-) create mode 100644 include/tvm/ir/adt.h create mode 100644 include/tvm/ir/expr.h create mode 100644 src/ir/adt.cc create mode 100644 src/ir/expr.cc diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 976af61..faae303 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -24,6 +24,7 @@ #ifndef TVM_EXPR_H_ #define TVM_EXPR_H_ +#include #include #include #include @@ -37,58 +38,6 @@ namespace tvm { -/*! - * \brief Base node of all primitive expressions. - * - * A primitive expression deals with low-level - * POD data types and handles without - * doing life-cycle management for objects. - * - * PrimExpr is used in the low-level code - * optimizations and integer analysis. - * - * \sa PrimExpr - */ -class PrimExprNode : public Object { - public: - /*! \brief The data type of the expression. */ - DataType dtype; - - static constexpr const char* _type_key = "PrimExpr"; - TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object); -}; - -/*! - * \brief Container of all primitive expressions. - * \sa PrimExprNode - */ -class PrimExpr : public ObjectRef { - public: - PrimExpr() {} - explicit PrimExpr(ObjectPtr ptr) : ObjectRef(ptr) {} - /*! - * \brief construct from integer. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(int32_t value); // NOLINT(*) - /*! - * \brief construct from float. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! - * \brief construct from string. - * \param str The value to be constructed. - */ - TVM_DLL PrimExpr(std::string str); // NOLINT(*) - - /*! \return the data type of this expression. */ - DataType dtype() const { - return static_cast(get())->dtype; - } - - using ContainerType = PrimExprNode; -}; /*! \brief Base node of all statements. */ class StmtNode : public Object { diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h new file mode 100644 index 0000000..6e87162 --- /dev/null +++ b/include/tvm/ir/adt.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/adt.h + * \brief Algebraic data type definitions. + * + * We adopt relay's ADT definition as a unified class + * for decripting structured data. + */ +#ifndef TVM_IR_ADT_H_ +#define TVM_IR_ADT_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief ADT constructor. + * Constructors compare by pointer equality. + * \sa Constructor + */ +class ConstructorNode : public RelayExprNode { + public: + /*! \brief The name (only a hint) */ + std::string name_hint; + /*! \brief Input to the constructor. */ + Array inputs; + /*! \brief The datatype the constructor will construct. */ + GlobalTypeVar belong_to; + /*! \brief Index in the table of constructors (set when the type is registered). */ + mutable int32_t tag = -1; + + ConstructorNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("inputs", &inputs); + v->Visit("belong_to", &belong_to); + v->Visit("tag", &tag); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.Constructor"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to ConstructorNode + * \sa ConstructorNode + */ +class Constructor : public RelayExpr { + public: + /*! + * \brief Constructor + * \param name_hint the name of the constructor. + * \param inputs The input types. + * \param belong_to The data type var the constructor will construct. + */ + TVM_DLL Constructor(std::string name_hint, + Array inputs, + GlobalTypeVar belong_to); + + TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); +}; + +/*! \brief TypeData container node */ +class TypeDataNode : public TypeNode { + public: + /*! + * \brief The header is simply the name of the ADT. + * We adopt nominal typing for ADT definitions; + * that is, differently-named ADT definitions with same constructors + * have different types. + */ + GlobalTypeVar header; + /*! \brief The type variables (to allow for polymorphism). */ + Array type_vars; + /*! \brief The constructors. */ + Array constructors; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("header", &header); + v->Visit("type_vars", &type_vars); + v->Visit("constructors", &constructors); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.TypeData"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); +}; + +/*! + * \brief Stores all data for an Algebraic Data Type (ADT). + * + * In particular, it stores the handle (global type var) for an ADT + * and the constructors used to build it and is kept in the module. Note + * that type parameters are also indicated in the type data: this means that + * for any instance of an ADT, the type parameters must be indicated. That is, + * an ADT definition is treated as a type-level function, so an ADT handle + * must be wrapped in a TypeCall node that instantiates the type-level arguments. + * The kind checker enforces this. + */ +class TypeData : public Type { + public: + /*! + * \brief Constructor + * \param header the name of ADT. + * \param type_vars type variables. + * \param constructors constructors field. + */ + TVM_DLL TypeData(GlobalTypeVar header, + Array type_vars, + Array constructors); + + TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); +}; + +} // namespace tvm +#endif // TVM_IR_ADT_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h new file mode 100644 index 0000000..7b42678 --- /dev/null +++ b/include/tvm/ir/expr.h @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/expr.h + * \brief Base expr nodes in TVM. + */ +#ifndef TVM_IR_EXPR_H_ +#define TVM_IR_EXPR_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief Base type of all the expressions. + * \sa Expr + */ +class BaseExprNode : public Object { + public: + static constexpr const char* _type_key = "Expr"; + TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa BaseExprNode + */ +class BaseExpr : public ObjectRef { + public: + /*! \brief Cosntructor */ + BaseExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit BaseExpr(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief The container type. */ + using ContainerType = BaseExprNode; +}; + +/*! + * \brief Base node of all primitive expressions. + * + * A primitive expression deals with low-level + * POD data types and handles without + * doing life-cycle management for objects. + * + * PrimExpr is used in the low-level code + * optimizations and integer analysis. + * + * \sa PrimExpr + */ +class PrimExprNode : public BaseExprNode { + public: + /*! + * \brief The runtime data type of the primitive expression. + * + * runtime::DataType(dtype) provides coarse grained type information + * during compile time and runtime. It is eagerly built in + * PrimExpr expression construction and can be used for + * quick type checking. + * + * dtype is sufficient to decide the Type of the PrimExpr + * when it corresponds to POD value types such as i32. + * + * When dtype is DataType::Handle(), the expression could corresponds to + * a more fine-grained Type, and we can get the type by running lazy type inference. + */ + DataType dtype; + + static constexpr const char* _type_key = "PrimExpr"; + TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); +}; + +/*! + * \brief Reference to PrimExprNode. + * \sa PrimExprNode + */ +class PrimExpr : public BaseExpr { + public: + /*! \brief Cosntructor */ + PrimExpr() {} + /*! + * \brief Cosntructor from object ptr. + * \param ptr The object pointer. + */ + explicit PrimExpr(ObjectPtr ptr) : BaseExpr(ptr) {} + /*! + * \brief construct from integer. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(int32_t value); // NOLINT(*) + /*! + * \brief construct from float. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! + * \brief construct from string. + * \param str The value to be constructed. + */ + TVM_DLL PrimExpr(std::string str); // NOLINT(*) + + /*! \return the data type of this expression. */ + DataType dtype() const { + return static_cast(get())->dtype; + } + /*! \brief The container type. */ + using ContainerType = PrimExprNode; +}; + +/*! + * \brief Base node of all non-primitive expressions. + * + * RelayExpr supports tensor types, functions and ADT as + * first class citizens. The life-cycle of the corresponding + * objects are implicitly managed by the language. + * + * \sa RelayExpr + */ +class RelayExprNode : public BaseExprNode { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + /*! + * \brief Stores the result of type inference(type checking). + * + * \note This can be undefined before type inference. + * This value is discarded during serialization. + */ + mutable Type checked_type_ = Type(nullptr); + /*! + * \return The checked_type + */ + const Type& checked_type() const; + /*! + * \brief Check if the inferred(checked) type of the Expr + * is backed by a TTypeNode and return it. + * + * \note This function will thrown an error if the node type + * of this Expr is not TTypeNode. + * + * \return The corresponding TTypeNode pointer. + * \tparam The specific TypeNode we look for. + */ + template + inline const TTypeNode* type_as() const; + + static constexpr const char* _type_key = "relay.Expr"; + TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); +}; + +/*! + * \brief Managed reference to RelayExprNode. + * \sa RelayExprNode + */ +class RelayExpr : public BaseExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode); +}; + +class GlobalVar; +/*! + * \brief Global variable that leaves in the top-level module. + * + * A GlobalVar only refers to function definitions. + * This is used to enable recursive calls between function. + * + * \sa GlobalVarNode + */ +class GlobalVarNode : public RelayExprNode { + public: + /*! \brief The name of the variable, this only acts as a hint. */ + std::string name_hint; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + static constexpr const char* _type_key = "relay.GlobalVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to GlobalVarNode. + * \sa GlobalVarNode + */ +class GlobalVar : public RelayExpr { + public: + TVM_DLL explicit GlobalVar(std::string name_hint); + + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); +}; + +/*! + * \brief Base node of all functions. + * + * We support several variants of functions throughout the stack. + * All of the functions shares the same type system(via checked_type) + * to support cross variant calls. + * + * \sa BaseFunc + */ +class BaseFuncNode : public RelayExprNode { + public: + static constexpr const char* _type_key = "BaseFunc"; + TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to BaseFuncNode. + * \sa BaseFuncNode + */ +class BaseFunc : public RelayExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); +}; + +// implementataions +inline const Type& RelayExprNode::checked_type() const { + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); + return this->checked_type_; +} + +template +inline const TTypeNode* RelayExprNode::type_as() const { + static_assert(std::is_base_of::value, + "TType must be a special case of type"); + CHECK(checked_type_.defined()) + << "Type inference for this Expr has not completed. Try to call infer_type pass."; + const TTypeNode* node = checked_type_.as(); + CHECK(node != nullptr) + << "Expected type to be " << TTypeNode::_type_key + << ", but get " << checked_type_->GetTypeKey(); + return node; +} + +} // namespace tvm +#endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ffe1ba8..ab2003e 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -28,7 +28,7 @@ * * ## Relation between Type and runtime::DataType * - * Besides Type, we also store a dtype field in some of the low-level IR's Expr. + * Besides Type, we also store a dtype field in the low-level PrimExpr. * runtime::DataType(dtype) provides coarse grained type information * during compile time and runtime. It is eagerly built in * low-level expression construction and can be used for diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index dac39e0..1807696 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ADT_H_ #include +#include #include #include #include "./base.h" @@ -34,6 +35,12 @@ namespace tvm { namespace relay { +using Constructor = tvm::Constructor; +using ConstructorNode = tvm::ConstructorNode; + +using TypeData = tvm::TypeData; +using TypeDataNode = tvm::TypeDataNode; + /*! \brief Base type for declaring relay pattern. */ class PatternNode : public RelayNode { public: @@ -105,47 +112,6 @@ class PatternVar : public Pattern { TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); }; -/*! - * \brief ADT constructor. - * Constructors compare by pointer equality. - */ -class Constructor; -/*! \brief Constructor container node. */ -class ConstructorNode : public ExprNode { - public: - /*! \brief The name (only a hint) */ - std::string name_hint; - /*! \brief Input to the constructor. */ - tvm::Array inputs; - /*! \brief The datatype the constructor will construct. */ - GlobalTypeVar belong_to; - /*! \brief Index in the table of constructors (set when the type is registered). */ - mutable int32_t tag = -1; - - ConstructorNode() {} - - TVM_DLL static Constructor make(std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to); - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("inputs", &inputs); - v->Visit("belong_to", &belong_to); - v->Visit("tag", &tag); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - } - - static constexpr const char* _type_key = "relay.Constructor"; - TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, ExprNode); -}; - -class Constructor : public Expr { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Constructor, Expr, ConstructorNode); -}; - /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ class PatternConstructor; /*! \brief PatternVar container node */ @@ -201,53 +167,6 @@ class PatternTuple : public Pattern { TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); }; -/*! - * \brief Stores all data for an Algebraic Data Type (ADT). - * - * In particular, it stores the handle (global type var) for an ADT - * and the constructors used to build it and is kept in the module. Note - * that type parameters are also indicated in the type data: this means that - * for any instance of an ADT, the type parameters must be indicated. That is, - * an ADT definition is treated as a type-level function, so an ADT handle - * must be wrapped in a TypeCall node that instantiates the type-level arguments. - * The kind checker enforces this. - */ -class TypeData; -/*! \brief TypeData container node */ -class TypeDataNode : public TypeNode { - public: - /*! - * \brief The header is simply the name of the ADT. - * We adopt nominal typing for ADT definitions; - * that is, differently-named ADT definitions with same constructors - * have different types. - */ - GlobalTypeVar header; - /*! \brief The type variables (to allow for polymorphism). */ - tvm::Array type_vars; - /*! \brief The constructors. */ - tvm::Array constructors; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("header", &header); - v->Visit("type_vars", &type_vars); - v->Visit("constructors", &constructors); - v->Visit("span", &span); - } - - TVM_DLL static TypeData make(GlobalTypeVar header, - tvm::Array type_vars, - tvm::Array constructors); - - static constexpr const char* _type_key = "relay.TypeData"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode); -}; - -class TypeData : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); -}; - /*! \brief A clause in a match expression. */ class Clause; /*! \brief Clause container node. */ @@ -306,7 +225,7 @@ class MatchNode : public ExprNode { class Match : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Match, Expr, MatchNode); + TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); }; } // namespace relay diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 47c8369..47ae552 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -25,6 +25,7 @@ #define TVM_RELAY_EXPR_H_ #include +#include #include #include #include "./base.h" @@ -33,47 +34,12 @@ namespace tvm { namespace relay { -/*! - * \brief A Relay expression. - */ -class Expr; -/*! - * \brief Base type of the Relay expression hiearchy. - */ -class ExprNode : public RelayNode { - public: - /*! - * \brief Stores the result of type inference(type checking). - * - * \note This can be undefined before type inference. - * This value is discarded during serialization. - */ - mutable Type checked_type_ = Type(nullptr); - /*! - * \return The checked_type - */ - const Type& checked_type() const; - /*! - * \brief Check if the inferred(checked) type of the Expr - * is backed by a TTypeNode and return it. - * - * \note This function will thrown an error if the node type - * of this Expr is not TTypeNode. - * - * \return The corresponding TTypeNode pointer. - * \tparam The specific TypeNode we look for. - */ - template - inline const TTypeNode* type_as() const; - - static constexpr const char* _type_key = "relay.Expr"; - TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, RelayNode); -}; - -class Expr : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Expr, ObjectRef, ExprNode); -}; +using Expr = tvm::RelayExpr; +using ExprNode = tvm::RelayExprNode; +using BaseFunc = tvm::BaseFunc; +using BaseFuncNode = tvm::BaseFuncNode; +using GlobalVar = tvm::GlobalVar; +using GlobalVarNode = tvm::GlobalVarNode; /*! * \brief Constant tensor, backed by an NDArray on the cpu(0) device. @@ -112,7 +78,7 @@ class ConstantNode : public ExprNode { class Constant : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Constant, Expr, ConstantNode); + TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); }; /*! \brief Tuple of multiple Exprs */ @@ -137,7 +103,7 @@ class TupleNode : public ExprNode { class Tuple : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); }; /*! @@ -193,37 +159,7 @@ class VarNode : public ExprNode { class Var : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); -}; - -/*! - * \brief Global variable that leaves in the top-level module. - * This is used to enable recursive calls between function. - * - * \note A GlobalVar may only point to functions. - */ -class GlobalVar; -/*! \brief A GlobalId from the node's current type to target type. */ -class GlobalVarNode : public ExprNode { - public: - /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - } - - TVM_DLL static GlobalVar make(std::string name_hint); - - static constexpr const char* _type_key = "relay.GlobalVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, ExprNode); -}; - -class GlobalVar : public Expr { - public: - TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, Expr, GlobalVarNode); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; /*! @@ -231,7 +167,7 @@ class GlobalVar : public Expr { */ class Function; /*! \brief Function container */ -class FunctionNode : public ExprNode { +class FunctionNode : public BaseFuncNode { public: /*! \brief Function parameters */ tvm::Array params; @@ -312,12 +248,12 @@ class FunctionNode : public ExprNode { tvm::Map GetParams() const; static constexpr const char* _type_key = "relay.Function"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; -class Function : public Expr { +class Function : public BaseFunc { public: - TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); }; @@ -388,7 +324,7 @@ class CallNode : public ExprNode { class Call : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Call, Expr, CallNode); + TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); }; /*! @@ -429,7 +365,7 @@ class LetNode : public ExprNode { class Let : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(Let, Expr, LetNode); + TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); }; /*! @@ -470,7 +406,7 @@ class IfNode : public ExprNode { class If : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(If, Expr, IfNode); + TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); }; /*! \brief Get index-th field out of a tuple. */ @@ -497,7 +433,7 @@ class TupleGetItemNode : public ExprNode { class TupleGetItem : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, Expr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); }; /*! \brief Create a new Reference out of initial value. */ @@ -521,7 +457,7 @@ class RefCreateNode : public ExprNode { class RefCreate : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, Expr, RefCreateNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); }; /*! \brief Get value out of Reference. */ @@ -545,7 +481,7 @@ class RefReadNode : public ExprNode { class RefRead : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefRead, Expr, RefReadNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); }; /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; @@ -571,7 +507,7 @@ class RefWriteNode : public ExprNode { class RefWrite : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, Expr, RefWriteNode); + TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); }; /*! @@ -600,32 +536,9 @@ class TempExprNode : public ExprNode { class TempExpr : public Expr { public: - TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, Expr, TempExprNode); + TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); }; -// implementataions -inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) - << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " - << GetRef(this); - return this->checked_type_; -} - -template -inline const TTypeNode* ExprNode::type_as() const { - static_assert(std::is_base_of::value, - "TType must be a special case of type"); - CHECK(checked_type_.defined()) - << "Type inference for this Expr has not completed. Try to call infer_type pass."; - const TTypeNode* node = checked_type_.as(); - CHECK(node != nullptr) - << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->GetTypeKey(); - return node; -} - /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ std::string PrettyPrint(const ObjectRef& node); diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index d7b3b39..8292344 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -25,7 +25,7 @@ #define TVM_RELAY_FEATURE_H_ #include -#include +#include #include namespace tvm { @@ -132,7 +132,6 @@ class FeatureSet { explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } }; -class Expr; /*! * \brief Calculate the feature of the program. * @@ -140,7 +139,7 @@ class Expr; * * \return The FeatureSet. */ -FeatureSet DetectFeature(const Expr& expr); +FeatureSet DetectFeature(const RelayExpr& expr); struct Module; /*! diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index b449519..6bd0a35 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -140,7 +140,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(ObjectPtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : RelayExpr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7d14947..a2e9188 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -650,6 +650,7 @@ struct ObjectEqual { * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ static const uint32_t RuntimeTypeIndex() { \ if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ return TypeName::_type_index; \ diff --git a/src/ir/adt.cc b/src/ir/adt.cc new file mode 100644 index 0000000..2914779 --- /dev/null +++ b/src/ir/adt.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/adt.cc + * \brief ADT type definitions. + */ +#include +#include + +namespace tvm { + +Constructor::Constructor(std::string name_hint, + tvm::Array inputs, + GlobalTypeVar belong_to) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + n->inputs = std::move(inputs); + n->belong_to = std::move(belong_to); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ConstructorNode); + +TVM_REGISTER_GLOBAL("relay._make.Constructor") +.set_body_typed([](std::string name_hint, + tvm::Array inputs, + GlobalTypeVar belong_to) { + return Constructor(name_hint, inputs, belong_to); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorNode(" << node->name_hint << ", " + << node->inputs << ", " << node->belong_to << ")"; +}); + +TypeData::TypeData(GlobalTypeVar header, + tvm::Array type_vars, + tvm::Array constructors) { + ObjectPtr n = make_object(); + n->header = std::move(header); + n->type_vars = std::move(type_vars); + n->constructors = std::move(constructors); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TypeDataNode); + +TVM_REGISTER_GLOBAL("relay._make.TypeData") +.set_body_typed([](GlobalTypeVar header, + tvm::Array type_vars, + tvm::Array constructors) { + return TypeData(header, type_vars, constructors); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " + << node->constructors << ")"; +}); + +} // namespace tvm diff --git a/src/ir/expr.cc b/src/ir/expr.cc new file mode 100644 index 0000000..f698a5d --- /dev/null +++ b/src/ir/expr.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/expr.cc + * \brief The expression AST nodes for the common IR infra. + */ +#include +#include + +namespace tvm { + +GlobalVar::GlobalVar(std::string name_hint) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name_hint); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalVarNode); + +TVM_REGISTER_GLOBAL("relay._make.GlobalVar") +.set_body_typed([](std::string name){ + return GlobalVar(name); +}); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); + +} // namespace tvm diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 62de1c3..e95e03b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -628,7 +628,7 @@ class CompileEngineImpl : public CompileEngineNode { auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); const tvm::ir::StringImmNode* symbol_name = ext_symbol.as(); CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVarNode::make(symbol_name->value); + auto gv = GlobalVar(symbol_name->value); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index b7ecadc..601af9e 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -101,7 +101,7 @@ class LambdaLifter : public ExprMutator { } auto name = GenerateName(func); - auto global = GlobalVarNode::make(name); + auto global = GlobalVar(name); auto free_vars = FreeVars(func); auto free_type_vars = FreeTypeVars(func, module_); diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 1769298..bf9c918 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -96,50 +96,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << "PatternTupleNode(" << node->patterns << ")"; }); -Constructor ConstructorNode::make(std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name_hint); - n->inputs = std::move(inputs); - n->belong_to = std::move(belong_to); - return Constructor(n); -} - -TVM_REGISTER_NODE_TYPE(ConstructorNode); - -TVM_REGISTER_GLOBAL("relay._make.Constructor") -.set_body_typed(ConstructorNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorNode(" << node->name_hint << ", " - << node->inputs << ", " << node->belong_to << ")"; -}); - -TypeData TypeDataNode::make(GlobalTypeVar header, - tvm::Array type_vars, - tvm::Array constructors) { - ObjectPtr n = make_object(); - n->header = std::move(header); - n->type_vars = std::move(type_vars); - n->constructors = std::move(constructors); - return TypeData(n); -} - -TVM_REGISTER_NODE_TYPE(TypeDataNode); - -TVM_REGISTER_GLOBAL("relay._make.TypeData") -.set_body_typed(TypeDataNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " - << node->constructors << ")"; -}); - Clause ClauseNode::make(Pattern lhs, Expr rhs) { ObjectPtr n = make_object(); n->lhs = std::move(lhs); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 4bac1fd..82b3513 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -37,7 +37,8 @@ TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_GLOBAL("relay._base.set_span") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { - CHECK(rn); + rn->span = sp; + } else if (auto* rn = node_ref.as()) { rn->span = sp; } else if (auto* rn = node_ref.as()) { rn->span = sp; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f6ebadf..239a33e 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/tvm/ir/expr.cc + * \file src/tvm/relay/ir/expr.cc * \brief The expression AST nodes of Relay. */ #include @@ -109,24 +109,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->stream << ")"; }); -GlobalVar GlobalVarNode::make(std::string name_hint) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name_hint); - return GlobalVar(n); -} - -TVM_REGISTER_NODE_TYPE(GlobalVarNode); - -TVM_REGISTER_GLOBAL("relay._make.GlobalVar") -.set_body_typed(GlobalVarNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); - - Function FunctionNode::make(tvm::Array params, Expr body, Type ret_type, diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index fdaa607..bf1ebf3 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -279,7 +279,7 @@ Module ModuleNode::FromExpr( } else { func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); } - auto main_gv = GlobalVarNode::make("main"); + auto main_gv = GlobalVar("main"); mod->Add(main_gv, func); return mod; } diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 7f00c71..bce5879 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -213,7 +213,7 @@ class ConstantFolder : public ExprMutator { {}, module_->type_definitions, module_->Imports()); - auto global = GlobalVarNode::make("main"); + auto global = GlobalVar("main"); mod->Add(global, func); auto seq = transform::Sequential(passes); mod = seq(mod); diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 3ca7a08..9e2516b 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -155,7 +155,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); if (cm->count(gv) == 0) { - auto cps_gv = GlobalVarNode::make(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(gv->name_hint + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index c62520a..ceed964 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -662,7 +662,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); auto module = ModuleNode::make({}, {}); - auto dummy_fn_name = GlobalVarNode::make("test"); + auto dummy_fn_name = GlobalVar("test"); module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); auto solver = std::make_shared(dummy_fn_name, module, err_reporter); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 03ad228..f727404 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -29,7 +29,7 @@ TEST(Relay, SelfReference) { auto tensor_type = relay::TensorTypeNode::make({}, DataType::Bool()); auto x = relay::VarNode::make("x", relay::Type()); auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); - + CHECK(f->IsInstance()); auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); -- 2.7.4