[NODE][REFACTOR] Refactor reflection system in node. (#4189)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 24 Oct 2019 20:40:04 +0000 (13:40 -0700)
committerGitHub <noreply@github.com>
Thu, 24 Oct 2019 20:40:04 +0000 (13:40 -0700)
* [NODE][REFACTOR] Refactor reflection system in node.

- Removed the old Node, Node is now just an alias of runtime::Object
- Introduce ReflectionVTable, a new columnar dispatcher to support reflection
  - This allows us to remove vtable from most node objects
  - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE,
    they are no longer virtual.
- Consolidated serialization and reflection features into node.

* Explicit type qualification when calling destructor.

* Fix SPIRV, more comments

76 files changed:
include/tvm/api_registry.h
include/tvm/arithmetic.h
include/tvm/attrs.h
include/tvm/base.h
include/tvm/buffer.h
include/tvm/build_module.h
include/tvm/channel.h
include/tvm/data_layout.h
include/tvm/expr.h
include/tvm/ir.h
include/tvm/lowered_func.h
include/tvm/node/container.h
include/tvm/node/node.h
include/tvm/node/reflection.h [new file with mode: 0644]
include/tvm/node/serialization.h [new file with mode: 0644]
include/tvm/operation.h
include/tvm/packed_func_ext.h
include/tvm/relay/adt.h
include/tvm/relay/base.h
include/tvm/relay/expr.h
include/tvm/relay/interpreter.h
include/tvm/relay/module.h
include/tvm/relay/op.h
include/tvm/relay/transform.h
include/tvm/relay/type.h
include/tvm/runtime/device_api.h
include/tvm/runtime/memory.h
include/tvm/runtime/object.h
include/tvm/runtime/packed_func.h
include/tvm/runtime/registry.h
include/tvm/schedule.h
include/tvm/target_info.h
include/tvm/tensor.h
include/tvm/tensor_intrin.h
nnvm/src/compiler/compile_engine.h
nnvm/src/compiler/graph_hash.h
nnvm/src/compiler/graph_runtime.cc
nnvm/src/compiler/graph_runtime.h
src/README.md
src/api/api_base.cc
src/api/dsl_api.cc [deleted file]
src/arithmetic/bound_deducer.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/int_set.cc
src/arithmetic/int_set.h
src/codegen/spirv/intrin_rule_spirv.cc
src/lang/api_registry.cc
src/lang/ir.cc
src/lang/target_info.cc
src/node/reflection.cc [new file with mode: 0644]
src/node/serialization.cc [moved from src/lang/reflection.cc with 64% similarity]
src/relay/backend/compile_engine.h
src/relay/backend/interpreter.cc
src/relay/backend/param_dict.cc
src/relay/backend/param_dict.h
src/relay/ir/adt.cc
src/relay/ir/base.cc
src/relay/ir/op.cc
src/relay/ir/pretty_printer.cc
src/relay/ir/type_functor.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/device_annotation.cc
src/relay/pass/eta_expand.cc
src/relay/pass/fold_scale_axis.cc
src/relay/pass/forward_rewrite.cc
src/relay/pass/pass_manager.cc
src/relay/pass/quantize/annotate.cc
src/relay/pass/quantize/partition.cc
src/relay/pass/quantize/quantize.cc
src/relay/pass/quantize/quantize.h
src/relay/pass/quantize/realize.cc
src/relay/pass/type_solver.cc
src/relay/pass/util.cc
src/runtime/object.cc
tests/cpp/build_module_test.cc
tests/cpp/packed_func_test.cc

index dbd0972..c41c308 100644 (file)
@@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
   /*! \brief constructor */
   EnvFuncNode() {}
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
   }
 
index e81fa0a..bda6ac6 100644 (file)
@@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
   int64_t min_value;
   int64_t max_value;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("min_value", &min_value);
     v->Visit("max_value", &max_value);
   }
@@ -162,7 +162,7 @@ class ModularSetNode : public Node {
   /*! \brief The base */
   int64_t base;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("coeff", &coeff);
     v->Visit("base", &base);
   }
@@ -351,7 +351,7 @@ enum SignType {
  */
 struct IntSetNode : public Node {
   static constexpr const char* _type_key = "IntSet";
-  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
+  TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
 };
 
 /*!
index fb8927a..2fbb9e6 100644 (file)
@@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
   /*! \brief detailed description of the type */
   std::string description;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("type_info", &type_info);
     v->Visit("description", &description);
@@ -197,7 +197,7 @@ class AttrsHash {
   size_t operator()(const std::string& value) const {
     return std::hash<std::string>()(value);
   }
-  size_t operator()(const Type& value) const {
+  size_t operator()(const DataType& value) const {
     return std::hash<int>()(
         static_cast<int>(value.code()) |
         (static_cast<int>(value.bits()) << 8) |
@@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
  public:
   using TVMArgs = runtime::TVMArgs;
   using TVMRetValue = runtime::TVMRetValue;
+  // visit function
+  virtual void VisitAttrs(AttrVisitor* v) {}
   /*!
    * \brief Initialize the attributes by sequence of arguments
    * \param args The postional arguments in the form
@@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
 template<typename DerivedType>
 class AttrsNode : public BaseAttrsNode {
  public:
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     ::tvm::detail::AttrNormalVisitor vis(v);
     self()->__VisitAttrs__(vis);
   }
 
-  void VisitNonDefaultAttrs(AttrVisitor* v) final {
+  void VisitNonDefaultAttrs(AttrVisitor* v) {
     ::tvm::detail::AttrNonDefaultVisitor vis(v);
     self()->__VisitAttrs__(vis);
   }
index a42de10..9b3b4cd 100644 (file)
 
 /*!
  * \file tvm/base.h
- * \brief Defines the base data structure
+ * \brief Base utilities
  */
 #ifndef TVM_BASE_H_
 #define TVM_BASE_H_
 
 #include <dmlc/logging.h>
-#include <dmlc/registry.h>
-#include <tvm/node/node.h>
-#include <string>
-#include <memory>
-#include <functional>
 #include <utility>
-#include "runtime/registry.h"
 
 namespace tvm {
 
-using ::tvm::Node;
-using ::tvm::NodeRef;
-using ::tvm::AttrVisitor;
-
-/*!
- * \brief Macro to define common node ref methods.
- * \param TypeName The name of the NodeRef.
- * \param BaseTypeName The Base type.
- * \param NodeName The node container type.
- */
-#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName)   \
-  TypeName() {}                                                         \
-  explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n)                  \
-      : BaseTypeName(n) {}                                              \
-  const NodeName* operator->() const {                                  \
-    return static_cast<const NodeName*>(data_.get());                   \
-  }                                                                     \
-  operator bool() const { return this->defined(); }                     \
-  using ContainerType = NodeName;
-
-/*!
- * \brief Macro to define CopyOnWrite function in a NodeRef.
- * \param NodeName The Type of the Node.
- *
- *  CopyOnWrite will generate a unique copy of the internal node.
- *  The node will be copied if it is referenced by multiple places.
- *  The function returns the raw pointer to the node to allow modification
- *  of the content.
- *
- * \code
- *
- *  MyCOWNodeRef ref, ref2;
- *  ref2 = ref;
- *  ref.CopyOnWrite()->value = new_value;
- *  assert(ref2->value == old_value);
- *  assert(ref->value == new_value);
- *
- * \endcode
- */
-#define TVM_DEFINE_NODE_REF_COW(NodeName)                               \
-  NodeName* CopyOnWrite() {                                             \
-      CHECK(data_ != nullptr);                                          \
-      if (!data_.unique())  {                                           \
-        NodePtr<NodeName> n = make_node<NodeName>(*(operator->()));     \
-        ObjectPtr<Object>(std::move(n)).swap(data_);                    \
-      }                                                                 \
-      return static_cast<NodeName*>(data_.get());                       \
-    }
-
-/*! \brief Macro to make it easy to define node ref type given node */
-#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                      \
-  class TypeName : public ::tvm::NodeRef {                           \
-   public:                                                           \
-    TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
-  };                                                                 \
-
-/*!
- * \brief Macro to make it easy to define node ref type that
- *  has a CopyOnWrite member function.
- */
-#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName)           \
-  class TypeName : public BaseType {                                    \
-   public:                                                              \
-    TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName);          \
-    TVM_DEFINE_NODE_REF_COW(NodeName);                                  \
-  };
-
 /*!
  * \brief RAII wrapper function to enter and exit a context object
  *        similar to python's with syntax.
@@ -146,100 +73,6 @@ class With {
   ContextType ctx_;
 };
 
-/*!
- * \brief save the node as well as all the node it depends on as json.
- *  This can be used to serialize any TVM object
- *
- * \return the string representation of the node.
- */
-std::string SaveJSON(const NodeRef& node);
-
-/*!
- * \brief Internal implementation of LoadJSON
- * Load tvm Node object from json and return a shared_ptr of Node.
- * \param json_str The json string to load from.
- *
- * \return The shared_ptr of the Node.
- */
-ObjectPtr<Object> LoadJSON_(std::string json_str);
-
-/*!
- * \brief Load the node from json string.
- *  This can be used to deserialize any TVM object.
- *
- * \param json_str The json string to load from.
- *
- * \tparam NodeType the nodetype
- *
- * \code
- *  Expr e = LoadJSON<Expr>(json_str);
- * \endcode
- */
-template<typename NodeType,
-         typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
-inline NodeType LoadJSON(const std::string& json_str) {
-  return NodeType(LoadJSON_(json_str));
-}
-
-/*!
- * \brief Registry entry for NodeFactory.
- *
- *  There are two types of Nodes that can be serialized.
- *  The normal node requires a registration a creator function that
- *  constructs an empty Node of the corresponding type.
- *
- *  The global singleton(e.g. global operator) where only global_key need to be serialized,
- *  in this case, FGlobalKey need to be defined.
- */
-struct NodeFactoryReg {
-  /*!
-   * \brief creator function.
-   * \param global_key Key that identifies a global single object.
-   *        If this is not empty then FGlobalKey
-   * \return The created function.
-   */
-  using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
-  /*!
-   * \brief Global key function, only needed by global objects.
-   * \param node The node pointer.
-   * \return node The global key to the node.
-   */
-  using FGlobalKey = std::function<std::string(const Node* node)>;
-  /*! \brief registered name */
-  std::string name;
-  /*!
-   * \brief The creator function
-   */
-  FCreate fcreator = nullptr;
-  /*!
-   * \brief The global key function.
-   */
-  FGlobalKey fglobal_key = nullptr;
-  // setter of creator
-  NodeFactoryReg& set_creator(FCreate f) {  // NOLINT(*)
-    this->fcreator = f;
-    return *this;
-  }
-  // setter of creator
-  NodeFactoryReg& set_global_key(FGlobalKey f) {  // NOLINT(*)
-    this->fglobal_key = f;
-    return *this;
-  }
-  // global registry singleton
-  TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
-};
-
-/*!
- * \brief Register a Node type
- * \note This is necessary to enable serialization of the Node.
- */
-#define TVM_REGISTER_NODE_TYPE(TypeName)                                \
-  TVM_REGISTER_OBJECT_TYPE(TypeName);                                   \
-  static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
-      ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
-      .set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
-
-
 #define TVM_STRINGIZE_DETAIL(x) #x
 #define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
 #define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
index f18ed92..d2c2b40 100644 (file)
@@ -135,7 +135,7 @@ class BufferNode : public Node {
   /*! \brief constructor */
   BufferNode() {}
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("data", &data);
     v->Visit("dtype", &dtype);
     v->Visit("shape", &shape);
index c985fbe..7114a45 100644 (file)
@@ -61,7 +61,7 @@ class TargetNode : public Node {
   /*! \return the full device string to pass to codegen::Build */
   TVM_DLL const std::string& str() const;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("target_name", &target_name);
     v->Visit("device_name", &device_name);
     v->Visit("device_type", &device_type);
@@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
   /*! \brief Whether to disable loop vectorization. */
   bool disable_vectorize = false;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("data_alignment", &data_alignment);
     v->Visit("offset_factor", &offset_factor);
     v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
@@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
   /* \brief map from keys to registered functions */
   std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
 
+  void VisitAttrs(AttrVisitor* v) {}
+
   static constexpr const char* _type_key = "GenericFunc";
   TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
 };
index 346291a..3a40a78 100644 (file)
@@ -54,7 +54,7 @@ struct ChannelNode : public Node {
   /*! \brief default data type in read/write */
   Type dtype;
   // visit all attributes
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("handle_var", &handle_var);
     v->Visit("dtype", &dtype);
   }
index ad3da6b..5e2cc08 100644 (file)
@@ -104,7 +104,7 @@ class LayoutNode : public Node {
    */
   Array<IterVar> axes;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("axes", &axes);
   }
@@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
   /*! \brief The destination layout */
   Layout dst_layout;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("src_layout", &src_layout);
     v->Visit("dst_layout", &dst_layout);
     v->Visit("forward_rule", &forward_rule);
index d884a4d..ea57815 100644 (file)
 #include <string>
 #include <algorithm>
 #include <unordered_map>
+#include <iostream>
 #include "base.h"
 #include "dtype.h"
+#include "node/node.h"
 #include "node/container.h"
 #include "node/ir_functor.h"
 #include "runtime/c_runtime_api.h"
@@ -110,7 +112,7 @@ class Variable : public ExprNode {
 
   static Var make(DataType dtype, std::string name_hint);
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("name", &name_hint);
   }
@@ -164,7 +166,7 @@ class IntImm : public ExprNode {
   /*! \brief the Internal value. */
   int64_t value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
   }
@@ -230,7 +232,7 @@ class RangeNode : public Node {
   RangeNode() {}
   RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("min", &min);
     v->Visit("extent", &extent);
   }
@@ -406,7 +408,7 @@ class IterVarNode : public Node {
    */
   std::string thread_tag;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dom", &dom);
     v->Visit("var", &var);
     v->Visit("iter_type", &iter_type);
@@ -490,7 +492,7 @@ class IRPrinter {
 };
 
 // default print function for all nodes
-inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
+inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) {  // NOLINT(*)
   IRPrinter(os).Print(n);
   return os;
 }
index 37718fe..b6c3028 100644 (file)
@@ -45,7 +45,7 @@ class UIntImm : public ExprNode {
   /*! \brief The constant value content. */
   uint64_t value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
   }
@@ -62,7 +62,7 @@ class FloatImm : public ExprNode {
   /*! \brief The constant value content. */
   double value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
   }
@@ -79,7 +79,7 @@ class StringImm : public ExprNode {
   /*! \brief The constant value content. */
   std::string value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
   }
@@ -99,7 +99,7 @@ class Cast : public ExprNode {
   /*! \brief Original data type. */
   Expr value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
   }
@@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode {
   /*! \brief The right operand. */
   Expr b;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->type));
     v->Visit("a", &a);
     v->Visit("b", &b);
@@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode {
   /*! \brief The right operand. */
   Expr b;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->type));
     v->Visit("a", &a);
     v->Visit("b", &b);
@@ -278,7 +278,7 @@ class And : public ExprNode {
   /*! \brief The right operand. */
   Expr b;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->type));
     v->Visit("a", &a);
     v->Visit("b", &b);
@@ -298,7 +298,7 @@ class Or : public ExprNode {
   /*! \brief The right operand. */
   Expr b;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("a", &a);
     v->Visit("b", &b);
@@ -316,7 +316,7 @@ class Not : public ExprNode {
   /*! \brief The input operand. */
   Expr a;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("a", &a);
   }
@@ -343,7 +343,7 @@ class Select : public ExprNode {
   /*! \brief value to be returned when condition is false. */
   Expr false_value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("condition", &condition);
     v->Visit("true_value", &true_value);
@@ -380,7 +380,7 @@ class Load : public ExprNode {
   /*! \brief The predicate to mask which lanes would be loaded. */
   Expr predicate;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("buffer_var", &buffer_var);
     v->Visit("index", &index);
@@ -411,7 +411,7 @@ class Ramp : public ExprNode {
   /*! \brief Total number of lanes. */
   int lanes;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("base", &base);
     v->Visit("stride", &stride);
@@ -432,7 +432,7 @@ class Broadcast : public ExprNode {
   /*! \brief The numerb of lanes. */
   int lanes;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("value", &value);
     v->Visit("lanes", &lanes);
@@ -456,7 +456,7 @@ class Let : public ExprNode {
   /*! \brief The result expression. */
   Expr body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("var", &var);
     v->Visit("value", &value);
@@ -522,7 +522,7 @@ class Call : public ExprNode {
   /*! \brief The output value index if func's value is a tuple. */
   int value_index{0};
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("name", &name);
     v->Visit("args", &args);
@@ -592,7 +592,7 @@ class Shuffle : public ExprNode {
   /*! \brief The indices of each element. */
   Array<Expr> indices;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("vectors", &vectors);
     v->Visit("indices", &indices);
   }
@@ -652,7 +652,7 @@ class CommReducerNode : public Node {
                                   Array<Expr> result,
                                   Array<Expr> identity_element);
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("lhs", &lhs);
     v->Visit("rhs", &rhs);
     v->Visit("result", &result);
@@ -694,7 +694,7 @@ class Reduce : public ExprNode {
                            Expr condition,
                            int value_index);
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &type);
     v->Visit("combiner", &combiner);
     v->Visit("source", &source);
@@ -710,7 +710,7 @@ class Reduce : public ExprNode {
 /*! \brief Any shape. */
 class Any : public ExprNode {
  public:
-  void VisitAttrs(AttrVisitor* v) final {}
+  void VisitAttrs(AttrVisitor* v) {}
   /*! \brief Convert to var. */
   Var ToVar() const {
     return Variable::make(Int(32), "any_dim");
@@ -735,7 +735,7 @@ class LetStmt : public StmtNode {
   /*! \brief The body block. */
   Stmt body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("var", &var);
     v->Visit("value", &value);
     v->Visit("body", &body);
@@ -768,7 +768,7 @@ class AttrStmt : public StmtNode {
   /*! \brief The body statement to be executed */
   Stmt body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("node", &node);
     v->Visit("attr_key", &attr_key);
     v->Visit("value", &value);
@@ -799,7 +799,7 @@ class AssertStmt : public StmtNode {
    */
   Stmt body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("condition", &condition);
     v->Visit("message", &message);
     v->Visit("body", &body);
@@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode {
   /*! \brief Body to be executed. */
   Stmt body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("is_producer", &is_producer);
     v->Visit("body", &body);
@@ -863,7 +863,7 @@ class Store : public StmtNode {
   /*! \brief The predicate to mask which lanes would be stored. */
   Expr predicate;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("buffer_var", &buffer_var);
     v->Visit("value", &value);
     v->Visit("index", &index);
@@ -893,7 +893,7 @@ class Provide : public StmtNode {
   /*! \brief The index arguments of the function. */
   Array<Expr> args;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("value_index", &value_index);
     v->Visit("value", &value);
@@ -929,7 +929,7 @@ class Allocate : public StmtNode {
   Expr new_expr;
   std::string free_function;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("buffer_var", &buffer_var);
     v->Visit("dtype", &type);
     v->Visit("extents", &extents);
@@ -972,7 +972,7 @@ class Free : public StmtNode {
   /*! \brief The buffer variable. */
   Var buffer_var;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("buffer_var", &buffer_var);
   }
 
@@ -1001,7 +1001,7 @@ class Realize : public StmtNode {
   /*! \brief The body of realization. */
   Stmt body;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("value_index", &value_index);
     v->Visit("dtype", &type);
@@ -1031,7 +1031,7 @@ class Block : public StmtNode {
   /*! \brief The restof statments. */
   Stmt rest;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("first", &first);
     v->Visit("rest", &rest);
   }
@@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode {
   /*! \brief The branch to be executed when condition is false, can be null. */
   Stmt else_case;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("condition", &condition);
     v->Visit("then_case", &then_case);
     v->Visit("else_case", &else_case);
@@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode {
   /*! \brief The expression to be evaluated. */
   Expr value;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("value", &value);
   }
 
@@ -1142,7 +1142,7 @@ class For : public StmtNode {
                            DeviceAPI device_api,
                            Stmt body);
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("loop_var", &loop_var);
     v->Visit("min", &min);
     v->Visit("extent", &extent);
@@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode {
   /*! \brief Bounds to be prefetched. */
   Region bounds;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("value_index", &value_index);
     v->Visit("type", &type);
index e2147d0..6709f54 100644 (file)
@@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
   int num_outputs() const final {
     return 1;
   }
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("args", &args);
     v->Visit("thread_axis", &thread_axis);
index 2e1a978..c36c6c1 100644 (file)
@@ -40,8 +40,7 @@ class ArrayNode : public Node {
   /*! \brief the data content */
   std::vector<ObjectRef> data;
 
-  void VisitAttrs(AttrVisitor* visitor) final {
-     // Visitor to array have no effect.
+  void VisitAttrs(AttrVisitor* visitor) {
   }
 
   static constexpr const char* _type_key = "Array";
@@ -51,9 +50,9 @@ class ArrayNode : public Node {
 /*! \brief map node content */
 class MapNode : public Node {
  public:
-  void VisitAttrs(AttrVisitor* visitor) final {
-     // Visitor to map have no effect.
+  void VisitAttrs(AttrVisitor* visitor) {
   }
+
   /*! \brief The corresponding conatiner type */
   using ContainerType = std::unordered_map<
     ObjectRef,
@@ -71,12 +70,12 @@ class MapNode : public Node {
 /*! \brief specialized map node with string as key */
 class StrMapNode : public Node {
  public:
-  void VisitAttrs(AttrVisitor* visitor) final {
-     // Visitor to map have no effect.
-  }
   /*! \brief The corresponding conatiner type */
   using ContainerType = std::unordered_map<std::string, ObjectRef>;
 
+  void VisitAttrs(AttrVisitor* visitor) {
+  }
+
   /*! \brief the data content */
   ContainerType data;
 
index 8203ee6..4014c37 100644 (file)
  */
 /*!
  * \file tvm/node/node.h
- * \brief Node system data structure.
+ * \brief Definitions and helper macros for IR/AST nodes.
+ *
+ *  The node folder contains base utilities for IR/AST nodes,
+ *  invariant of which specific language dialect.
+ *
+ *  We implement AST/IR nodes as sub-classes of runtime::Object.
+ *  The base class Node is just an alias of runtime::Object.
+ *
+ *  Besides the runtime type checking provided by Object,
+ *  node folder contains additional functionalities such as
+ *  reflection and serialization, which are important features
+ *  for building a compiler infra.
  */
 #ifndef TVM_NODE_NODE_H_
 #define TVM_NODE_NODE_H_
 
-#include <dmlc/logging.h>
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/object.h>
 #include <tvm/runtime/memory.h>
-#include <tvm/runtime/ndarray.h>
+#include <tvm/node/reflection.h>
+
 #include <string>
 #include <vector>
 #include <utility>
 #include <type_traits>
 
-
 namespace tvm {
-// forward declaration
-class DataType;
-class Node;
-class NodeRef;
 
-/*!
- * \brief Visitor class to each node content.
- *  The content is going to be called for each field.
- */
-class TVM_DLL AttrVisitor {
- public:
-//! \cond Doxygen_Suppress
-  virtual ~AttrVisitor() = default;
-  virtual void Visit(const char* key, double* value) = 0;
-  virtual void Visit(const char* key, int64_t* value) = 0;
-  virtual void Visit(const char* key, uint64_t* value) = 0;
-  virtual void Visit(const char* key, int* value) = 0;
-  virtual void Visit(const char* key, bool* value) = 0;
-  virtual void Visit(const char* key, std::string* value) = 0;
-  virtual void Visit(const char* key, void** value) = 0;
-  virtual void Visit(const char* key, DataType* value) = 0;
-  virtual void Visit(const char* key, NodeRef* value) = 0;
-  virtual void Visit(const char* key, runtime::NDArray* value) = 0;
-  virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
-  template<typename ENum,
-           typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
-  void Visit(const char* key, ENum* ptr) {
-    static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
-                  "declare enum to be enum int to use visitor");
-    this->Visit(key, reinterpret_cast<int*>(ptr));
-  }
-//! \endcond
-};
+using runtime::TypeIndex;
+using runtime::Object;
+using runtime::ObjectPtr;
+using runtime::ObjectRef;
+using runtime::GetRef;
+using runtime::Downcast;
+using runtime::ObjectHash;
+using runtime::ObjectEqual;
+using runtime::make_object;
 
-/*! \brief Reuse the type index in he runtime. */
-using TypeIndex = runtime::TypeIndex;
+using NodeHash = ObjectHash;
+using NodeEqual = ObjectEqual;
+using Node = Object;
 
 /*!
- * \brief base class of node container in DSL AST.
+ * \brief Base class of all references to AST/IR nodes.
  */
-class Node : public runtime::Object {
+class NodeRef : public ObjectRef {
  public:
-  /*! \brief virtual destructor */
-  virtual ~Node() {}
-
-  /*!
-   * \brief Apply visitor to each field of the Node
-   *  Visitor could mutate the content of the node.
-   *  override if Node contains attribute fields.
-   * \param visitor The visitor
-   */
-  virtual void VisitAttrs(AttrVisitor* visitor) {}
-
-  static constexpr const char* _type_key = "Node";
-  static constexpr uint32_t _type_index = TypeIndex::kDynamic;
-
-  TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object);
+  NodeRef() {}
+  explicit NodeRef(ObjectPtr<Object> n) : ObjectRef(n) {}
 };
 
-
 /*!
- * \brief Base class of all node reference object
- *  NodeRef is just a alias of ObjectRef.
+ * \brief Allocate a node object.
+ * \param args arguments to the constructor.
+ * \tparam T the node type.
+ * \return The NodePtr to the allocated object.
+ * \note This function is an alias of make_object.
  */
-class NodeRef : public runtime::ObjectRef {
- public:
-  /*! \brief type indicate the container type */
-  using ContainerType = Node;
-
-  /*! \return the internal node pointer */
-  const Node* get() const {
-    return static_cast<const Node*>(ObjectRef::get());
-  }
-  /*! \return the internal node pointer */
-  const Node* operator->() const {
-    return get();
-  }
-  /*!
-   * \brief A more powerful version of as that also works with
-   *  intermediate base types.
-   * \tparam T the target type, must be subtype of IRNode
-   */
-  template<typename T>
-  const T *as_derived() const {
-    return as<T>();
-  }
-  /*! \brief default constructor */
-  NodeRef() = default;
-  explicit NodeRef(runtime::ObjectPtr<runtime::Object> ptr) : ObjectRef(ptr) {}
-};
+template<typename T, typename... Args>
+inline NodePtr<T> make_node(Args&&... args) {
+  return runtime::make_object<T>(std::forward<Args>(args)...);
+}
 
 /*!
  * \brief helper macro to declare type information in a base node.
@@ -139,27 +94,67 @@ class NodeRef : public runtime::ObjectRef {
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent);
 
 
-using runtime::Object;
-using runtime::ObjectPtr;
-using runtime::ObjectRef;
-using runtime::GetRef;
-using runtime::Downcast;
-using runtime::make_object;
-using runtime::ObjectHash;
-using runtime::ObjectEqual;
+/*!
+ * \brief Macro to define common node ref methods.
+ * \param TypeName The name of the NodeRef.
+ * \param BaseTypeName The Base type.
+ * \param NodeName The node container type.
+ */
+#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName)   \
+  TypeName() {}                                                         \
+  explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n)                  \
+      : BaseTypeName(n) {}                                              \
+  const NodeName* operator->() const {                                  \
+    return static_cast<const NodeName*>(data_.get());                   \
+  }                                                                     \
+  operator bool() const { return this->defined(); }                     \
+  using ContainerType = NodeName;
 
-using NodeHash = ObjectHash;
-using NodeEqual = ObjectEqual;
+/*!
+ * \brief Macro to define CopyOnWrite function in a NodeRef.
+ * \param NodeName The Type of the Node.
+ *
+ *  CopyOnWrite will generate a unique copy of the internal node.
+ *  The node will be copied if it is referenced by multiple places.
+ *  The function returns the raw pointer to the node to allow modification
+ *  of the content.
+ *
+ * \code
+ *
+ *  MyCOWNodeRef ref, ref2;
+ *  ref2 = ref;
+ *  ref.CopyOnWrite()->value = new_value;
+ *  assert(ref2->value == old_value);
+ *  assert(ref->value == new_value);
+ *
+ * \endcode
+ */
+#define TVM_DEFINE_NODE_REF_COW(NodeName)                               \
+  NodeName* CopyOnWrite() {                                             \
+      CHECK(data_ != nullptr);                                          \
+      if (!data_.unique())  {                                           \
+        NodePtr<NodeName> n = make_node<NodeName>(*(operator->()));     \
+        ObjectPtr<Object>(std::move(n)).swap(data_);                    \
+      }                                                                 \
+      return static_cast<NodeName*>(data_.get());                       \
+    }
+
+/*! \brief Macro to make it easy to define node ref type given node */
+#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                      \
+  class TypeName : public ::tvm::NodeRef {                           \
+   public:                                                           \
+    TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
+  };                                                                 \
 
 /*!
- * \brief Allocate a node object.
- * \param args arguments to the constructor.
- * \tparam T the node type.
- * \return The NodePtr to the allocated object.
+ * \brief Macro to make it easy to define node ref type that
+ *  has a CopyOnWrite member function.
  */
-template<typename T, typename... Args>
-inline NodePtr<T> make_node(Args&&... args) {
-  return runtime::make_object<T>(std::forward<Args>(args)...);
-}
+#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName)           \
+  class TypeName : public BaseType {                                    \
+   public:                                                              \
+    TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName);          \
+    TVM_DEFINE_NODE_REF_COW(NodeName);                                  \
+  };
 }  // namespace tvm
 #endif  // TVM_NODE_NODE_H_
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
new file mode 100644 (file)
index 0000000..e6caa44
--- /dev/null
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/reflection.h
+ * \brief Reflection and serialization of compiler IR/AST nodes.
+ */
+#ifndef TVM_NODE_REFLECTION_H_
+#define TVM_NODE_REFLECTION_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/ndarray.h>
+
+#include <vector>
+#include <string>
+
+namespace tvm {
+
+// forward declaration
+class DataType;
+
+using runtime::Object;
+using runtime::ObjectPtr;
+using runtime::ObjectRef;
+
+/*!
+ * \brief Visitor class for to get the attributesof a AST/IR node.
+ *  The content is going to be called for each field.
+ *
+ *  Each objects that wants reflection will need to implement
+ *  a VisitAttrs function and call visitor->Visit on each of its field.
+ */
+class TVM_DLL AttrVisitor {
+ public:
+//! \cond Doxygen_Suppress
+  virtual ~AttrVisitor() = default;
+  virtual void Visit(const char* key, double* value) = 0;
+  virtual void Visit(const char* key, int64_t* value) = 0;
+  virtual void Visit(const char* key, uint64_t* value) = 0;
+  virtual void Visit(const char* key, int* value) = 0;
+  virtual void Visit(const char* key, bool* value) = 0;
+  virtual void Visit(const char* key, std::string* value) = 0;
+  virtual void Visit(const char* key, void** value) = 0;
+  virtual void Visit(const char* key, DataType* value) = 0;
+  virtual void Visit(const char* key, runtime::NDArray* value) = 0;
+  virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
+  template<typename ENum,
+           typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+  void Visit(const char* key, ENum* ptr) {
+    static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
+                  "declare enum to be enum int to use visitor");
+    this->Visit(key, reinterpret_cast<int*>(ptr));
+  }
+//! \endcond
+};
+
+/*!
+ * \brief Virtual function table to support IR/AST node reflection.
+ *
+ * Functions are stored  in columar manner.
+ * Each column is a vector indexed by Object's type_index.
+ */
+class ReflectionVTable {
+ public:
+  /*!
+   * \brief Visitor function.
+   * \note We use function pointer, instead of std::function
+   *       to reduce the dispatch overhead as field visit
+   *       does not need as much customization.
+   */
+  typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
+  /*!
+   * \brief creator function.
+   * \param global_key Key that identifies a global single object.
+   *        If this is not empty then FGlobalKey must be defined for the object.
+   * \return The created function.
+   */
+  using FCreate = std::function<ObjectPtr<Object>(const std::string& global_key)>;
+  /*!
+   * \brief Global key function, only needed by global objects.
+   * \param node The node pointer.
+   * \return node The global key to the node.
+   */
+  using FGlobalKey = std::function<std::string(const Object* self)>;
+  /*!
+   * \brief Dispatch the VisitAttrs function.
+   * \param self The pointer to the object.
+   * \param visitor The attribute visitor.
+   */
+  inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
+  /*!
+   * \brief Get global key of the object, if any.
+   * \param self The pointer to the object.
+   * \return the global key if object has one, otherwise return empty string.
+   */
+  inline std::string GetGlobalKey(Object* self) const;
+  /*!
+   * \brief Create an initial object using default constructor
+   *        by type_key and global key.
+   *
+   * \param type_key The type key of the object.
+   * \param global_key A global key that can be used to uniquely identify the object if any.
+   */
+  TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
+                                             const std::string& global_key = "") const;
+  /*!
+   * \brief Get an field object by the attr name.
+   * \param self The pointer to the object.
+   * \param attr_name The name of the field.
+   * \return The corresponding attribute value.
+   * \note This function will throw an exception if the object does not contain the field.
+   */
+  TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const;
+
+  /*!
+   * \brief List all the fields in the object.
+   * \return All the fields.
+   */
+  TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
+
+  /*! \return The global singleton. */
+  TVM_DLL static ReflectionVTable* Global();
+
+  class Registry;
+  template<typename T>
+  inline Registry Register();
+
+ private:
+  /*! \brief Attribute visitor. */
+  std::vector<FVisitAttrs> fvisit_attrs_;
+  /*! \brief Creation function. */
+  std::vector<FCreate> fcreate_;
+  /*! \brief Global key function. */
+  std::vector<FGlobalKey> fglobal_key_;
+};
+
+/*! \brief Registry of a reflection table. */
+class ReflectionVTable::Registry {
+ public:
+  Registry(ReflectionVTable* parent, uint32_t type_index)
+      : parent_(parent), type_index_(type_index) { }
+  /*!
+   * \brief Set fcreate function.
+   * \param f The creator function.
+   * \return rference to self.
+   */
+  Registry& set_creator(FCreate f) {  // NOLINT(*)
+    CHECK_LT(type_index_, parent_->fcreate_.size());
+    parent_->fcreate_[type_index_] = f;
+    return *this;
+  }
+  /*!
+   * \brief Set global_key function.
+   * \param f The creator function.
+   * \return rference to self.
+   */
+  Registry& set_global_key(FGlobalKey f) {  // NOLINT(*)
+    CHECK_LT(type_index_, parent_->fglobal_key_.size());
+    parent_->fglobal_key_[type_index_] = f;
+    return *this;
+  }
+
+ private:
+  ReflectionVTable* parent_;
+  uint32_t type_index_;
+};
+
+/*!
+ * \brief Register a node type to object registry and reflection registry.
+ * \param TypeName The name of the type.
+ * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well.
+ */
+#define TVM_REGISTER_NODE_TYPE(TypeName)                                \
+  TVM_REGISTER_OBJECT_TYPE(TypeName);                                   \
+  static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry &      \
+  __make_Node ## _ ## TypeName ## __ =                                  \
+      ::tvm::ReflectionVTable::Global()->Register<TypeName>()           \
+      .set_creator([](const std::string&) {                             \
+          return ::tvm::runtime::make_object<TypeName>();               \
+        })
+
+// Implementation details
+template<typename T>
+inline ReflectionVTable::Registry
+ReflectionVTable::Register() {
+  uint32_t tindex = T::RuntimeTypeIndex();
+  if (tindex >= fvisit_attrs_.size()) {
+    fvisit_attrs_.resize(tindex + 1, nullptr);
+    fcreate_.resize(tindex + 1, nullptr);
+    fglobal_key_.resize(tindex + 1, nullptr);
+  }
+  // functor that implemnts the redirection.
+  struct Functor {
+    static void VisitAttrs(Object* self, AttrVisitor* v) {
+      static_cast<T*>(self)->VisitAttrs(v);
+     }
+  };
+
+  fvisit_attrs_[tindex] = Functor::VisitAttrs;
+  return Registry(this, tindex);
+}
+
+inline void ReflectionVTable::
+VisitAttrs(Object* self, AttrVisitor* visitor) const {
+  uint32_t tindex = self->type_index();
+  if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
+    LOG(FATAL) << "TypeError: " << self->GetTypeKey()
+               << " is not registered via TVM_REGISTER_NODE_TYPE";
+  }
+  fvisit_attrs_[tindex](self, visitor);
+}
+
+inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
+  uint32_t tindex = self->type_index();
+  if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
+    return fglobal_key_[tindex](self);
+  } else {
+    return std::string();
+  }
+}
+
+}  // namespace tvm
+#endif  // TVM_NODE_REFLECTION_H_
diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h
new file mode 100644 (file)
index 0000000..ac67594
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Utility functions for serialization.
+ * \file tvm/node/serialization.h
+ */
+#ifndef TVM_NODE_SERIALIZATION_H_
+#define TVM_NODE_SERIALIZATION_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <string>
+
+namespace tvm {
+/*!
+ * \brief save the node as well as all the node it depends on as json.
+ *  This can be used to serialize any TVM object
+ *
+ * \return the string representation of the node.
+ */
+TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node);
+
+/*!
+ * \brief Internal implementation of LoadJSON
+ * Load tvm Node object from json and return a shared_ptr of Node.
+ * \param json_str The json string to load from.
+ *
+ * \return The shared_ptr of the Node.
+ */
+TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str);
+
+}  // namespace tvm
+#endif  // TVM_NODE_SERIALIZATION_H_
index b942464..f53c1ce 100644 (file)
@@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode {
       const std::unordered_map<IterVar, Range>& dom_map,
       bool debug_keep_trivial_loop) const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("attrs", &attrs);
@@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
       bool debug_keep_trivial_loop) const final;
   size_t num_schedulable_dims() const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("attrs", &attrs);
@@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
       bool debug_keep_trivial_loop) const final;
   size_t num_schedulable_dims() const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("axis", &axis);
@@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode {
       const std::unordered_map<IterVar, Range>& dom_map,
       bool debug_keep_trivial_loop) const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("attrs", &attrs);
@@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode {
       const std::unordered_map<IterVar, Range>& dom_map,
       bool debug_keep_trivial_loop) const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("attrs", &attrs);
@@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode {
       const std::unordered_map<IterVar, Range>& dom_map,
       bool debug_keep_trivial_loop) const final;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("tag", &tag);
     v->Visit("attrs", &attrs);
index 48d46fd..71f8f55 100644 (file)
@@ -20,7 +20,7 @@
 /*!
  * \file tvm/packed_func_ext.h
  * \brief Extension package to PackedFunc
- *   This enales pass NodeRef types into/from PackedFunc.
+ *   This enales pass ObjectRef types into/from PackedFunc.
  */
 #ifndef TVM_PACKED_FUNC_EXT_H_
 #define TVM_PACKED_FUNC_EXT_H_
@@ -129,18 +129,18 @@ inline std::string ObjectTypeName() {
 
 // extensions for tvm arg value
 
-template<typename TNodeRef>
-inline TNodeRef TVMArgValue::AsNodeRef() const {
+template<typename TObjectRef>
+inline TObjectRef TVMArgValue::AsObjectRef() const {
   static_assert(
-      std::is_base_of<NodeRef, TNodeRef>::value,
-      "Conversion only works for NodeRef");
-  if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
+      std::is_base_of<ObjectRef, TObjectRef>::value,
+      "Conversion only works for ObjectRef");
+  if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
   TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
   Object* ptr = static_cast<Object*>(value_.v_handle);
-  CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr))
-      << "Expected type " << ObjectTypeName<TNodeRef>()
+  CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
+      << "Expected type " << ObjectTypeName<TObjectRef>()
       << " but get " << ptr->GetTypeKey();
-  return TNodeRef(ObjectPtr<Node>(ptr));
+  return TObjectRef(ObjectPtr<Node>(ptr));
 }
 
 inline TVMArgValue::operator tvm::Expr() const {
@@ -184,28 +184,28 @@ inline TVMArgValue::operator tvm::Integer() const {
   return Integer(ObjectPtr<Node>(ptr));
 }
 
-template<typename TNodeRef, typename>
+template<typename TObjectRef, typename>
 inline bool TVMPODValue_::IsObjectRef() const {
   TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
   Object* ptr = static_cast<Object*>(value_.v_handle);
-  return ObjectTypeChecker<TNodeRef>::Check(ptr);
+  return ObjectTypeChecker<TObjectRef>::Check(ptr);
 }
 
 // extensions for TVMRetValue
-template<typename TNodeRef>
-inline TNodeRef TVMRetValue::AsNodeRef() const {
+template<typename TObjectRef>
+inline TObjectRef TVMRetValue::AsObjectRef() const {
   static_assert(
-      std::is_base_of<NodeRef, TNodeRef>::value,
-      "Conversion only works for NodeRef");
-  if (type_code_ == kNull) return TNodeRef();
+      std::is_base_of<ObjectRef, TObjectRef>::value,
+      "Conversion only works for ObjectRef");
+  if (type_code_ == kNull) return TObjectRef();
   TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
 
   Object* ptr = static_cast<Object*>(value_.v_handle);
 
-  CHECK(ObjectTypeChecker<TNodeRef>::Check(ptr))
-      << "Expected type " << ObjectTypeName<TNodeRef>()
+  CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
+      << "Expected type " << ObjectTypeName<TObjectRef>()
       << " but get " << ptr->GetTypeKey();
-  return TNodeRef(ObjectPtr<Object>(ptr));
+  return TObjectRef(ObjectPtr<Object>(ptr));
 }
 
 // type related stuffs
index e54d88d..a743532 100644 (file)
@@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode {
 
   TVM_DLL static PatternWildcard make();
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("span", &span);
   }
 
@@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode {
 
   TVM_DLL static PatternVar make(tvm::relay::Var var);
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("var", &var);
     v->Visit("span", &span);
   }
@@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode {
                                   tvm::Array<Type> inputs,
                                   GlobalTypeVar belong_to);
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
     v->Visit("inputs", &inputs);
     v->Visit("belong_to", &belong_to);
@@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode {
 
   TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("constructor", &constructor);
     v->Visit("patterns", &patterns);
     v->Visit("span", &span);
@@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode {
 
   TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("patterns", &patterns);
     v->Visit("span", &span);
   }
@@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode {
   /*! \brief The constructors. */
   tvm::Array<Constructor> constructors;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("header", &header);
     v->Visit("type_vars", &type_vars);
     v->Visit("constructors", &constructors);
@@ -240,7 +240,7 @@ class ClauseNode : public Node {
   /*! \brief The resulting value. */
   Expr rhs;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("lhs", &lhs);
     v->Visit("rhs", &rhs);
   }
@@ -269,7 +269,7 @@ class MatchNode : public ExprNode {
    */
   bool complete;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("data", &data);
     v->Visit("clauses", &clauses);
     v->Visit("complete", &complete);
index 15330b0..5a2326e 100644 (file)
@@ -107,7 +107,7 @@ class SourceNameNode : public Node {
   /*! \brief The source name. */
   std::string name;
   // override attr visitor
-  void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
+  void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
 
   static constexpr const char* _type_key = "relay.SourceName";
   TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
@@ -160,7 +160,7 @@ class SpanNode : public Node {
   /*! \brief column offset */
   int col_offset;
   // override attr visitor
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("source", &source);
     v->Visit("lineno", &lineno);
     v->Visit("col_offset", &col_offset);
@@ -204,7 +204,7 @@ class IdNode : public Node {
    */
   std::string name_hint;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
   }
 
index 281b992..6df4273 100644 (file)
@@ -95,7 +95,7 @@ class ConstantNode : public ExprNode {
     return data->ndim == 0;
   }
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("data", &data);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -117,7 +117,7 @@ class TupleNode : public ExprNode {
   /*! \brief the fields of the tuple */
   tvm::Array<relay::Expr> fields;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("fields", &fields);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -165,7 +165,7 @@ class VarNode : public ExprNode {
     return vid->name_hint;
   }
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("vid", &vid);
     v->Visit("type_annotation", &type_annotation);
     v->Visit("span", &span);
@@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode {
   /*! \brief The name of the variable, this only acts as a hint. */
   std::string name_hint;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -243,7 +243,7 @@ class FunctionNode : public ExprNode {
    */
   tvm::Attrs attrs;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("params", &params);
     v->Visit("body", &body);
     v->Visit("ret_type", &ret_type);
@@ -327,7 +327,7 @@ class CallNode : public ExprNode {
    */
   tvm::Array<Type> type_args;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("op", &op);
     v->Visit("args", &args);
     v->Visit("attrs", &attrs);
@@ -369,7 +369,7 @@ class LetNode : public ExprNode {
   /*! \brief The body of the let binding */
   Expr body;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("var", &var);
     v->Visit("value", &value);
     v->Visit("body", &body);
@@ -407,7 +407,7 @@ class IfNode : public ExprNode {
   /*! \brief The expression evaluated when condition is false */
   Expr false_branch;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("cond", &cond);
     v->Visit("true_branch", &true_branch);
     v->Visit("false_branch", &false_branch);
@@ -432,7 +432,7 @@ class TupleGetItemNode : public ExprNode {
   /*! \brief which value to get */
   int index;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("tuple_value", &tuple);
     v->Visit("index", &index);
     v->Visit("span", &span);
@@ -454,7 +454,7 @@ class RefCreateNode : public ExprNode {
   /*! \brief The initial value of the Reference. */
   Expr value;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("value", &value);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -475,7 +475,7 @@ class RefReadNode : public ExprNode {
   /*! \brief The Reference Expression. */
   Expr ref;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("ref", &ref);
     v->Visit("span", &span);
     v->Visit("_checked_type_", &checked_type_);
@@ -498,7 +498,7 @@ class RefWriteNode : public ExprNode {
   /*! \brief The value to write into. */
   Expr value;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("ref", &ref);
     v->Visit("value", &value);
     v->Visit("span", &span);
index f0b1e7c..3bdc125 100644 (file)
@@ -106,7 +106,7 @@ class ClosureNode : public ValueNode {
 
   ClosureNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("env", &env);
     v->Visit("func", &func);
   }
@@ -154,7 +154,7 @@ struct TupleValueNode : ValueNode {
 
   TupleValueNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
 
   TVM_DLL static TupleValue make(tvm::Array<Value> value);
 
@@ -173,7 +173,7 @@ struct TensorValueNode : ValueNode {
 
   TensorValueNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }
+  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); }
 
   /*! \brief Build a value from an NDArray. */
   TVM_DLL static TensorValue make(runtime::NDArray data);
@@ -192,7 +192,7 @@ struct RefValueNode : ValueNode {
 
   RefValueNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("value", &value);
   }
 
@@ -215,7 +215,7 @@ struct ConstructorValueNode : ValueNode {
   /*! \brief Optional field tracking ADT constructor. */
   Constructor constructor;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("tag", &tag);
     v->Visit("fields", &fields);
     v->Visit("constructor", &constructor);
index 10d7234..160ae5d 100644 (file)
@@ -68,7 +68,7 @@ class ModuleNode : public RelayNode {
 
   ModuleNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("functions", &functions);
     v->Visit("type_definitions", &type_definitions);
     v->Visit("global_var_map_", &global_var_map_);
index 572c194..7d2a1f6 100644 (file)
@@ -24,6 +24,8 @@
 #ifndef TVM_RELAY_OP_H_
 #define TVM_RELAY_OP_H_
 
+#include <dmlc/registry.h>
+
 #include <functional>
 #include <limits>
 #include <string>
@@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode {
    */
   int32_t support_level = 10;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("op_type", &op_type);
     v->Visit("description", &description);
index 08ea307..82144d7 100644 (file)
@@ -101,7 +101,7 @@ class PassContextNode : public RelayNode {
 
   PassContextNode() = default;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("opt_level", &opt_level);
     v->Visit("fallback_device", &fallback_device);
     v->Visit("required_pass", &required_pass);
@@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode {
 
   PassInfoNode() = default;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("opt_level", &opt_level);
     v->Visit("name", &name);
     v->Visit("required", &required);
@@ -221,6 +221,7 @@ class Pass;
  */
 class PassNode : public RelayNode {
  public:
+  virtual ~PassNode() {}
   /*!
    * \brief Get the pass information/meta data. */
   virtual PassInfo Info() const = 0;
@@ -247,7 +248,7 @@ class PassNode : public RelayNode {
   virtual Module operator()(const Module& mod,
                             const PassContext& pass_ctx) const = 0;
 
-  void VisitAttrs(tvm::AttrVisitor* v) override {}
+  void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "relay.Pass";
   TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
index a5cc3c8..e0c056c 100644 (file)
@@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
   /*! \brief The content data type */
   DataType dtype;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("shape", &shape);
     v->Visit("dtype", &dtype);
     v->Visit("span", &span);
@@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode {
   /*! \brief The kind of type parameter */
   Kind kind;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("var", &var);
     v->Visit("kind", &kind);
     v->Visit("span", &span);
@@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode {
   /*! \brief The kind of type parameter */
   Kind kind;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("var", &var);
     v->Visit("kind", &kind);
     v->Visit("span", &span);
@@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode {
   /*! \brief The arguments. */
   tvm::Array<Type> args;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("args", &args);
     v->Visit("span", &span);
@@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode {
  public:
   Kind kind;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("kind", &kind);
     v->Visit("span", &span);
   }
@@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode {
    */
   tvm::Array<TypeConstraint> type_constraints;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("arg_types", &arg_types);
     v->Visit("ret_type", &ret_type);
     v->Visit("type_params", &type_params);
@@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode {
 
   TupleTypeNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("fields", &fields);
     v->Visit("span", &span);
   }
@@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode {
 
   RefTypeNode() {}
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("value", &value);
     v->Visit("span", &span);
   }
@@ -417,7 +417,7 @@ class TypeReporterNode : public Node {
   TVM_DLL virtual Module GetModule() = 0;
 
   // solver is not serializable.
-  void VisitAttrs(tvm::AttrVisitor* v) final {}
+  void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "relay.TypeReporter";
   TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node);
@@ -488,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode {
   /*! \brief Attributes to the relation function */
   Attrs attrs;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("func", &func);
     v->Visit("args", &args);
     v->Visit("num_inputs", &num_inputs);
index 68029c1..bb362dc 100644 (file)
@@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) {  // NOLINT(*)
   os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")";
   return os;
 }
+
 #endif
 }  // namespace runtime
 }  // namespace tvm
index 01c08d3..d28552e 100644 (file)
@@ -82,6 +82,8 @@ class SimpleObjAllocator :
   template<typename T>
   class Handler {
    public:
+    using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
+
     template<typename... Args>
     static T* New(SimpleObjAllocator*, Args&&... args) {
       // NOTE: the first argument is not needed for SimpleObjAllocator
@@ -91,7 +93,15 @@ class SimpleObjAllocator :
       // In the case of an object pool, an allocator needs to create
       // a special chunk memory that hides reference to the allocator
       // and call allocator's release function in the deleter.
-      return new T(std::forward<Args>(args)...);
+
+      // NOTE2: Use inplace new to allocate
+      // This is used to get rid of warning when deleting a virtual
+      // class with non-virtual destructor.
+      // We are fine here as we captured the right deleter during construction.
+      // This is also the right way to get storage type for an object pool.
+      StorageType* data = new StorageType();
+      new (data) T(std::forward<Args>(args)...);
+      return reinterpret_cast<T*>(data);
     }
 
     static Object::FDeleter Deleter() {
@@ -99,8 +109,17 @@ class SimpleObjAllocator :
     }
 
    private:
-    static void Deleter_(Object* ptr) {
-      delete static_cast<T*>(ptr);
+    static void Deleter_(Object* objptr) {
+      // NOTE: this is important to cast back to T*
+      // because objptr and tptr may not be the same
+      // depending on how sub-class allocates the space.
+      T* tptr = static_cast<T*>(objptr);
+      // It is important to do tptr->T::~T(),
+      // so that we explicitly call the specific destructor
+      // instead of tptr->~T(), which could mean the intention
+      // call a virtual destructor(which may not be available and is not required).
+      tptr->T::~T();
+      delete reinterpret_cast<StorageType*>(tptr);
     }
   };
 };
index 143f3bb..cc4a295 100644 (file)
@@ -23,6 +23,7 @@
 #ifndef TVM_RUNTIME_OBJECT_H_
 #define TVM_RUNTIME_OBJECT_H_
 
+#include <dmlc/logging.h>
 #include <type_traits>
 #include <string>
 #include <utility>
@@ -189,7 +190,7 @@ class Object {
    * \param key The type key.
    * \return the result.
    */
-  TVM_DLL static uint32_t TypeKey2Index(const char* key);
+  TVM_DLL static uint32_t TypeKey2Index(const std::string& key);
 
 #if TVM_OBJECT_ATOMIC_REF_COUNTER
   using RefCounterType = std::atomic<int32_t>;
@@ -197,18 +198,24 @@ class Object {
   using RefCounterType = int32_t;
 #endif
 
-  // Object type properties
   static constexpr const char* _type_key = "Object";
-  static constexpr bool _type_final = false;
-  static constexpr uint32_t _type_child_slots = 0;
-  static constexpr bool _type_child_slots_can_overflow = true;
+
   static uint32_t _GetOrAllocRuntimeTypeIndex() {
-    return 0;
+    return TypeIndex::kRoot;
   }
   static uint32_t RuntimeTypeIndex() {
-    return 0;
+    return TypeIndex::kRoot;
   }
 
+  // Default object type properties for sub-classes
+  static constexpr bool _type_final = false;
+  static constexpr uint32_t _type_child_slots = 0;
+  static constexpr bool _type_child_slots_can_overflow = true;
+  // NOTE: the following field is not type index of Object
+  // but was intended to be used by sub-classes as default value.
+  // The type index of Object is TypeIndex::kRoot
+  static constexpr uint32_t _type_index = TypeIndex::kDynamic;
+
   // Default constructor and copy constructor
   Object() {}
   // Override the copy and assign constructors to do nothing.
@@ -262,13 +269,12 @@ class Object {
    * \return The allocated type index.
    */
   TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
-      const char* key,
+      const std::string& key,
       uint32_t static_tindex,
       uint32_t parent_tindex,
       uint32_t type_child_slots,
       bool type_child_slots_can_overflow);
 
- private:
   // reference counter related operations
   /*! \brief developer function, increases reference counter. */
   inline void IncRef();
@@ -621,8 +627,8 @@ struct ObjectEqual {
  */
 #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)              \
   static const uint32_t RuntimeTypeIndex()  {                           \
-    if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) {           \
-      return _type_index;                                               \
+    if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
+      return TypeName::_type_index;                                     \
     }                                                                   \
     return _GetOrAllocRuntimeTypeIndex();                               \
   }                                                                     \
index 649a505..a42946a 100644 (file)
@@ -51,8 +51,6 @@ namespace tvm {
 class Integer;
 class DataType;
 class Expr;
-class Node;
-class NodeRef;
 
 namespace runtime {
 
@@ -516,9 +514,9 @@ class TVMPODValue_ {
     CHECK_LT(type_code_, kExtEnd);
     return static_cast<TExtension*>(value_.v_handle)[0];
   }
-  template<typename TNodeRef,
+  template<typename TObjectRef,
            typename = typename std::enable_if<
-             std::is_class<TNodeRef>::value>::type>
+             std::is_class<TObjectRef>::value>::type>
   inline bool IsObjectRef() const;
   int type_code() const {
     return type_code_;
@@ -620,8 +618,8 @@ class TVMArgValue : public TVMPODValue_ {
     return value_;
   }
   // Deferred extension handler.
-  template<typename TNodeRef>
-  inline TNodeRef AsNodeRef() const;
+  template<typename TObjectRef>
+  inline TObjectRef AsObjectRef() const;
   template<typename T,
            typename = typename std::enable_if<
            std::is_class<T>::value>::type>
@@ -834,13 +832,13 @@ class TVMRetValue : public TVMPODValue_ {
           type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
     return value_;
   }
-  // NodeRef related extenstions: in tvm/packed_func_ext.h
+  // ObjectRef related extenstions: in tvm/packed_func_ext.h
   template<typename T,
            typename = typename std::enable_if<
              std::is_class<T>::value>::type>
   inline operator T() const;
-  template<typename TNodeRef>
-  inline TNodeRef AsNodeRef() const;
+  template<typename TObjectRef>
+  inline TObjectRef AsObjectRef() const;
   // type related
   inline operator tvm::DataType() const;
   inline TVMRetValue& operator=(const tvm::DataType& other);
@@ -1306,7 +1304,7 @@ template<typename T, typename TSrc, bool is_ext, bool is_nd>
 struct TVMValueCast {
   static T Apply(const TSrc* self) {
     static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
-    return self->template AsNodeRef<T>();
+    return self->template AsObjectRef<T>();
   }
 };
 
index 40e1a52..d668984 100644 (file)
@@ -91,7 +91,7 @@ class Registry {
    *        Note that this will ignore default arg values and always require all arguments to be provided.
    *
    * \code
-   * 
+   *
    * int multiply(int x, int y) {
    *   return x * y;
    * }
@@ -115,7 +115,7 @@ class Registry {
    *        Note that this will ignore default arg values and always require all arguments to be provided.
    *
    * \code
-   * 
+   *
    * // node subclass:
    * struct Example {
    *    int doThing(int x);
@@ -143,7 +143,7 @@ class Registry {
    *        Note that this will ignore default arg values and always require all arguments to be provided.
    *
    * \code
-   * 
+   *
    * // node subclass:
    * struct Example {
    *    int doThing(int x);
@@ -168,22 +168,22 @@ class Registry {
 
   /*!
    * \brief set the body of the function to be the passed method pointer.
-   *        Used when calling a method on a Node subclass through a NodeRef subclass.
+   *        Used when calling a method on a Node subclass through a ObjectRef subclass.
    *        Note that this will ignore default arg values and always require all arguments to be provided.
    *
    * \code
-   * 
+   *
    * // node subclass:
    * struct ExampleNode: BaseNode {
    *    int doThing(int x);
    * }
-   * 
+   *
    * // noderef subclass
-   * struct Example; 
+   * struct Example;
    *
    * TVM_REGISTER_API("Example_doThing")
    * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
-   * 
+   *
    * // note that just doing:
    * // .set_body_method(&ExampleNode::doThing);
    * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
@@ -191,15 +191,15 @@ class Registry {
    * \endcode
    *
    * \param f the method pointer to forward to.
-   * \tparam TNodeRef the node reference type to call the method on
+   * \tparam TObjectRef the node reference type to call the method on
    * \tparam TNode the node type containing the method (inferred).
    * \tparam R the return type of the function (inferred).
    * \tparam Args the argument types of the function (inferred).
    */
-  template<typename TNodeRef, typename TNode, typename R, typename ...Args,
-    typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
+    typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
   Registry& set_body_method(R (TNode::*f)(Args...)) {
-    return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+    return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
       TNode* target = ref.operator->();
       // call method pointer
       return (target->*f)(params...);
@@ -208,22 +208,22 @@ class Registry {
 
   /*!
    * \brief set the body of the function to be the passed method pointer.
-   *        Used when calling a method on a Node subclass through a NodeRef subclass.
+   *        Used when calling a method on a Node subclass through a ObjectRef subclass.
    *        Note that this will ignore default arg values and always require all arguments to be provided.
    *
    * \code
-   * 
+   *
    * // node subclass:
    * struct ExampleNode: BaseNode {
    *    int doThing(int x);
    * }
-   * 
+   *
    * // noderef subclass
-   * struct Example; 
+   * struct Example;
    *
    * TVM_REGISTER_API("Example_doThing")
    * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
-   * 
+   *
    * // note that just doing:
    * // .set_body_method(&ExampleNode::doThing);
    * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
@@ -231,15 +231,15 @@ class Registry {
    * \endcode
    *
    * \param f the method pointer to forward to.
-   * \tparam TNodeRef the node reference type to call the method on
+   * \tparam TObjectRef the node reference type to call the method on
    * \tparam TNode the node type containing the method (inferred).
    * \tparam R the return type of the function (inferred).
    * \tparam Args the argument types of the function (inferred).
    */
-  template<typename TNodeRef, typename TNode, typename R, typename ...Args,
-    typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
+  template<typename TObjectRef, typename TNode, typename R, typename ...Args,
+    typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
   Registry& set_body_method(R (TNode::*f)(Args...) const) {
-    return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
+    return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
       const TNode* target = ref.operator->();
       // call method pointer
       return (target->*f)(params...);
index 3626566..3f4ee38 100644 (file)
@@ -495,7 +495,7 @@ class StageNode : public Node {
   /*! \brief Number of direct child stages, only used for group stage.*/
   int num_child_stages{0};
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("op", &op);
     v->Visit("origin_op", &origin_op);
     v->Visit("all_iter_vars", &all_iter_vars);
@@ -540,7 +540,7 @@ class ScheduleNode : public Node {
    */
   std::unordered_map<const Node*, Stage> op2stage_cache_;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("outputs", &outputs);
     v->Visit("stages", &stages);
     v->Visit("groups", &groups);
@@ -617,7 +617,7 @@ class IterVarAttrNode : public Node {
    */
   Array<Expr> pragma_values;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("iter_type", &iter_type);
     v->Visit("bind_thread", &bind_thread);
     v->Visit("prefetch_data", &prefetch_data);
@@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode {
   /*! \brief Number of parts, only factor or nparts can be given */
   Expr nparts;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("parent", &parent);
     v->Visit("outer", &outer);
     v->Visit("inner", &inner);
@@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode {
   /*! \brief The target domain */
   IterVar fused;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("outer", &outer);
     v->Visit("inner", &inner);
     v->Visit("fused", &fused);
@@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode {
   /*! \brief The inner domain */
   IterVar rebased;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("parent", &parent);
     v->Visit("rebased", &rebased);
   }
@@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode {
   /*! \brief The singleton iterator */
   IterVar iter;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("iter", &iter);
   }
 
index 1e3a768..86cb0e2 100644 (file)
@@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node {
    */
   Expr head_address;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("unit_bits", &unit_bits);
     v->Visit("max_num_bits", &max_num_bits);
     v->Visit("max_simd_bits", &max_simd_bits);
index 6471c9c..599d6ff 100644 (file)
@@ -171,7 +171,7 @@ class TensorNode : public Node {
   /*! \brief constructor */
   TensorNode() {}
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("shape", &shape);
     v->Visit("dtype", &dtype);
     v->Visit("op", &op);
index 152a27f..0d4795a 100644 (file)
@@ -87,7 +87,7 @@ class TensorIntrinNode : public Node {
   /*! \brief constructor */
   TensorIntrinNode() {}
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("op", &op);
     v->Visit("inputs", &inputs);
@@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node {
   /*! \brief scalar expression inputs */
   Array<Expr> scalar_inputs;
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("intrin", &intrin);
     v->Visit("tensors", &tensors);
     v->Visit("regions", &regions);
index e8d33cb..ec9a13b 100644 (file)
@@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node {
   /*! \brief The lowered functions */
   tvm::Array<tvm::LoweredFunc> funcs;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("target", &target);
     v->Visit("func_name", &func_name);
     v->Visit("inputs", &inputs);
@@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node {
   /*! \brief Index of the master node for calling schedule*/
   int master_idx;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("graph_func", &graph_func);
     v->Visit("use_count", &use_count);
     v->Visit("master_idx", &master_idx);
index aed3462..6966a15 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node {
   // The graph hash key is ensured always not to be 0
   mutable size_t cache_hash_key_{0};
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("inputs", &inputs);
     v->Visit("target", &target);
   }
index 3bfebe3..d8ff3bf 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- * Copyright (c) 2017 by Contributors
  * \file graph_runtime.cc
  * \brief Interface code with TVM graph runtime.
 */
 #include <dmlc/memory_io.h>
+#include <tvm/runtime/registry.h>
+
 #include <utility>
 #include "graph_runtime.h"
 
index 7b324ba..770c98e 100644 (file)
@@ -61,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node {
   std::string name;
   tvm::runtime::NDArray array;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("array", &array);
   }
 
   static constexpr const char* _type_key = "NDArrayWrapper";
-  TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node);
+  TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, tvm::Node);
 };
 
 TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode);
index 0c6f30a..599f41d 100644 (file)
@@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src.
 
 ## Modules
 - common: Internal common utilities.
+- runtime: Minimum runtime related codes.
+- node: base infra for IR/AST nodes that is dialect independent.
 - api: API function registration.
 - lang: The definition of DSL related data structure.
 - arithmetic: Arithmetic expression and set simplification.
@@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src.
 - schedule: The operations on the schedule graph before converting to IR.
 - pass: The optimization pass on the IR structure.
 - codegen: The code generator.
-- runtime: Minimum runtime related codes.
 - autotvm: The auto-tuning module.
 - relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks.
 - contrib: Contrib extension libraries.
index c25c35f..42367ef 100644 (file)
@@ -26,6 +26,7 @@
 #include <tvm/expr.h>
 #include <tvm/tensor.h>
 #include <tvm/api_registry.h>
+#include <tvm/node/serialization.h>
 
 namespace tvm {
 TVM_REGISTER_API("_format_str")
@@ -43,10 +44,10 @@ TVM_REGISTER_API("_raw_ptr")
   });
 
 TVM_REGISTER_API("_save_json")
-.set_body_typed<std::string(NodeRef)>(SaveJSON);
+.set_body_typed<std::string(ObjectRef)>(SaveJSON);
 
 TVM_REGISTER_API("_load_json")
-.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
+.set_body_typed<ObjectRef(std::string)>(LoadJSON);
 
 TVM_REGISTER_API("_TVMSetStream")
 .set_body_typed(TVMSetStream);
diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc
deleted file mode 100644 (file)
index 64805c9..0000000
+++ /dev/null
@@ -1,190 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Implementation of DSL API
- * \file dsl_api.cc
- */
-#include <dmlc/logging.h>
-#include <tvm/api_registry.h>
-#include <tvm/attrs.h>
-#include <tvm/expr.h>
-#include <vector>
-#include <string>
-
-namespace tvm {
-namespace runtime {
-
-struct APIAttrGetter : public AttrVisitor {
-  std::string skey;
-  TVMRetValue* ret;
-  bool found_ref_object{false};
-
-  void Visit(const char* key, double* value) final {
-    if (skey == key) *ret = value[0];
-  }
-  void Visit(const char* key, int64_t* value) final {
-    if (skey == key) *ret = value[0];
-  }
-  void Visit(const char* key, uint64_t* value) final {
-    CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
-        << "cannot return too big constant";
-    if (skey == key) *ret = static_cast<int64_t>(value[0]);
-  }
-  void Visit(const char* key, int* value) final {
-    if (skey == key) *ret = static_cast<int64_t>(value[0]);
-  }
-  void Visit(const char* key, bool* value) final {
-    if (skey == key) *ret = static_cast<int64_t>(value[0]);
-  }
-  void Visit(const char* key, void** value) final {
-    if (skey == key) *ret = static_cast<void*>(value[0]);
-  }
-  void Visit(const char* key, Type* value) final {
-    if (skey == key) *ret = value[0];
-  }
-  void Visit(const char* key, std::string* value) final {
-    if (skey == key) *ret = value[0];
-  }
-  void Visit(const char* key, NodeRef* value) final {
-    if (skey == key) {
-      *ret = value[0];
-      found_ref_object = true;
-    }
-  }
-  void Visit(const char* key, runtime::NDArray* value) final {
-    if (skey == key) {
-      *ret = value[0];
-      found_ref_object = true;
-    }
-  }
-  void Visit(const char* key, runtime::ObjectRef* value) final {
-    if (skey == key) {
-      *ret = value[0];
-      found_ref_object = true;
-    }
-  }
-};
-
-struct APIAttrDir : public AttrVisitor {
-  std::vector<std::string>* names;
-
-  void Visit(const char* key, double* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, int64_t* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, uint64_t* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, bool* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, int* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, void** value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, Type* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, std::string* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, NodeRef* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, runtime::NDArray* value) final {
-    names->push_back(key);
-  }
-  void Visit(const char* key, runtime::ObjectRef* value) final {
-    names->push_back(key);
-  }
-};
-
-struct NodeAPI {
-  static void GetAttr(TVMArgs args, TVMRetValue* ret) {
-    NodeRef ref = args[0];
-    Node* tnode = const_cast<Node*>(ref.get());
-    APIAttrGetter getter;
-    getter.skey = args[1].operator std::string();
-    getter.ret = ret;
-
-    bool success;
-    if (getter.skey == "type_key") {
-      *ret = tnode->GetTypeKey();
-      success = true;
-    } else if (!tnode->IsInstance<DictAttrsNode>()) {
-      tnode->VisitAttrs(&getter);
-      success = getter.found_ref_object || ret->type_code() != kNull;
-    } else {
-      // specially handle dict attr
-      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode);
-      auto it = dnode->dict.find(getter.skey);
-      if (it != dnode->dict.end()) {
-        success = true;
-        *ret = (*it).second;
-      } else {
-        success = false;
-      }
-    }
-    if (!success) {
-      LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey()
-                 << " object has no attributed " << getter.skey;
-    }
-  }
-
-  static void ListAttrNames(TVMArgs args, TVMRetValue* ret) {
-    NodeRef ref = args[0];
-    Node* tnode = const_cast<Node*>(ref.get());
-    auto names = std::make_shared<std::vector<std::string> >();
-    APIAttrDir dir;
-    dir.names = names.get();
-
-    if (!tnode->IsInstance<DictAttrsNode>()) {
-      tnode->VisitAttrs(&dir);
-    } else {
-      // specially handle dict attr
-      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode);
-      for (const auto& kv : dnode->dict) {
-        names->push_back(kv.first);
-      }
-    }
-
-    *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) {
-        int64_t i = args[0];
-        if (i == -1) {
-          *rv = static_cast<int64_t>(names->size());
-        } else {
-          *rv = (*names)[i];
-        }
-      });
-  }
-};
-
-TVM_REGISTER_GLOBAL("_NodeGetAttr")
-.set_body(NodeAPI::GetAttr);
-
-TVM_REGISTER_GLOBAL("_NodeListAttrNames")
-.set_body(NodeAPI::ListAttrNames);
-
-}  // namespace runtime
-}  // namespace tvm
index 6f7b4d7..9c3a706 100644 (file)
@@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor {
     if (!found_) path_.pop_back();
   }
 
-  std::vector<const Node*> path_;
+  std::vector<const Object*> path_;
 
  private:
   bool found_{false};
   Expr target_;
-  std::unordered_set<const Node*> visited_;
+  std::unordered_set<const Object*> visited_;
 };
 
 // get the path to the variable,
 // return empty vector to represent failure
-std::vector<const Node*> GetPath(Expr target, Expr expr) {
+std::vector<const Object*> GetPath(Expr target, Expr expr) {
   VariablePathFinder v(target);
   v.Visit(expr);
   return v.path_;
@@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor {
   const std::unordered_map<const Variable*, IntSet>& hint_map_;
   const std::unordered_map<const Variable*, IntSet>& relax_map_;
   ExprIntSetMap expr_map_;
-  std::vector<const Node*> path_;
+  std::vector<const Object*> path_;
   size_t iter_{0};
   // internal analzyer
   Analyzer analyzer_;
index 02e8079..1b576a6 100644 (file)
@@ -43,6 +43,7 @@ class SplitExpr;
  */
 class CanonicalExprNode : public BaseExprNode {
  public:
+  virtual ~CanonicalExprNode() {}
   /*!
    * \brief Return the normal Expr that is equivalent to self.
    * \note Can mutate the internal data structure.
@@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode {
   virtual Expr Normalize() const = 0;
 
   // overrides
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
   }
 
   static constexpr const char* _type_key = "arith.CanonicalExpr";
@@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
    * \return Normalized expr.
    */
   Expr Normalize(Expr expr) {
-    if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
+    if (const auto* op = expr.as<CanonicalExprNode>()) {
       return op->Normalize();
     } else {
       return expr;
@@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
     if (const auto* op = expr.as<SumExprNode>()) {
       if (op->base == 0 && op->args.size() == 1) return op->args[0];
     }
-    if (const auto* op = expr.as_derived<CanonicalExprNode>()) {
+    if (const auto* op = expr.as<CanonicalExprNode>()) {
       expr = op->Normalize();
     }
     NodePtr<SplitExprNode> n = make_node<SplitExprNode>();
index 313b34d..4094775 100644 (file)
@@ -807,6 +807,8 @@ IntSet EvalSet(Range r,
   return EvalSet(r, ConvertDomMap(dom_map));
 }
 
+TVM_REGISTER_NODE_TYPE(IntervalSetNode);
+
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
     p->stream << "IntervalSet"
index 3063618..831b444 100644 (file)
@@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode {
   Expr max_value;
 
   // visitor overload.
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("min_value", &min_value);
     v->Visit("max_value", &max_value);
   }
index a046cc4..fca9aa2 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,9 +18,9 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file intrin_rule_spirv.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/ir.h>
 #include <GLSL.std.450.h>
index e041f3a..cd3d43b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc")
 
 TVM_REGISTER_NODE_TYPE(EnvFuncNode)
 .set_creator(CreateEnvNode)
-.set_global_key([](const Node* n) {
+.set_global_key([](const Object* n) {
     return static_cast<const EnvFuncNode*>(n)->name;
   });
 
index 48b486a..04e04ae 100644 (file)
@@ -1150,6 +1150,8 @@ TVM_REGISTER_NODE_TYPE(Select);
 TVM_REGISTER_NODE_TYPE(Load);
 TVM_REGISTER_NODE_TYPE(Ramp);
 TVM_REGISTER_NODE_TYPE(Broadcast);
+TVM_REGISTER_NODE_TYPE(Shuffle);
+TVM_REGISTER_NODE_TYPE(Prefetch);
 TVM_REGISTER_NODE_TYPE(Call);
 TVM_REGISTER_NODE_TYPE(Let);
 TVM_REGISTER_NODE_TYPE(LetStmt);
index ff6a352..481a926 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,9 +18,9 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file target_info.cc
  */
+#include <tvm/runtime/registry.h>
 #include <tvm/target_info.h>
 #include <tvm/packed_func_ext.h>
 
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
new file mode 100644 (file)
index 0000000..e92ca92
--- /dev/null
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Reflection utilities.
+ * \file node/reflection.cc
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/node/node.h>
+#include <tvm/node/container.h>
+#include <tvm/node/reflection.h>
+#include <tvm/attrs.h>
+
+namespace tvm {
+
+// Attr getter.
+class AttrGetter : public AttrVisitor {
+ public:
+  const std::string& skey;
+  TVMRetValue* ret;
+
+  AttrGetter(const std::string &skey,
+             TVMRetValue* ret)
+      : skey(skey), ret(ret) {}
+
+  bool found_ref_object{false};
+
+  void Visit(const char* key, double* value) final {
+    if (skey == key) *ret = value[0];
+  }
+  void Visit(const char* key, int64_t* value) final {
+    if (skey == key) *ret = value[0];
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
+        << "cannot return too big constant";
+    if (skey == key) *ret = static_cast<int64_t>(value[0]);
+  }
+  void Visit(const char* key, int* value) final {
+    if (skey == key) *ret = static_cast<int64_t>(value[0]);
+  }
+  void Visit(const char* key, bool* value) final {
+    if (skey == key) *ret = static_cast<int64_t>(value[0]);
+  }
+  void Visit(const char* key, void** value) final {
+    if (skey == key) *ret = static_cast<void*>(value[0]);
+  }
+  void Visit(const char* key, Type* value) final {
+    if (skey == key) *ret = value[0];
+  }
+  void Visit(const char* key, std::string* value) final {
+    if (skey == key) *ret = value[0];
+  }
+
+  void Visit(const char* key, runtime::NDArray* value) final {
+    if (skey == key) {
+      *ret = value[0];
+      found_ref_object = true;
+    }
+  }
+  void Visit(const char* key, runtime::ObjectRef* value) final {
+    if (skey == key) {
+      *ret = value[0];
+      found_ref_object = true;
+    }
+  }
+};
+
+runtime::TVMRetValue ReflectionVTable::GetAttr(
+    Object* self, const std::string& field_name) const {
+  runtime::TVMRetValue ret;
+  AttrGetter getter(field_name, &ret);
+
+  bool success;
+  if (getter.skey == "type_key") {
+    ret = self->GetTypeKey();
+    success = true;
+  } else if (!self->IsInstance<DictAttrsNode>()) {
+    VisitAttrs(self, &getter);
+    success = getter.found_ref_object || ret.type_code() != kNull;
+  } else {
+    // specially handle dict attr
+    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
+    auto it = dnode->dict.find(getter.skey);
+    if (it != dnode->dict.end()) {
+      success = true;
+      ret = (*it).second;
+    } else {
+      success = false;
+    }
+  }
+  if (!success) {
+      LOG(FATAL) << "AttributeError: " << self->GetTypeKey()
+                 << " object has no attributed " << getter.skey;
+  }
+  return ret;
+}
+
+// List names;
+class AttrDir : public AttrVisitor {
+ public:
+  std::vector<std::string>* names;
+
+  void Visit(const char* key, double* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, int64_t* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, bool* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, int* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, void** value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, Type* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, std::string* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    names->push_back(key);
+  }
+  void Visit(const char* key, runtime::ObjectRef* value) final {
+    names->push_back(key);
+  }
+};
+
+std::vector<std::string>
+ReflectionVTable::ListAttrNames(Object* self) const {
+  std::vector<std::string> names;
+  AttrDir dir;
+  dir.names = &names;
+
+  if (!self->IsInstance<DictAttrsNode>()) {
+    VisitAttrs(self, &dir);
+  } else {
+    // specially handle dict attr
+    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
+    for (const auto& kv : dnode->dict) {
+      names.push_back(kv.first);
+    }
+  }
+  return names;
+}
+
+ReflectionVTable* ReflectionVTable::Global() {
+  static ReflectionVTable inst;
+  return &inst;
+}
+
+ObjectPtr<Object>
+ReflectionVTable::CreateInitObject(const std::string& type_key,
+                                   const std::string& global_key) const {
+  uint32_t tindex = Object::TypeKey2Index(type_key);
+  if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
+    LOG(FATAL) << "TypeError: " << type_key
+               << " is not registered via TVM_REGISTER_NODE_TYPE";
+  }
+  return fcreate_[tindex](global_key);
+}
+
+class NodeAttrSetter : public AttrVisitor {
+ public:
+  std::string type_key;
+  std::unordered_map<std::string, runtime::TVMArgValue> attrs;
+
+  void Visit(const char* key, double* value) final {
+    *value = GetAttr(key).operator double();
+  }
+  void Visit(const char* key, int64_t* value) final {
+    *value = GetAttr(key).operator int64_t();
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    *value = GetAttr(key).operator uint64_t();
+  }
+  void Visit(const char* key, int* value) final {
+    *value = GetAttr(key).operator int();
+  }
+  void Visit(const char* key, bool* value) final {
+    *value = GetAttr(key).operator bool();
+  }
+  void Visit(const char* key, std::string* value) final {
+    *value = GetAttr(key).operator std::string();
+  }
+  void Visit(const char* key, void** value) final {
+    *value = GetAttr(key).operator void*();
+  }
+  void Visit(const char* key, DataType* value) final {
+    *value = GetAttr(key).operator DataType();
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    *value = GetAttr(key).operator runtime::NDArray();
+  }
+  void Visit(const char* key, ObjectRef* value) final {
+    *value = GetAttr(key).operator ObjectRef();
+  }
+
+ private:
+  runtime::TVMArgValue GetAttr(const char* key) {
+    auto it = attrs.find(key);
+    if (it == attrs.end()) {
+      LOG(FATAL) << type_key << ": require field " << key;
+    }
+    runtime::TVMArgValue v = it->second;
+    attrs.erase(it);
+    return v;
+  }
+};
+
+void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
+  NodeAttrSetter setter;
+  setter.type_key = n->GetTypeKey();
+  CHECK_EQ(args.size() % 2, 0);
+  for (int i = 0; i < args.size(); i += 2) {
+    setter.attrs.emplace(args[i].operator std::string(),
+                         args[i + 1]);
+  }
+  auto* reflection = ReflectionVTable::Global();
+  reflection->VisitAttrs(n, &setter);
+
+  if (setter.attrs.size() != 0) {
+    std::ostringstream os;
+    os << setter.type_key << " does not contain field ";
+    for (const auto &kv : setter.attrs) {
+      os << " " << kv.first;
+    }
+    LOG(FATAL) << os.str();
+  }
+}
+
+// Expose to FFI APIs.
+void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
+  CHECK_EQ(args[0].type_code(), kObjectHandle);
+  Object* self = static_cast<Object*>(args[0].value().v_handle);
+  *ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
+}
+
+void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
+  CHECK_EQ(args[0].type_code(), kObjectHandle);
+  Object* self = static_cast<Object*>(args[0].value().v_handle);
+
+  auto names = std::make_shared<std::vector<std::string> >(
+      ReflectionVTable::Global()->ListAttrNames(self));
+
+  *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) {
+      int64_t i = args[0];
+      if (i == -1) {
+        *rv = static_cast<int64_t>(names->size());
+      } else {
+        *rv = (*names)[i];
+      }
+    });
+}
+
+// API function to make node.
+// args format:
+//   key1, value1, ..., key_n, value_n
+void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
+  std::string type_key = args[0];
+  std::string empty_str;
+  TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
+  auto* reflection = ReflectionVTable::Global();
+  ObjectPtr<Object> n = reflection->CreateInitObject(type_key);
+  if (n->IsInstance<BaseAttrsNode>()) {
+    static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
+  } else {
+    InitNodeByPackedArgs(n.get(), kwargs);
+  }
+  *rv = ObjectRef(n);
+}
+
+
+TVM_REGISTER_GLOBAL("_NodeGetAttr")
+.set_body(NodeGetAttr);
+
+TVM_REGISTER_GLOBAL("_NodeListAttrNames")
+.set_body(NodeListAttrNames);
+
+TVM_REGISTER_GLOBAL("make._Node")
+.set_body(MakeNode);
+
+}  // namespace tvm
similarity index 64%
rename from src/lang/reflection.cc
rename to src/node/serialization.cc
index 8e2c3fe..d270e72 100644 (file)
  */
 
 /*!
- * \file reflection.cc
- * \brief Utilities to save/load/construct TVM objects
+ * \file node/serialization.cc
+ * \brief Utilities to serialize TVM AST/IR objects.
  */
-#include <tvm/base.h>
-#include <tvm/expr.h>
-#include <tvm/attrs.h>
-#include <tvm/node/container.h>
-#include <tvm/packed_func_ext.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/packed_func.h>
 #include <dmlc/json.h>
 #include <dmlc/memory_io.h>
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/node/container.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/serialization.h>
+#include <tvm/attrs.h>
+
 #include <string>
-#include "../common/base64.h"
+#include <map>
 
-namespace dmlc {
-DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
-}  // namespace dmlc
+#include "../common/base64.h"
 
 namespace tvm {
 
-::dmlc::Registry<NodeFactoryReg>* NodeFactoryReg::Registry() {
-  return ::dmlc::Registry<NodeFactoryReg>::Get();
-}
-
-inline std::string Type2String(const Type& t) {
+inline std::string Type2String(const DataType& t) {
   return runtime::TVMType2String(Type2TVMType(t));
 }
 
-
 inline Type String2Type(std::string s) {
   return TVMType2Type(runtime::String2TVMType(s));
 }
 
-using runtime::Object;
-using runtime::ObjectRef;
-
 // indexer to index all the nodes
 class NodeIndexer : public AttrVisitor {
  public:
-  std::unordered_map<Object*, size_t> node_index{{nullptr, 0}};
-  std::vector<Object*> node_list{nullptr};
-  std::unordered_map<DLTensor*, size_t> tensor_index;
-  std::vector<DLTensor*> tensor_list;
+  std::unordered_map<Object*, size_t> node_index_{{nullptr, 0}};
+  std::vector<Object*> node_list_{nullptr};
+  std::unordered_map<DLTensor*, size_t> tensor_index_;
+  std::vector<DLTensor*> tensor_list_;
+  ReflectionVTable* reflection_ = ReflectionVTable::Global();
 
   void Visit(const char* key, double* value) final {}
   void Visit(const char* key, int64_t* value) final {}
@@ -70,17 +62,14 @@ class NodeIndexer : public AttrVisitor {
   void Visit(const char* key, bool* value) final {}
   void Visit(const char* key, std::string* value) final {}
   void Visit(const char* key, void** value) final {}
-  void Visit(const char* key, Type* value) final {}
-  void Visit(const char* key, NodeRef* value) final {
-    MakeIndex(const_cast<Node*>(value->get()));
-  }
+  void Visit(const char* key, DataType* value) final {}
 
   void Visit(const char* key, runtime::NDArray* value) final {
     DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
-    if (tensor_index.count(ptr)) return;
-    CHECK_EQ(tensor_index.size(), tensor_list.size());
-    tensor_index[ptr] = tensor_list.size();
-    tensor_list.push_back(ptr);
+    if (tensor_index_.count(ptr)) return;
+    CHECK_EQ(tensor_index_.size(), tensor_list_.size());
+    tensor_index_[ptr] = tensor_list_.size();
+    tensor_list_.push_back(ptr);
   }
 
   void Visit(const char* key, ObjectRef* value) final {
@@ -88,15 +77,14 @@ class NodeIndexer : public AttrVisitor {
   }
 
   // make index of all the children of node
-  void MakeIndex(Object* ptr) {
-    if (ptr == nullptr) return;
-    CHECK(ptr->IsInstance<Node>());
-    auto* node = static_cast<Node*>(ptr);
+  void MakeIndex(Object* node) {
+    if (node == nullptr) return;
+    CHECK(node->IsInstance<Node>());
 
-    if (node_index.count(node)) return;
-    CHECK_EQ(node_index.size(), node_list.size());
-    node_index[node] = node_list.size();
-    node_list.push_back(node);
+    if (node_index_.count(node)) return;
+    CHECK_EQ(node_index_.size(), node_list_.size());
+    node_index_[node] = node_list_.size();
+    node_list_.push_back(node);
 
     if (node->IsInstance<ArrayNode>()) {
       ArrayNode* n = static_cast<ArrayNode*>(node);
@@ -115,7 +103,7 @@ class NodeIndexer : public AttrVisitor {
         MakeIndex(const_cast<Object*>(kv.second.get()));
       }
     } else {
-      static_cast<Node*>(node)->VisitAttrs(this);
+      reflection_->VisitAttrs(node, this);
     }
   }
 };
@@ -123,17 +111,17 @@ class NodeIndexer : public AttrVisitor {
 // use map so attributes are ordered.
 using AttrMap = std::map<std::string, std::string>;
 
-// A Node structure for JSON node.
+/*! \brief Node structure for json format. */
 struct JSONNode {
-  // The type key of the data
+  /*! \brief The type of key of the object. */
   std::string type_key;
-  // The global key for global object
+  /*! \brief The global key for global object. */
   std::string global_key;
-  // the attributes
+  /*! \brief the attributes */
   AttrMap attrs;
-  // container keys
+  /*! \brief keys of a map. */
   std::vector<std::string> keys;
-  // container data
+  /*! \brief values of a map or array. */
   std::vector<size_t> data;
 
   void Save(dmlc::JSONWriter *writer) const {
@@ -169,11 +157,14 @@ struct JSONNode {
   }
 };
 
+// Helper class to populate the json node
+// using the existing index.
 class JSONAttrGetter : public AttrVisitor {
  public:
   const std::unordered_map<Object*, size_t>* node_index_;
   const std::unordered_map<DLTensor*, size_t>* tensor_index_;
   JSONNode* node_;
+  ReflectionVTable* reflection_ = ReflectionVTable::Global();
 
   void Visit(const char* key, double* value) final {
     node_->attrs[key] = std::to_string(*value);
@@ -196,40 +187,36 @@ class JSONAttrGetter : public AttrVisitor {
   void Visit(const char* key, void** value) final {
     LOG(FATAL) << "not allowed to serialize a pointer";
   }
-  void Visit(const char* key, Type* value) final {
+  void Visit(const char* key, DataType* value) final {
     node_->attrs[key] = Type2String(*value);
   }
-  void Visit(const char* key, NodeRef* value) final {
-    node_->attrs[key] = std::to_string(
-        node_index_->at(const_cast<Node*>(value->get())));
-  }
+
   void Visit(const char* key, runtime::NDArray* value) final {
     node_->attrs[key] = std::to_string(
         tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
   }
+
   void Visit(const char* key, ObjectRef* value) final {
-    LOG(FATAL) << "Do not support json serialize non-node object";
+    node_->attrs[key] = std::to_string(
+        node_index_->at(const_cast<Object*>(value->get())));
   }
+
   // Get the node
-  void Get(Object* ptr) {
-    if (ptr == nullptr) {
+  void Get(Object* node) {
+    if (node == nullptr) {
       node_->type_key.clear();
       return;
     }
-    CHECK(ptr->IsInstance<Node>());
-    auto* node = static_cast<Node*>(ptr);
     node_->type_key = node->GetTypeKey();
+    node_->global_key = reflection_->GetGlobalKey(node);
+    // No need to recursively visit fields of global singleton
+    // They are registered via the environment.
+    if (node_->global_key.length() != 0) return;
 
-    // sepcially handle global object
-    auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
-    CHECK(f != nullptr)
-        << "Node type \'" << node_->type_key << "\' is not registered in TVM";
-    if (f->fglobal_key != nullptr) {
-      node_->global_key = f->fglobal_key(node);
-      return;
-    }
+    // populates the fields.
     node_->attrs.clear();
     node_->data.clear();
+
     if (node->IsInstance<ArrayNode>()) {
       ArrayNode* n = static_cast<ArrayNode*>(node);
       for (size_t i = 0; i < n->data.size(); ++i) {
@@ -252,23 +239,22 @@ class JSONAttrGetter : public AttrVisitor {
             node_index_->at(const_cast<Object*>(kv.second.get())));
       }
     } else {
-      // do not need to recover content of global singleton object
-      // they are registered via the environment
-      auto* f = dmlc::Registry<NodeFactoryReg>::Find(node->GetTypeKey());
-      if (f != nullptr && f->fglobal_key != nullptr) return;
       // recursively index normal object.
-      node->VisitAttrs(this);
+      reflection_->VisitAttrs(node, this);
     }
   }
 };
 
+// Helper class to set the attributes of a node
+// from given json node.
 class JSONAttrSetter : public AttrVisitor {
  public:
   const std::vector<ObjectPtr<Object> >* node_list_;
   const std::vector<runtime::NDArray>* tensor_list_;
-
   JSONNode* node_;
 
+  ReflectionVTable* reflection_ = ReflectionVTable::Global();
+
   std::string GetValue(const char* key) const {
     auto it = node_->attrs.find(key);
     if (it == node_->attrs.end()) {
@@ -305,16 +291,10 @@ class JSONAttrSetter : public AttrVisitor {
   void Visit(const char* key, void** value) final {
     LOG(FATAL) << "not allowed to deserialize a pointer";
   }
-  void Visit(const char* key, Type* value) final {
+  void Visit(const char* key, DataType* value) final {
     std::string stype = GetValue(key);
     *value = String2Type(stype);
   }
-  void Visit(const char* key, NodeRef* value) final {
-    size_t index;
-    ParseValue(key, &index);
-    CHECK_LE(index, node_list_->size());
-    *value = NodeRef(node_list_->at(index));
-  }
   void Visit(const char* key, runtime::NDArray* value) final {
     size_t index;
     ParseValue(key, &index);
@@ -322,14 +302,15 @@ class JSONAttrSetter : public AttrVisitor {
     *value = tensor_list_->at(index);
   }
   void Visit(const char* key, ObjectRef* value) final {
-    LOG(FATAL) << "Do not support json serialize non-node object";
+    size_t index;
+    ParseValue(key, &index);
+    CHECK_LE(index, node_list_->size());
+    *value = ObjectRef(node_list_->at(index));
   }
   // set node to be current JSONNode
-  void Set(Object* ptr) {
-    if (ptr == nullptr) return;
+  void Set(Object* node) {
+    if (node == nullptr) return;
 
-    CHECK(ptr->IsInstance<Node>());
-    auto* node = static_cast<Node*>(ptr);
     if (node->IsInstance<ArrayNode>()) {
       ArrayNode* n = static_cast<ArrayNode*>(node);
       n->data.clear();
@@ -351,7 +332,7 @@ class JSONAttrSetter : public AttrVisitor {
             = ObjectRef(node_list_->at(node_->data[i]));
       }
     } else {
-      node->VisitAttrs(this);
+      reflection_->VisitAttrs(node, this);
     }
   }
 };
@@ -393,18 +374,18 @@ struct JSONGraph {
     NodeIndexer indexer;
     indexer.MakeIndex(const_cast<Object*>(root.get()));
     JSONAttrGetter getter;
-    getter.node_index_ = &indexer.node_index;
-    getter.tensor_index_ = &indexer.tensor_index;
-    for (Object* n : indexer.node_list) {
+    getter.node_index_ = &indexer.node_index_;
+    getter.tensor_index_ = &indexer.tensor_index_;
+    for (Object* n : indexer.node_list_) {
       JSONNode jnode;
       getter.node_ = &jnode;
       getter.Get(n);
       g.nodes.emplace_back(std::move(jnode));
     }
     g.attrs["tvm_version"] = TVM_VERSION;
-    g.root = indexer.node_index.at(const_cast<Object*>(root.get()));
+    g.root = indexer.node_index_.at(const_cast<Object*>(root.get()));
     // serialize tensor
-    for (DLTensor* tensor : indexer.tensor_list) {
+    for (DLTensor* tensor : indexer.tensor_list_) {
       std::string blob;
       dmlc::MemoryStringStream mstrm(&blob);
       common::Base64OutStream b64strm(&mstrm);
@@ -416,7 +397,7 @@ struct JSONGraph {
   }
 };
 
-std::string SaveJSON(const NodeRef& n) {
+std::string SaveJSON(const ObjectRef& n) {
   auto jgraph = JSONGraph::Create(n);
   std::ostringstream os;
   dmlc::JSONWriter writer(&os);
@@ -424,8 +405,7 @@ std::string SaveJSON(const NodeRef& n) {
   return os.str();
 }
 
-ObjectPtr<Object> LoadJSON_(std::string json_str) {
-  LOG(INFO) << json_str;
+ObjectRef LoadJSON(std::string json_str) {
   std::istringstream is(json_str);
   dmlc::JSONReader reader(&is);
   JSONGraph jgraph;
@@ -442,16 +422,18 @@ ObjectPtr<Object> LoadJSON_(std::string json_str) {
     CHECK(temp.Load(&b64strm));
     tensors.emplace_back(temp);
   }
+  ReflectionVTable* reflection = ReflectionVTable::Global();
+
   // node 0 is always null
   nodes.reserve(jgraph.nodes.size());
+
   for (const JSONNode& jnode : jgraph.nodes) {
     if (jnode.type_key.length() != 0) {
-      auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key);
-      CHECK(f != nullptr)
-          << "Node type \'" << jnode.type_key << "\' is not registered in TVM";
-      nodes.emplace_back(f->fcreator(jnode.global_key));
+      ObjectPtr<Object> node =
+          reflection->CreateInitObject(jnode.type_key, jnode.global_key);
+      nodes.emplace_back(node);
     } else {
-      nodes.emplace_back(NodePtr<Node>());
+      nodes.emplace_back(ObjectPtr<Object>());
     }
   }
   CHECK_EQ(nodes.size(), jgraph.nodes.size());
@@ -467,101 +449,6 @@ ObjectPtr<Object> LoadJSON_(std::string json_str) {
       setter.Set(nodes[i].get());
     }
   }
-  return nodes.at(jgraph.root);
+  return ObjectRef(nodes.at(jgraph.root));
 }
-
-class NodeAttrSetter : public AttrVisitor {
- public:
-  std::string type_key;
-  std::unordered_map<std::string, runtime::TVMArgValue> attrs;
-
-  void Visit(const char* key, double* value) final {
-    *value = GetAttr(key).operator double();
-  }
-  void Visit(const char* key, int64_t* value) final {
-    *value = GetAttr(key).operator int64_t();
-  }
-  void Visit(const char* key, uint64_t* value) final {
-    *value = GetAttr(key).operator uint64_t();
-  }
-  void Visit(const char* key, int* value) final {
-    *value = GetAttr(key).operator int();
-  }
-  void Visit(const char* key, bool* value) final {
-    *value = GetAttr(key).operator bool();
-  }
-  void Visit(const char* key, std::string* value) final {
-    *value = GetAttr(key).operator std::string();
-  }
-  void Visit(const char* key, void** value) final {
-    *value = GetAttr(key).operator void*();
-  }
-  void Visit(const char* key, Type* value) final {
-    *value = GetAttr(key).operator Type();
-  }
-  void Visit(const char* key, NodeRef* value) final {
-    *value = GetAttr(key).operator NodeRef();
-  }
-  void Visit(const char* key, runtime::NDArray* value) final {
-    *value = GetAttr(key).operator runtime::NDArray();
-  }
-  void Visit(const char* key, ObjectRef* value) final {
-    *value = GetAttr(key).operator ObjectRef();
-  }
-
- private:
-  runtime::TVMArgValue GetAttr(const char* key) {
-    auto it = attrs.find(key);
-    if (it == attrs.end()) {
-      LOG(FATAL) << type_key << ": require field " << key;
-    }
-    runtime::TVMArgValue v = it->second;
-    attrs.erase(it);
-    return v;
-  }
-};
-
-
-void InitNodeByPackedArgs(Node* n, const TVMArgs& args) {
-  NodeAttrSetter setter;
-  setter.type_key = n->GetTypeKey();
-  CHECK_EQ(args.size() % 2, 0);
-  for (int i = 0; i < args.size(); i += 2) {
-    setter.attrs.emplace(args[i].operator std::string(),
-                         args[i + 1]);
-  }
-  n->VisitAttrs(&setter);
-  if (setter.attrs.size() != 0) {
-    std::ostringstream os;
-    os << setter.type_key << " does not contain field ";
-    for (const auto &kv : setter.attrs) {
-      os << " " << kv.first;
-    }
-    LOG(FATAL) << os.str();
-  }
-}
-
-// API function to make node.
-// args format:
-//   key1, value1, ..., key_n, value_n
-void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
-  std::string type_key = args[0];
-  std::string empty_str;
-  auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key);
-  CHECK(f != nullptr)
-      << "Node type \'" << type_key << "\' is not registered in TVM";
-  TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
-  CHECK(f->fglobal_key == nullptr)
-      << "Cannot make node type \'" << type_key << "\' with global_key.";
-  NodePtr<Node> n = f->fcreator(empty_str);
-  if (n->IsInstance<BaseAttrsNode>()) {
-    static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
-  } else {
-    InitNodeByPackedArgs(n.get(), kwargs);
-  }
-  *rv = NodeRef(n);
-}
-
-TVM_REGISTER_GLOBAL("make._Node")
-.set_body(MakeNode);
 }  // namespace tvm
index e09ae06..65f5eed 100644 (file)
@@ -59,7 +59,7 @@ struct CachedFuncNode : public Node {
   /*! \brief Parameter usage states in the shape function. */
   tvm::Array<Integer> shape_func_param_states;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("target", &target);
     v->Visit("func_name", &func_name);
     v->Visit("inputs", &inputs);
@@ -84,7 +84,7 @@ class CCacheKeyNode : public Node {
   /*! \brief The hardware target.*/
   Target target;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("source_func", &source_func);
     v->Visit("target", &target);
   }
@@ -141,7 +141,7 @@ class CCacheValueNode : public Node {
   /*! \brief usage statistics */
   int use_count{0};
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("cached_func", &cached_func);
     v->Visit("use_count", &use_count);
   }
@@ -191,7 +191,7 @@ class CompileEngineNode : public Node {
   virtual void Clear() = 0;
 
   // VisitAttrs
-  void VisitAttrs(AttrVisitor*) final {}
+  void VisitAttrs(AttrVisitor*) {}
 
   static constexpr const char* _type_key = "relay.CompileEngine";
   TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
index 2703b1c..8c6dace 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/tvm/relay/interpreter.cc
  * \brief An interpreter for the Relay IR.
  */
@@ -116,6 +115,8 @@ RefValue RefValueNode::make(Value value) {
 TVM_REGISTER_API("relay._make.RefValue")
 .set_body_typed(RefValueNode::make);
 
+TVM_REGISTER_NODE_TYPE(RefValueNode);
+
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<RefValueNode>([](const RefValueNode* node,
                                tvm::IRPrinter* p) {
@@ -135,6 +136,8 @@ ConstructorValue ConstructorValueNode::make(int32_t tag,
 TVM_REGISTER_API("relay._make.ConstructorValue")
 .set_body_typed(ConstructorValueNode::make);
 
+TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
+
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
                                        tvm::IRPrinter* p) {
@@ -207,7 +210,7 @@ class InterpreterStateNode : public Node {
   /*! \brief The call stack of the interpreter. */
   Stack stack;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("current_expr", &current_expr);
     v->Visit("stack", &stack);
   }
index 0b9a299..9bde3a0 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file param_dict.cc
  * \brief Implementation and registration of parameter dictionary
  * serializing/deserializing functions.
  */
-#include "param_dict.h"
-
+#include <tvm/runtime/registry.h>
 #include <dmlc/memory_io.h>
 
 #include <string>
 #include <vector>
 #include <utility>
 
+#include "param_dict.h"
+
+
+
 namespace tvm {
 namespace relay {
 
index 296c71c..e7695dc 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -45,7 +45,7 @@ struct NamedNDArrayNode : public ::tvm::Node {
   std::string name;
   tvm::runtime::NDArray array;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("array", &array);
   }
index 9c670bf..12cebe5 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/tvm/ir/adt.cc
  * \brief AST nodes for Relay algebraic data types (ADTs).
  */
index 2032112..80f0790 100644 (file)
@@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(SourceNameNode)
 .set_creator(GetSourceNameNode)
-.set_global_key([](const Node* n) {
+.set_global_key([](const Object* n) {
     return static_cast<const SourceNameNode*>(n)->name;
   });
 
@@ -88,7 +88,7 @@ TVM_REGISTER_NODE_TYPE(IdNode);
 
 TVM_REGISTER_API("relay._base.set_span")
 .set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
-    auto rn = node_ref.as_derived<RelayNode>();
+    auto rn = node_ref.as<RelayNode>();
     CHECK(rn);
     rn->span = sp;
 });
index b0f889c..7bfe41c 100644 (file)
@@ -195,7 +195,7 @@ NodePtr<Node> CreateOp(const std::string& name) {
 
 TVM_REGISTER_NODE_TYPE(OpNode)
 .set_creator(CreateOp)
-.set_global_key([](const Node* n) {
+.set_global_key([](const Object* n) {
     return static_cast<const OpNode*>(n)->name;
   });
 
index 394ec7e..b2a8396 100644 (file)
@@ -32,7 +32,7 @@
  *  - Otherwise, inline if the node is at the end of a scope and is used at most once.
  */
 
-#include <dmlc/json.h>
+#include <tvm/node/serialization.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/module.h>
 #include <tvm/relay/pattern_functor.h>
@@ -214,7 +214,7 @@ class PrettyPrinter :
   }
 
   Doc PrintFinal(const NodeRef& node) {
-    if (node.as_derived<ExprNode>()) {
+    if (node.as<ExprNode>()) {
       Expr expr = Downcast<Expr>(node);
       dg_ = DependencyGraph::Create(&arena_, expr);
     }
@@ -237,13 +237,13 @@ class PrettyPrinter :
   std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
 
   Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
-    if (node.as_derived<ExprNode>()) {
+    if (node.as<ExprNode>()) {
       return PrintExpr(Downcast<Expr>(node), meta, try_inline);
-    } else if (node.as_derived<TypeNode>()) {
+    } else if (node.as<TypeNode>()) {
       return PrintType(Downcast<Type>(node), meta);
-    } else if (node.as_derived<PatternNode>()) {
+    } else if (node.as<PatternNode>()) {
       return PrintPattern(Downcast<Pattern>(node), meta);
-    } else if (node.as_derived<ModuleNode>()) {
+    } else if (node.as<ModuleNode>()) {
       return PrintMod(Downcast<Module>(node));
     } else {
       Doc doc;
@@ -924,14 +924,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
   void Visit(const char* key, DataType* value) final {
     PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
   }
-  void Visit(const char* key, NodeRef* value) final {
-    PrintKV(key, parent_->PrintAttr(*value));
-  }
   void Visit(const char* key, runtime::NDArray* value) final {
     LOG(FATAL) << "do not allow NDarray as argument";
   }
   void Visit(const char* key, runtime::ObjectRef* obj) final {
-    LOG(FATAL) << "do not allow Object as argument";
+    PrintKV(key, parent_->PrintAttr(*obj));
   }
 
  private:
index cde68c5..b93d9cc 100644 (file)
@@ -132,7 +132,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
     if (const TypeVarNode* tin = new_type_param.as<TypeVarNode>()) {
       type_params.push_back(GetRef<TypeVar>(tin));
     } else {
-      LOG(FATAL) << new_type_param << std::endl;
+      LOG(FATAL) << new_type_param;
     }
   }
 
@@ -141,10 +141,10 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) {
     auto new_type_cs = VisitType(type_cs);
     changed = changed || !new_type_cs.same_as(type_cs);
     if (const TypeConstraintNode* tin =
-        new_type_cs.as_derived<TypeConstraintNode>()) {
+        new_type_cs.as<TypeConstraintNode>()) {
       type_constraints.push_back(GetRef<TypeConstraint>(tin));
     } else {
-      LOG(FATAL) << new_type_cs << std::endl;
+      LOG(FATAL) << new_type_cs;
     }
   }
 
index 9143ae3..bbfb97c 100644 (file)
@@ -140,7 +140,7 @@ class LayoutAlternatedExprNode : public TempExprNode {
     return tmp_memorizer.Transform(value, new_layout, old_layout);
   }
 
-  void VisitAttrs(AttrVisitor *v) final {
+  void VisitAttrs(AttrVisitor *v) {
     v->Visit("value", &value);
     v->Visit("old_layout", &old_layout);
     v->Visit("new_layout", &new_layout);
index 94d09b7..21992ab 100644 (file)
@@ -18,8 +18,6 @@
  */
 
 /*!
- * Copyright (c) 2018 by Contributors
- *
  * \file deivce_annotation.cc
  * \brief Passes to rewrite annotated program and retrieve the device allocation
  * of expression.
@@ -46,13 +44,15 @@ namespace relay {
 namespace {
 
 bool IsOnDeviceNode(const ExprNode* node) {
-  const auto* call_node = dynamic_cast<const CallNode*>(node);
-  return call_node != nullptr && call_node->attrs.as<OnDeviceAttrs>();
+  if (!node->IsInstance<CallNode>()) return false;
+  const auto* call_node = static_cast<const CallNode*>(node);
+  return call_node->attrs.as<OnDeviceAttrs>();
 }
 
 bool IsDeviceCopyNode(const ExprNode* node) {
-  const auto* call_node = dynamic_cast<const CallNode*>(node);
-  return call_node != nullptr && call_node->attrs.as<DeviceCopyAttrs>();
+  if (!node->IsInstance<CallNode>()) return false;
+  const auto* call_node = static_cast<const CallNode*>(node);
+  return call_node->attrs.as<DeviceCopyAttrs>();
 }
 
 }  // namespace
@@ -447,7 +447,8 @@ class DeviceInfo {
   static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
     if (IsDeviceCopyNode(node)) {
       return node;
-    } else if (const auto* call_node = dynamic_cast<const CallNode*>(node)) {
+    } else if (node->IsInstance<CallNode>()) {
+      const auto* call_node = static_cast<const CallNode*>(node);
       if (const auto* fn = call_node->op.as<FunctionNode>()) {
         const ExprNode* body = fn->body.operator->();
         if (IsDeviceCopyNode(body)) {
@@ -472,7 +473,8 @@ class DeviceInfo {
     for (auto it = post_visitor_.post_dfs_order_.crbegin();
          it != post_visitor_.post_dfs_order_.crend(); ++it) {
       if (const auto* node = GetDeviceCopyNode(it->first)) {
-        last_copy_node = dynamic_cast<const CallNode*>(node);
+        CHECK(node->IsInstance<CallNode>());
+        last_copy_node = static_cast<const CallNode*>(node);
         const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
         cur_dev_type = attrs->src_dev_type;
         if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
index 612abab..a5d0487 100644 (file)
@@ -37,14 +37,14 @@ Expr EtaExpand(const Expr& e, const Module& mod) {
   Type ret_type;
 
   if (e->IsInstance<GlobalVarNode>()) {
-    auto gvar_node = e.as_derived<GlobalVarNode>();
+    auto gvar_node = e.as<GlobalVarNode>();
     auto func = mod->Lookup(GetRef<GlobalVar>(gvar_node));
     original_params = func->params;
     original_type_params = func->type_params;
     ret_type = func->ret_type;
   } else {
     CHECK(e->IsInstance<FunctionNode>());
-    auto func = GetRef<Function>(e.as_derived<FunctionNode>());
+    auto func = GetRef<Function>(e.as<FunctionNode>());
     original_params = func->params;
     original_type_params = func->type_params;
     ret_type = func->ret_type;
index 6defa35..e13a50a 100644 (file)
@@ -176,7 +176,7 @@ class ScaledExprNode : public TempExprNode {
     return value;
   }
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("value", &value);
     v->Visit("axes", &axes);
     v->Visit("scale", &scale);
@@ -664,7 +664,7 @@ class BackwardTransformerNode :
   }
 
   // solver is not serializable.
-  void VisitAttrs(tvm::AttrVisitor* v) final {}
+  void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer";
   TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node);
index 6c66d6e..f7d463a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -47,7 +47,7 @@ class TempRealizer : private ExprMutator {
       return it->second;
     } else {
       Expr res;
-      if (const auto* temp = expr.as_derived<TempExprNode>()) {
+      if (const auto* temp = expr.as<TempExprNode>()) {
         res = temp->Realize();
 
       } else {
index 928d8bd..d268862 100644 (file)
@@ -102,7 +102,7 @@ class ModulePassNode : public PassNode {
 
   ModulePassNode() = default;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("pass_info", &pass_info);
   }
 
@@ -156,7 +156,7 @@ class FunctionPassNode : public PassNode {
 
   FunctionPassNode() = default;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("pass_info", &pass_info);
   }
 
@@ -211,7 +211,7 @@ class SequentialNode : public PassNode {
   /*! \brief A list of passes that used to compose a sequential pass. */
   tvm::Array<Pass> passes;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("pass_info", &pass_info);
     v->Visit("passes", &passes);
   }
index 38ffd9b..31e95fc 100644 (file)
@@ -41,7 +41,7 @@ class QAnnotateExprNode : public TempExprNode {
   Expr expr;
   QAnnotateKind kind;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("expr", &expr);
     v->Visit("kind", &kind);
   }
index 6c7dc50..f66aed3 100644 (file)
@@ -42,7 +42,7 @@ class QPartitionExprNode : public TempExprNode {
   /*! \brief The original expression */
   Expr expr;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("expr", &expr);
   }
 
index dafbc1d..3d0e71e 100644 (file)
@@ -18,8 +18,6 @@
  */
 
 /*!
- * Copyright (c) 2018 by Contributors
- *
  * \file quantize.cc
  *
  * \brief transform a graph to a low-bit graph
index f193f9a..412bce0 100644 (file)
@@ -76,7 +76,7 @@ class QConfigNode : public Node {
   bool round_for_shift = true;
   Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
 
-  void VisitAttrs(AttrVisitor* v) final {
+  void VisitAttrs(AttrVisitor* v) {
     v->Visit("nbit_input", &nbit_input);
     v->Visit("nbit_weight", &nbit_weight);
     v->Visit("nbit_activation", &nbit_activation);
index cd367fd..bdd0d73 100644 (file)
@@ -56,7 +56,7 @@ class QRealizeIntExprNode : public QRealizeExprNode {
   Expr dom_scale;
   DataType dtype;
 
-  void VisitAttrs(tvm::AttrVisitor* v) final {
+  void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("data", &data);
     v->Visit("dom_scale", &dom_scale);
     v->Visit("dtype", &dtype);
index f2bf46a..6035790 100644 (file)
@@ -153,7 +153,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
   // default: unify only if alpha-equal
   Type VisitTypeDefault_(const Node* op, const Type& tn) final {
     NodeRef nr = GetRef<NodeRef>(op);
-    Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
+    Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
     if (!AlphaEqual(t1, tn)) {
       return Type(nullptr);
     }
@@ -411,7 +411,7 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
 
   void VisitTypeDefault_(const Node* op) override {
     NodeRef nr = GetRef<NodeRef>(op);
-    Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
+    Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
     UpdateRelSet(t);
   }
 
@@ -495,7 +495,7 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
 
   void VisitTypeDefault_(const Node* op) override {
     NodeRef nr = GetRef<NodeRef>(op);
-    Type t = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
+    Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
     TransferLinks(t);
   }
 
index 90c3de8..fe1cc14 100644 (file)
@@ -280,7 +280,7 @@ TVM_REGISTER_API("relay._analysis.free_vars")
 TVM_REGISTER_API("relay._analysis.bound_vars")
   .set_body([](TVMArgs args, TVMRetValue* ret) {
       NodeRef x = args[0];
-      if (x.as_derived<ExprNode>()) {
+      if (x.as<ExprNode>()) {
         *ret = BoundVars(Downcast<Expr>(x));
       } else {
         *ret = BoundVars(Downcast<Pattern>(x));
@@ -294,7 +294,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     NodeRef x = args[0];
     Module mod = args[1];
-    if (x.as_derived<TypeNode>()) {
+    if (x.as<TypeNode>()) {
       *ret = FreeTypeVars(Downcast<Type>(x), mod);
     } else {
       *ret = FreeTypeVars(Downcast<Expr>(x), mod);
@@ -305,7 +305,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars")
   .set_body([](TVMArgs args, TVMRetValue* ret) {
       NodeRef x = args[0];
       Module mod = args[1];
-      if (x.as_derived<TypeNode>()) {
+      if (x.as<TypeNode>()) {
         *ret = BoundTypeVars(Downcast<Type>(x), mod);
       } else {
         *ret = BoundTypeVars(Downcast<Expr>(x), mod);
@@ -316,7 +316,7 @@ TVM_REGISTER_API("relay._analysis.all_type_vars")
   .set_body([](TVMArgs args, TVMRetValue* ret) {
       NodeRef x = args[0];
       Module mod = args[1];
-      if (x.as_derived<TypeNode>()) {
+      if (x.as<TypeNode>()) {
         *ret = AllTypeVars(Downcast<Type>(x), mod);
       } else {
         *ret = AllTypeVars(Downcast<Expr>(x), mod);
index d07612f..5d71c2f 100644 (file)
@@ -73,13 +73,12 @@ class TypeContext {
     return child_tindex == parent_tindex;
   }
 
-  uint32_t GetOrAllocRuntimeTypeIndex(const char* key,
+  uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey,
                                       uint32_t static_tindex,
                                       uint32_t parent_tindex,
                                       uint32_t num_child_slots,
                                       bool child_slots_can_overflow) {
     std::lock_guard<std::mutex> lock(mutex_);
-    std::string skey = key;
     auto it = type_key2index_.find(skey);
     if (it != type_key2index_.end()) {
       return it->second;
@@ -106,7 +105,7 @@ class TypeContext {
           << "Conflicting static index " << static_tindex
           << " between " << type_table_[allocated_tindex].name
           << " and "
-          << key;
+          << skey;
     } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
       // allocate the slot from parent's reserved pool
       allocated_tindex = parent_tindex + pinfo.allocated_slots;
@@ -152,11 +151,10 @@ class TypeContext {
     return type_table_[tindex].name_hash;
   }
 
-  uint32_t TypeKey2Index(const char* key) {
-    std::string skey = key;
+  uint32_t TypeKey2Index(const std::string& skey) {
     auto it = type_key2index_.find(skey);
     CHECK(it != type_key2index_.end())
-        << "Cannot find type " << key;
+        << "Cannot find type " << skey;
     return it->second;
   }
 
@@ -176,7 +174,7 @@ class TypeContext {
   std::unordered_map<std::string, uint32_t> type_key2index_;
 };
 
-uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key,
+uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key,
                                             uint32_t static_tindex,
                                             uint32_t parent_tindex,
                                             uint32_t num_child_slots,
@@ -198,7 +196,7 @@ size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
   return TypeContext::Global()->TypeIndex2KeyHash(tindex);
 }
 
-uint32_t Object::TypeKey2Index(const char* key) {
+uint32_t Object::TypeKey2Index(const std::string& key) {
   return TypeContext::Global()->TypeKey2Index(key);
 }
 
@@ -210,7 +208,7 @@ class TVMObjectCAPI {
     }
   }
 
-  static uint32_t TypeKey2Index(const char* type_key) {
+  static uint32_t TypeKey2Index(const std::string& type_key) {
     return Object::TypeKey2Index(type_key);
   }
 };
index a7237db..6e43b40 100644 (file)
@@ -21,6 +21,7 @@
 #include <gtest/gtest.h>
 #include <topi/cuda/injective.h>
 #include <tvm/operation.h>
+#include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/build_module.h>
 
index 4baf649..70a4c32 100644 (file)
@@ -20,6 +20,7 @@
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/ir.h>