[NODE][REFACTOR] Rename IRFunctor->NodeFunctor, use func pointer (#4247)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 1 Nov 2019 23:34:42 +0000 (16:34 -0700)
committerGitHub <noreply@github.com>
Fri, 1 Nov 2019 23:34:42 +0000 (16:34 -0700)
* [NODE][REFACTOR] Rename IRFunctor->NodeFunctor, use function pointer for dispatching.

Previously we used std::function for the functor dispatching.
It introduces additional overhead and problems during dll destruction(of std::function).

This PR changes the std::function to function pointers.
This change a bit restrictions around the set_dispatch that we can get around,
but will improve the general efficiency by reducing one level of indirection in the std::function.
We also no longer need special marcos to register functions to the Functor.

50 files changed:
include/tvm/expr.h
include/tvm/ir_functor_ext.h
include/tvm/ir_mutator.h
include/tvm/ir_visitor.h
include/tvm/node/functor.h [new file with mode: 0644]
include/tvm/node/ir_functor.h [deleted file]
include/tvm/node/reflection.h
include/tvm/relay/expr_functor.h
include/tvm/relay/pattern_functor.h
nnvm/src/compiler/compile_engine.cc
nnvm/src/compiler/graph_hash.cc
nnvm/src/compiler/graph_runtime.cc
src/arithmetic/const_int_bound.cc
src/arithmetic/int_set.cc
src/arithmetic/modular_set.cc
src/codegen/build_module.cc
src/lang/api_registry.cc
src/lang/attr_functor.h
src/lang/attrs.cc
src/lang/buffer.cc
src/lang/channel.cc
src/lang/data_layout.cc
src/lang/expr.cc
src/lang/ir.cc
src/lang/lowered_func.cc
src/lang/target_info.cc
src/lang/tensor.cc
src/op/compute_op.cc
src/op/extern_op.cc
src/op/hybrid_op.cc
src/op/placeholder_op.cc
src/op/scan_op.cc
src/op/tensor_compute_op.cc
src/pass/ir_mutator.cc
src/pass/ir_visitor.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/interpreter.cc
src/relay/ir/adt.cc
src/relay/ir/base.cc
src/relay/ir/expr.cc
src/relay/ir/module.cc
src/relay/ir/op.cc
src/relay/ir/type.cc
src/relay/ir/type_functor.h
src/relay/pass/pass_manager.cc
src/relay/pass/quantize/quantize.cc
src/schedule/schedule_lang.cc
tests/cpp/attrs_test.cc
tests/cpp/ir_functor_test.cc
tests/cpp/ir_mutator_test.cc

index ea57815..fc52421 100644 (file)
@@ -32,7 +32,7 @@
 #include "dtype.h"
 #include "node/node.h"
 #include "node/container.h"
-#include "node/ir_functor.h"
+#include "node/functor.h"
 #include "runtime/c_runtime_api.h"
 
 namespace tvm {
@@ -487,7 +487,7 @@ class IRPrinter {
   /*! \brief Print indent to the stream */
   TVM_DLL void PrintIndent();
   // Allow registration to be printer.
-  using FType = IRFunctor<void(const ObjectRef&, IRPrinter *)>;
+  using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
   TVM_DLL static FType& vtable();
 };
 
index 54a5eff..04ce793 100644 (file)
@@ -24,8 +24,9 @@
 #ifndef TVM_IR_FUNCTOR_EXT_H_
 #define TVM_IR_FUNCTOR_EXT_H_
 
-#include "tvm/node/ir_functor.h"
-#include "ir.h"
+#include <tvm/node/functor.h>
+#include <tvm/ir.h>
+
 #include <utility>
 
 namespace tvm {
@@ -104,7 +105,7 @@ template<typename R, typename ...Args>
 class ExprFunctor<R(const Expr& n, Args...)> {
  private:
   using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
-  using FType = IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
   /*! \brief the result type of this functor */
@@ -213,7 +214,7 @@ template<typename R, typename ...Args>
 class StmtFunctor<R(const Stmt& n, Args... args)> {
  private:
   using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
-  using FType = IRFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
+  using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
 
  public:
   /*! \brief the result type of this functor */
index c910a48..5460ae0 100644 (file)
@@ -28,7 +28,7 @@
 #include <utility>
 #include "expr.h"
 #include "ir.h"
-#include "tvm/node/ir_functor.h"
+#include "tvm/node/functor.h"
 
 namespace tvm {
 namespace ir {
@@ -36,13 +36,13 @@ namespace ir {
  * \brief a base class for mutator to iterative mutate the IR
  *
  *  This IRMutator is implemented via Visitor Pattern.
- *  Also you can implement via IRFunctor.
+ *  Also you can implement via NodeFunctor.
  *  This enables easy extensions of possible new Node.
  *  It also makes changing return types easier.
  *
  * \note If you want to return a different type other than Expr and Stmt,
  *       Simply following the same pattern as IRMutator and create a seperate class.
- * \sa IRFunctor
+ * \sa NodeFunctor
  */
 class TVM_DLL IRMutator {
  public:
@@ -65,9 +65,9 @@ class TVM_DLL IRMutator {
   /*! \brief destructor */
   virtual ~IRMutator() {}
   /*! \brief functor type of expr mutation */
-  using FMutateExpr = IRFunctor<Expr(const ObjectRef&, const Expr&, IRMutator*)>;
+  using FMutateExpr = NodeFunctor<Expr(const ObjectRef&, const Expr&, IRMutator*)>;
   /*! \brief functor type of stmt mutation */
-  using FMutateStmt = IRFunctor<Stmt(const ObjectRef&, const Stmt&, IRMutator*)>;
+  using FMutateStmt = NodeFunctor<Stmt(const ObjectRef&, const Stmt&, IRMutator*)>;
   /*! \return internal vtable of expr */
   static FMutateExpr& vtable_expr();  // NOLINT(*)
   /*! \return internal stmt of expr */
index bebf945..b85cf23 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_IR_VISITOR_H_
 
 #include "ir.h"
-#include "tvm/node/ir_functor.h"
+#include "tvm/node/functor.h"
 
 namespace tvm {
 namespace ir {
@@ -33,7 +33,7 @@ namespace ir {
 /*!
  * \brief a base class for visitor to iterative traverse the IR
  *
- *  This IRVisitor is implemented via IRFunctor
+ *  This IRVisitor is implemented via NodeFunctor
  *  This enables extensions of possible new Node.
  *
  * \sa ExprFunctor, StmtFunctor, PostOrderVisit
@@ -94,7 +94,7 @@ class TVM_DLL IRVisitor {
   /*! \brief destructor */
   virtual ~IRVisitor() {}
   /*! \brief functor type of visitor */
-  using FVisit = IRFunctor<void(const ObjectRef&, IRVisitor*)>;
+  using FVisit = NodeFunctor<void(const ObjectRef&, IRVisitor*)>;
   /*! \return internal vtable*/
   static FVisit& vtable();
   // overloadable visit function.
diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h
new file mode 100644 (file)
index 0000000..d56fb8d
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/functor.h
+ * \brief Defines the Functor data structures.
+ */
+#ifndef TVM_NODE_FUNCTOR_H_
+#define TVM_NODE_FUNCTOR_H_
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/node/node.h>
+
+#include <vector>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+/*!
+ * \brief A dynamically dispatched functor on the type of the first argument.
+ *
+ * This is a class that is useful to construct polymorphic dispatching
+ * base on the AST/IR node's type.
+ *
+ * \code
+ *   NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
+ *   tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
+ *     return prefix + "Add";
+ *   });
+ *   tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
+ *     return prefix + "IntImm"
+ *   });
+ *
+ *   Expr x = make_const(1);
+ *   Expr y = x + x;
+ *   // dispatch to IntImm, outputs "MyIntImm"
+ *   LOG(INFO) << tostr(x, "My");
+ *   // dispatch to IntImm, outputs "MyAdd"
+ *   LOG(INFO) << tostr(y, "My");
+ * \endcode
+ *
+ * \tparam FType function signiture
+ *  This type if only defined for FType with function signature
+ */
+template<typename FType>
+class NodeFunctor;
+
+template<typename R, typename ...Args>
+class NodeFunctor<R(const ObjectRef& n, Args...)> {
+ private:
+  /*! \brief internal function pointer type */
+  typedef R (*FPointer)(const ObjectRef&n, Args...);
+  /*! \brief refer to itself. */
+  using TSelf = NodeFunctor<R (const ObjectRef& n, Args...)>;
+  /*! \brief internal function table */
+  std::vector<FPointer> func_;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*!
+   * \brief Whether the functor can dispatch the corresponding Node
+   * \param n The node to be dispatched
+   * \return Whether dispatching function is registered for n's type.
+   */
+  bool can_dispatch(const ObjectRef& n) const {
+    uint32_t type_index = n->type_index();
+    return type_index < func_.size() && func_[type_index] != nullptr;
+  }
+  /*!
+   * \brief invoke the functor, dispatch on type of n
+   * \param n The Node argument
+   * \param args The additional arguments
+   * \return The result.
+   */
+  R operator()(const ObjectRef& n, Args... args) const {
+    CHECK(can_dispatch(n))
+        << "NodeFunctor calls un-registered function on type "
+        << n->GetTypeKey();
+    return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief set the dispacher for type TNode
+   * \param f The function to be set.
+   * \tparam TNode the type of Node to be dispatched.
+   * \return reference to self.
+   */
+  template<typename TNode>
+  TSelf& set_dispatch(FPointer f) {  // NOLINT(*)
+    uint32_t tindex = TNode::RuntimeTypeIndex();
+    if (func_.size() <= tindex) {
+      func_.resize(tindex + 1, nullptr);
+    }
+    CHECK(func_[tindex] == nullptr)
+        << "Dispatch for " << TNode::_type_key
+        << " is already set";
+    func_[tindex] = f;
+    return *this;
+  }
+  /*!
+  * \brief unset the dispacher for type TNode
+  *
+  * \tparam TNode the type of Node to be dispatched.
+  * \return reference to self.
+  */
+  template<typename TNode>
+  TSelf& clear_dispatch() {  // NOLINT(*)
+    uint32_t tindex = TNode::RuntimeTypeIndex();
+    CHECK_LT(tindex, func_.size())
+        << "clear_dispatch: index out of range";
+    func_[tindex] = nullptr;
+    return *this;
+  }
+};
+
+
+#define TVM_REG_FUNC_VAR_DEF(ClsName)                                 \
+  static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
+
+/*!
+ * \brief Useful macro to set NodeFunctor dispatch in a global static field.
+ *
+ * \code
+ *  // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern.
+ *  // vtable allows easy patch of new Node types, without changing
+ *  // interface of IRPrinter.
+ *
+ *  class IRPrinter {
+ *   public:
+ *    std::ostream& stream;
+ *    // the dispatch function.
+ *    void print(Expr e) {
+ *      const static FType& f = *vtable();
+ *      f(e, this);
+ *    }
+ *
+ *    using FType = NodeFunctor<void (const ObjectRef&, IRPrinter *)>;
+ *    // function to return global function table
+ *    static FType& vtable();
+ *  };
+ *
+ *  // in cpp/cc file
+ *  IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
+ *    static FType inst; return inst;
+ *  }
+ *
+ *  TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+ *  .set_dispatch<Add>([](const ObjectRef& ref, IRPrinter* p) {
+ *    auto* n = static_cast<const Add*>(ref.get());
+ *    p->print(n->a);
+ *    p->stream << '+'
+ *    p->print(n->b);
+ *  });
+ *
+ *
+ * \endcode
+ *
+ * \param ClsName The name of the class
+ * \param FField The static function that returns a singleton of NodeFunctor.
+ */
+#define TVM_STATIC_IR_FUNCTOR(ClsName, FField)                       \
+  TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__)  =      \
+      ClsName::FField()
+}  // namespace tvm
+#endif  // TVM_NODE_FUNCTOR_H_
diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h
deleted file mode 100644 (file)
index e902e8f..0000000
+++ /dev/null
@@ -1,282 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file tvm/node/ir_functor.h
- * \brief Defines the IRFunctor data structures.
- */
-#ifndef TVM_NODE_IR_FUNCTOR_H_
-#define TVM_NODE_IR_FUNCTOR_H_
-
-#include <dmlc/logging.h>
-#include <string>
-#include <vector>
-#include <memory>
-#include <type_traits>
-#include <utility>
-#include <functional>
-#include "node.h"
-
-namespace tvm {
-/*!
- * \brief A dynamically dispatched functor on ObjectRef in the first argument.
- *
- * \code
- *   IRFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
- *   tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
- *     return prefix + "Add";
- *   });
- *   tostr.set_dispatch<IntImm>([](const IntImm* op) {
- *     return prefix + "IntImm"
- *   });
- *
- *   Expr x = make_const(1);
- *   Expr y = x + x;
- *   // dispatch to IntImm, outputs "MyIntImm"
- *   LOG(INFO) << tostr(x, "My");
- *   // dispatch to IntImm, outputs "MyAdd"
- *   LOG(INFO) << tostr(y, "My");
- * \endcode
- *
- * \tparam FType function signiture
- *  This type if only defined for FType with function signature
- */
-template<typename FType>
-class IRFunctor;
-
-template<typename R, typename ...Args>
-class IRFunctor<R(const ObjectRef& n, Args...)> {
- private:
-  using Function = std::function<R (const ObjectRef&n, Args...)>;
-  using TSelf = IRFunctor<R (const ObjectRef& n, Args...)>;
-  /*! \brief internal function table */
-  std::vector<Function> func_;
-
- public:
-  /*! \brief the result type of this functor */
-  using result_type = R;
-  /*!
-   * \brief Whether the functor can dispatch the corresponding Node
-   * \param n The node to be dispatched
-   * \return Whether dispatching function is registered for n's type.
-   */
-  inline bool can_dispatch(const ObjectRef& n) const {
-    uint32_t type_index = n->type_index();
-    return type_index < func_.size() && func_[type_index] != nullptr;
-  }
-  /*!
-   * \brief invoke the functor , dispatch on type of n
-   * \param n The Node argument
-   * \param args The additional arguments
-   * \return The result.
-   */
-  inline R operator()(const ObjectRef& n, Args... args) const {
-    uint32_t type_index = n->type_index();
-    CHECK(type_index < func_.size() &&
-          func_[type_index] != nullptr)
-        << "IRFunctor calls un-registered function on type "
-        << n->GetTypeKey();
-    return func_[type_index](n, std::forward<Args>(args)...);
-  }
-  /*!
-   * \brief set the dispacher for type TNode
-   * \param f The function to be set.
-   * \tparam TNode the type of Node to be dispatched.
-   * \return reference to self.
-   */
-  template<typename TNode>
-  inline TSelf& set_dispatch(Function f) {  // NOLINT(*)
-    uint32_t tindex = TNode::RuntimeTypeIndex();
-    if (func_.size() <= tindex) {
-      func_.resize(tindex + 1, nullptr);
-    }
-    CHECK(func_[tindex] == nullptr)
-        << "Dispatch for " << TNode::_type_key
-        << " is already set";
-    func_[tindex] = f;
-    return *this;
-  }
-  /*!
-   * \brief set the dispacher for type TNode
-   *  This allows f to used detailed const Node pointer to replace ObjectRef
-   *
-   * \param f The function to be set.
-   * \tparam TNode the type of Node to be dispatched.
-   * \return reference to self.
-   */
-  template<typename TNode>
-  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
-    Function fun = [f](const ObjectRef& n, Args... args) {
-      return f(static_cast<const TNode*>(n.get()),
-               std::forward<Args>(args)...);
-    };
-    return this->set_dispatch<TNode>(fun);
-  }
-  /*!
-  * \brief unset the dispacher for type TNode
-  *
-  * \tparam TNode the type of Node to be dispatched.
-  * \return reference to self.
-  */
-  template<typename TNode>
-  inline TSelf& clear_dispatch() {  // NOLINT(*)
-    uint32_t tindex = TNode::RuntimeTypeIndex();
-    CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
-    func_[tindex] = nullptr;
-    return *this;
-  }
-};
-
-#if defined(__GNUC__)
-#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
-#else
-#define TVM_ATTRIBUTE_UNUSED
-#endif
-
-/*! \brief helper macro to generate string concat */
-#define TVM_STR_CONCAT_(__x, __y) __x##__y
-#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
-
-#define TVM_REGISTER_VAR_DEF(ClsName)                                 \
-  static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
-
-/*!
- * \brief Useful macro to set IRFunctor dispatch in a global static field.
- *
- * \code
- *  // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
- *  // vtable allows easy patch in of new Node types, without changing
- *  // interface of IRPrinter.
- *
- *  class IRPrinter {
- *   public:
- *    std::ostream& stream;
- *    // the dispatch function.
- *    void print(Expr e) {
- *      const static FType& f = *vtable();
- *      f(e, this);
- *    }
- *
- *    using FType = IRFunctor<void (const ObjectRef&, IRPrinter *)>;
- *    // function to return global function table
- *    static FType& vtable();
- *  };
- *
- *  // in cpp/cc file
- *  IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
- *    static FType inst; return inst;
- *  }
- *
- *  TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
- *  .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
- *    p->print(n->a);
- *    p->stream << '+'
- *    p->print(n->b);
- *  });
- *
- *
- * \endcode
- *
- * \param ClsName The name of the class
- * \param FField The static function that returns a singleton of IRFunctor.
- */
-#define TVM_STATIC_IR_FUNCTOR(ClsName, FField)                       \
-  TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__)  =      \
-                              ClsName::FField()
-
- /*!
- * \brief A container for a list of callbacks. All callbacks are invoked when
- * the object is destructed.
- */
-class IRFunctorCleanList {
- public:
-  ~IRFunctorCleanList() {
-    for (auto &f : clean_items) {
-      f();
-    }
-  }
-
-  void append(std::function<void()> func) {
-    clean_items.push_back(func);
-  }
-
- private:
-  std::vector< std::function<void()> > clean_items;
-};
-
-/*!
-* \brief A wrapper around IRFunctor that will record calls to set_dispatch
-* and make a corresponding call to clear_dispatch when the last copy of
-* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
-* this can be used by NNVM and other libraries to unregister callbacks when
-* the library is unloaded. This prevents crashes when the underlying IRFunctor
-* is destructed as it will no longer contain std::function instances allocated
-* by a library that has been unloaded.
-*/
-template<typename FType>
-class IRFunctorStaticRegistry;
-
-template<typename R, typename ...Args>
-class IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> {
- private:
-  IRFunctor<R(const ObjectRef& n, Args...)> *irf_;
-  std::shared_ptr<IRFunctorCleanList> free_list;
-
-  using TSelf = IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>;
-
- public:
-  IRFunctorStaticRegistry(IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
-    irf_ = irf;
-    free_list = std::make_shared<IRFunctorCleanList>();
-  }
-
-  template<typename TNode>
-  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) {  // NOLINT(*)
-    irf_->template set_dispatch<TNode>(f);
-    auto irf_copy = irf_;
-    free_list.get()->append([irf_copy] {
-      irf_copy->template clear_dispatch<TNode>();
-      });
-    return *this;
-  }
-};
-
-/*!
-* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
-* the compiler to deduce the template types.
-*/
-template<typename R, typename ...Args>
-IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)> MakeIRFunctorStaticRegistry(
-  IRFunctor<R(const ObjectRef& n, Args...)> *irf) {
-  return IRFunctorStaticRegistry<R(const ObjectRef& n, Args...)>(irf);
-}
-
-#define TVM_AUTO_REGISTER_VAR_DEF(ClsName)                        \
-  static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
-
-/*!
-* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
-* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
-* TVM_STATIC_IR_FUNCTOR.
-*/
-#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField)                  \
-  TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__)  = \
-                        MakeIRFunctorStaticRegistry(&ClsName::FField())
-
-}  // namespace tvm
-#endif  // TVM_NODE_IR_FUNCTOR_H_
index e6caa44..35a8e1d 100644 (file)
@@ -48,20 +48,20 @@ using runtime::ObjectRef;
  *  Each objects that wants reflection will need to implement
  *  a VisitAttrs function and call visitor->Visit on each of its field.
  */
-class TVM_DLL AttrVisitor {
+class AttrVisitor {
  public:
 //! \cond Doxygen_Suppress
-  virtual ~AttrVisitor() = default;
-  virtual void Visit(const char* key, double* value) = 0;
-  virtual void Visit(const char* key, int64_t* value) = 0;
-  virtual void Visit(const char* key, uint64_t* value) = 0;
-  virtual void Visit(const char* key, int* value) = 0;
-  virtual void Visit(const char* key, bool* value) = 0;
-  virtual void Visit(const char* key, std::string* value) = 0;
-  virtual void Visit(const char* key, void** value) = 0;
-  virtual void Visit(const char* key, DataType* value) = 0;
-  virtual void Visit(const char* key, runtime::NDArray* value) = 0;
-  virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
+  TVM_DLL virtual ~AttrVisitor() = default;
+  TVM_DLL virtual void Visit(const char* key, double* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, int* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, void** value) = 0;
+  TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
+  TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
   template<typename ENum,
            typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
   void Visit(const char* key, ENum* ptr) {
@@ -93,13 +93,13 @@ class ReflectionVTable {
    *        If this is not empty then FGlobalKey must be defined for the object.
    * \return The created function.
    */
-  using FCreate = std::function<ObjectPtr<Object>(const std::string& global_key)>;
+  typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
   /*!
    * \brief Global key function, only needed by global objects.
    * \param node The node pointer.
    * \return node The global key to the node.
    */
-  using FGlobalKey = std::function<std::string(const Object* self)>;
+  typedef std::string (*FGlobalKey)(const Object* self);
   /*!
    * \brief Dispatch the VisitAttrs function.
    * \param self The pointer to the object.
@@ -193,7 +193,7 @@ class ReflectionVTable::Registry {
   static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry &      \
   __make_Node ## _ ## TypeName ## __ =                                  \
       ::tvm::ReflectionVTable::Global()->Register<TypeName>()           \
-      .set_creator([](const std::string&) {                             \
+      .set_creator([](const std::string&) -> ObjectPtr<Object> {        \
           return ::tvm::runtime::make_object<TypeName>();               \
         })
 
index 8bc87a2..722f73f 100644 (file)
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
 #define TVM_RELAY_EXPR_FUNCTOR_H_
 
-#include <tvm/node/ir_functor.h>
+#include <tvm/node/functor.h>
 #include <string>
 #include <utility>
 #include <unordered_map>
@@ -66,7 +66,7 @@ template <typename R, typename... Args>
 class ExprFunctor<R(const Expr& n, Args...)> {
  private:
   using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
-  using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
   /*! \brief the result type of this functor */
index c15523c..d84d43a 100644 (file)
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
 #define TVM_RELAY_PATTERN_FUNCTOR_H_
 
-#include <tvm/node/ir_functor.h>
+#include <tvm/node/functor.h>
 #include <string>
 #include <utility>
 #include <unordered_map>
@@ -66,7 +66,7 @@ template <typename R, typename... Args>
 class PatternFunctor<R(const Pattern& n, Args...)> {
  private:
   using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
-  using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
   /*! \brief the result type of this functor */
index 5ce78d1..cd84f92 100644 (file)
@@ -391,8 +391,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
 TVM_REGISTER_NODE_TYPE(GraphFuncNode);
 TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<GraphFuncNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* op = static_cast<const GraphFuncNode*>(ref.get());
     p->stream << "GraphFunc(name=" << op->func_name
               << ", addr=" << op << ")";
 });
index b76f99f..bbbc0db 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -101,8 +101,9 @@ GraphKey GraphKeyNode::make(Graph graph,
   return GraphKey(n);
 }
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<GraphKeyNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* op = static_cast<const GraphKeyNode*>(ref.get());
     p->stream << "GraphKeyNode("<< op << ")";
 });
 
index d8ff3bf..a4b398c 100644 (file)
@@ -30,6 +30,8 @@
 namespace nnvm {
 namespace compiler {
 
+using tvm::Object;
+using tvm::ObjectPtr;
 using tvm::runtime::TVMArgs;
 using tvm::runtime::TVMRetValue;
 using tvm::runtime::PackedFunc;
index 168486e..d494a50 100644 (file)
@@ -53,7 +53,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ConstIntBoundNode>([](const ConstIntBoundNode* op, IRPrinter* p) {
+.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ConstIntBoundNode*>(node.get());
     p->stream << "ConstIntBound[";
     PrintBoundValue(p->stream, op->min_value);
     p->stream << ',';
index 4094775..9f8effb 100644 (file)
@@ -810,7 +810,8 @@ IntSet EvalSet(Range r,
 TVM_REGISTER_NODE_TYPE(IntervalSetNode);
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
+.set_dispatch<IntervalSetNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const IntervalSetNode*>(node.get());
     p->stream << "IntervalSet"
               << "[" << op->min_value << ", "
               << op->max_value << ']';
index 9e363e7..25c7391 100644 (file)
@@ -45,7 +45,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ModularSetNode>([](const ModularSetNode *op, IRPrinter *p) {
+.set_dispatch<ModularSetNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const ModularSetNode*>(node.get());
     p->stream << "ModularSet("
               << "coeff=" << op->coeff << ", base="
               << op->base << ')';
index cfcb060..3f279f8 100644 (file)
@@ -37,8 +37,9 @@ TVM_REGISTER_NODE_TYPE(TargetNode);
 TVM_REGISTER_NODE_TYPE(GenericFuncNode);
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TargetNode>([](const TargetNode *op, IRPrinter *p) {
-  p->stream << op->str();
+.set_dispatch<TargetNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const TargetNode*>(node.get());
+    p->stream << op->str();
   });
 
 
@@ -654,7 +655,8 @@ tvm::BuildConfig BuildConfig::Current() {
 TVM_REGISTER_NODE_TYPE(BuildConfigNode);
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<BuildConfigNode>([](const BuildConfigNode *op, IRPrinter *p) {
+.set_dispatch<BuildConfigNode>([](const ObjectRef& node, IRPrinter *p) {
+  auto* op = static_cast<const BuildConfigNode*>(node.get());
   p->stream << "build_config(";
   p->stream << "data_alignment=" << op->data_alignment << ", ";
   p->stream << "offset_factor=" << op->offset_factor << ", ";
index cd3d43b..3c48676 100644 (file)
 namespace tvm {
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<EnvFuncNode>([](const EnvFuncNode *op, IRPrinter *p) {
+.set_dispatch<EnvFuncNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const EnvFuncNode*>(node.get());
     p->stream << "EnvFunc(" << op->name << ")";
 });
 
-NodePtr<EnvFuncNode> CreateEnvNode(const std::string& name) {
+ObjectPtr<Object> CreateEnvNode(const std::string& name) {
   auto* f = runtime::Registry::Get(name);
   CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
   NodePtr<EnvFuncNode> n = make_node<EnvFuncNode>();
@@ -62,7 +63,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc")
 
 TVM_REGISTER_NODE_TYPE(EnvFuncNode)
 .set_creator(CreateEnvNode)
-.set_global_key([](const Object* n) {
+.set_global_key([](const Object* n) -> std::string {
     return static_cast<const EnvFuncNode*>(n)->name;
   });
 
index b9391e4..51b355e 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file attr_functor.h
  * \brief A way to define arbitrary function signature
  *        with dispatch on common attributes.
@@ -31,6 +30,7 @@
 #ifndef TVM_LANG_ATTR_FUNCTOR_H_
 #define TVM_LANG_ATTR_FUNCTOR_H_
 
+#include <tvm/node/functor.h>
 #include <utility>
 
 namespace tvm {
@@ -54,7 +54,7 @@ template <typename R, typename... Args>
 class AttrFunctor<R(const ObjectRef& n, Args...)> {
  private:
   using TSelf = AttrFunctor<R(const ObjectRef& n, Args...)>;
-  using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
   /*! \brief the result type of this functor */
index a299e17..0b036c3 100644 (file)
@@ -61,7 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
+.set_dispatch<DictAttrsNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const DictAttrsNode*>(node.get());
     p->stream << op->dict;
 });
 
index 689b291..bc14e2b 100644 (file)
@@ -452,7 +452,8 @@ Buffer BufferNode::make(Var data,
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<BufferNode>([](const BufferNode *op, IRPrinter *p) {
+.set_dispatch<BufferNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const BufferNode*>(node.get());
     p->stream << "buffer(" << op->name << ", " << op << ")";
 });
 
index 6746a3c..c564d61 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,7 +33,8 @@ Channel ChannelNode::make(Var handle_var, Type dtype) {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ChannelNode>([](const ChannelNode *op, IRPrinter *p) {
+.set_dispatch<ChannelNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const ChannelNode*>(node.get());
     p->stream << "channel(" << op->handle_var << ", " << op->dtype << ")";
 });
 
index 3686d5f..7c76e40 100644 (file)
@@ -196,7 +196,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<LayoutNode>([](const LayoutNode* l, IRPrinter* p) {
+.set_dispatch<LayoutNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* l = static_cast<const LayoutNode*>(node.get());
     p->stream << "Layout(" << l->name << ")";
   });
 
@@ -352,7 +353,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<BijectiveLayoutNode>([](const BijectiveLayoutNode* b, IRPrinter* p) {
+.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
     p->stream << "BijectiveLayout(" << b->src_layout.name()
               << "->" << b->dst_layout.name() << ")";
   });
index 31ade90..6a69fda 100644 (file)
@@ -182,7 +182,8 @@ IRPrinter::FType& IRPrinter::vtable() {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) {
+.set_dispatch<IntImm>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const IntImm*>(node.get());
     if (op->type == Int(32)) {
       p->stream << op->value;
     } else {
@@ -191,7 +192,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
+.set_dispatch<IterVarNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const IterVarNode*>(node.get());
     p->stream << "iter_var(";
     if (op->var->name_hint.length() != 0) {
       p->stream  << op->var->name_hint << ", ";
@@ -206,7 +208,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) {
+.set_dispatch<RangeNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const RangeNode*>(node.get());
     p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
   });
 
index 04e04ae..bb8401d 100644 (file)
@@ -553,12 +553,14 @@ Stmt Evaluate::make(Expr value) {
 
 // Printers
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<UIntImm>([](const UIntImm* op, IRPrinter* p) {
+.set_dispatch<UIntImm>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const UIntImm*>(node.get());
     p->stream << "(" << op->type << ")" << op->value;
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<FloatImm>([](const FloatImm* op, IRPrinter* p) {
+.set_dispatch<FloatImm>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const FloatImm*>(node.get());
     auto& stream = p->stream;
     switch (op->type.bits()) {
       case 64:
@@ -576,7 +578,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<StringImm>([](const StringImm* op, IRPrinter* p) {
+.set_dispatch<StringImm>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const StringImm*>(node.get());
     auto& stream = p->stream;
     stream << '"';
     for (size_t i = 0; i < op->value.size(); ++i) {
@@ -611,101 +614,116 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Cast>([](const Cast* op, IRPrinter* p) {
+.set_dispatch<Cast>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Cast*>(node.get());
     p->stream << op->type << '(';
     p->Print(op->value);
     p->stream << ')';
   })
-.set_dispatch<Variable>([](const Variable* op, IRPrinter* p) {
+.set_dispatch<Variable>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Variable*>(node.get());
     // omit the type
     // stream << op->name << "." << op->type;
     p->stream << op->name_hint;
   })
-.set_dispatch<Add>([](const Add* op, IRPrinter* p) {
+.set_dispatch<Add>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Add*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " + ";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Sub>([](const Sub* op, IRPrinter* p) {
+.set_dispatch<Sub>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Sub*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " - ";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Mul>([](const Mul* op, IRPrinter* p) {
+.set_dispatch<Mul>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Mul*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << "*";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Div>([](const Div* op, IRPrinter* p) {
+.set_dispatch<Div>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Div*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << "/";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Mod>([](const Mod* op, IRPrinter* p) {
+.set_dispatch<Mod>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Mod*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " % ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<Min>([](const Min* op, IRPrinter* p) {
+.set_dispatch<Min>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Min*>(node.get());
     p->stream << "min(";
     p->Print(op->a);
     p->stream << ", ";
     p->Print(op->b);
     p->stream << ")";
 })
-.set_dispatch<Max>([](const Max* op, IRPrinter* p) {
+.set_dispatch<Max>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Max*>(node.get());
     p->stream << "max(";
     p->Print(op->a);
     p->stream << ", ";
     p->Print(op->b);
     p->stream << ")";
 })
-.set_dispatch<EQ>([](const EQ* op, IRPrinter* p) {
+.set_dispatch<EQ>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const EQ*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " == ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<NE>([](const NE* op, IRPrinter* p) {
+.set_dispatch<NE>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const NE*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " != ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<LT>([](const LT* op, IRPrinter* p) {
+.set_dispatch<LT>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const LT*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " < ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<LE>([](const LE* op, IRPrinter* p) {
+.set_dispatch<LE>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const LE*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " <= ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<GT>([](const GT* op, IRPrinter* p) {
+.set_dispatch<GT>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const GT*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " > ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<GE>([](const GE* op, IRPrinter* p) {
+.set_dispatch<GE>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const GE*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " >= ";
@@ -714,17 +732,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<FloorDiv>([](const FloorDiv* op, IRPrinter *p) {
+.set_dispatch<FloorDiv>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const FloorDiv*>(node.get());
   p->stream << "floordiv(" << op->a << ", " << op->b << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<FloorMod>([](const FloorMod* op, IRPrinter *p) {
+.set_dispatch<FloorMod>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const FloorMod*>(node.get());
   p->stream << "floormod(" << op->a << ", " << op->b << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<And>([](const And* op, IRPrinter* p) {
+.set_dispatch<And>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const And*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " && ";
@@ -733,7 +754,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Or>([](const Or* op, IRPrinter* p) {
+.set_dispatch<Or>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Or*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " || ";
@@ -742,13 +764,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Not>([](const Not* op, IRPrinter* p) {
+.set_dispatch<Not>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Not*>(node.get());
     p->stream << '!';
     p->Print(op->a);
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Select>([](const Select* op, IRPrinter* p) {
+.set_dispatch<Select>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Select*>(node.get());
     p->stream << "select(";
     p->Print(op->condition);
     p->stream << ", ";
@@ -759,7 +783,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Load>([](const Load* op, IRPrinter* p) {
+.set_dispatch<Load>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Load*>(node.get());
     p->stream << op->buffer_var << "[";
     p->Print(op->index);
     p->stream << "]";
@@ -770,7 +795,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Ramp>([](const Ramp* op, IRPrinter* p) {
+.set_dispatch<Ramp>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Ramp*>(node.get());
     p->stream << "ramp(";
     p->Print(op->base);
     p->stream << ", ";
@@ -779,14 +805,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Broadcast>([](const Broadcast* op, IRPrinter* p) {
+.set_dispatch<Broadcast>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Broadcast*>(node.get());
     p->stream << "x" << op->lanes << "(";
     p->Print(op->value);
     p->stream << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Call>([](const Call* op, IRPrinter* p) {
+.set_dispatch<Call>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Call*>(node.get());
     p->stream << op->name << "(";
     for (size_t i = 0; i < op->args.size(); ++i) {
       p->Print(op->args[i]);
@@ -798,7 +826,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Let>([](const Let* op, IRPrinter* p) {
+.set_dispatch<Let>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Let*>(node.get());
     p->stream << "(let " << op->var << " = ";
     p->Print(op->value);
     p->stream << " in ";
@@ -807,7 +836,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<LetStmt>([](const LetStmt* op, IRPrinter* p) {
+.set_dispatch<LetStmt>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const LetStmt*>(node.get());
     p->PrintIndent();
     p->stream << "let " << op->var << " = ";
     p->Print(op->value);
@@ -816,7 +846,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<AttrStmt>([](const AttrStmt* op, IRPrinter* p) {
+.set_dispatch<AttrStmt>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const AttrStmt*>(node.get());
     p->PrintIndent();
     p->stream << "// attr [";
     p->Print(op->node);
@@ -828,7 +859,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<AssertStmt>([](const AssertStmt* op, IRPrinter* p) {
+.set_dispatch<AssertStmt>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const AssertStmt*>(node.get());
     p->PrintIndent();
     p->stream << "assert(";
     p->Print(op->condition);
@@ -839,7 +871,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ProducerConsumer>([](const ProducerConsumer* op, IRPrinter* p) {
+.set_dispatch<ProducerConsumer>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ProducerConsumer*>(node.get());
     if (op->is_producer) {
       p->PrintIndent();
       p->stream << "produce " << op->func->func_name() << " {\n";
@@ -872,7 +905,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<For>([](const For* op, IRPrinter* p) {
+.set_dispatch<For>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const For*>(node.get());
     p->PrintIndent();
     p->stream << op->for_type << " (" << op->loop_var << ", ";
     p->Print(op->min);
@@ -889,7 +923,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Store>([](const Store* op, IRPrinter* p) {
+.set_dispatch<Store>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Store*>(node.get());
     p->PrintIndent();
     p->stream << op->buffer_var << "[";
     p->Print(op->index);
@@ -903,7 +938,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Provide>([](const Provide* op, IRPrinter* p) {
+.set_dispatch<Provide>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Provide*>(node.get());
     p->PrintIndent();
     p->stream << op->func->func_name() << "(";
     for (size_t i = 0; i < op->args.size(); ++i) {
@@ -920,7 +956,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Allocate>([](const Allocate* op, IRPrinter* p) {
+.set_dispatch<Allocate>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Allocate*>(node.get());
     p->PrintIndent();
     p->stream << "allocate " << op->buffer_var << "[" << op->type;
     for (size_t i = 0; i < op->extents.size(); ++i) {
@@ -937,14 +974,16 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Free>([](const Free* op, IRPrinter* p) {
+.set_dispatch<Free>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Free*>(node.get());
     p->PrintIndent();
     p->stream << "free " << op->buffer_var;
     p->stream << '\n';
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Realize>([](const Realize* op, IRPrinter* p) {
+.set_dispatch<Realize>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Realize*>(node.get());
     p->PrintIndent();
     p->stream << "realize " << op->func->func_name() << "(";
     for (size_t i = 0; i < op->bounds.size(); ++i) {
@@ -974,7 +1013,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Prefetch>([](const Prefetch* op, IRPrinter* p) {
+.set_dispatch<Prefetch>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Prefetch*>(node.get());
     p->PrintIndent();
     p->stream << "prefetch " << op->func->func_name() << "(";
     for (size_t i = 0; i < op->bounds.size(); ++i) {
@@ -992,13 +1032,15 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Block>([](const Block* op, IRPrinter* p) {
+.set_dispatch<Block>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Block*>(node.get());
     p->Print(op->first);
     if (op->rest.defined()) p->Print(op->rest);
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IfThenElse>([](const IfThenElse* op, IRPrinter* p) {
+.set_dispatch<IfThenElse>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const IfThenElse*>(node.get());
     p->PrintIndent();
     while (true) {
       p->stream << "if (" << op->condition << ") {\n";
@@ -1028,7 +1070,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Evaluate>([](const Evaluate* op, IRPrinter* p) {
+.set_dispatch<Evaluate>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Evaluate*>(node.get());
     p->PrintIndent();
     p->Print(op->value);
     p->stream << "\n";
@@ -1045,7 +1088,8 @@ void PrintList(const Array<T> &exprs, IRPrinter* p) {
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Shuffle>([](const Shuffle* op, IRPrinter* p) {
+.set_dispatch<Shuffle>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Shuffle*>(node.get());
     p->stream << "shuffle(";
     PrintList(op->vectors, p);
     p->stream << ", ";
@@ -1055,7 +1099,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 
 // Container printer
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ArrayNode>([](const ArrayNode* op, IRPrinter* p) {
+.set_dispatch<ArrayNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ArrayNode*>(node.get());
     p->stream << '[';
     for (size_t i = 0 ; i < op->data.size(); ++i) {
       if (i != 0) {
@@ -1067,7 +1112,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<MapNode>([](const MapNode* op, IRPrinter* p) {
+.set_dispatch<MapNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const MapNode*>(node.get());
     p->stream << '{';
     for (auto it = op->data.begin(); it != op->data.end(); ++it) {
       if (it != op->data.begin()) {
@@ -1081,7 +1127,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<StrMapNode>([](const StrMapNode* op, IRPrinter* p) {
+.set_dispatch<StrMapNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const StrMapNode*>(node.get());
     p->stream << '{';
     for (auto it = op->data.begin(); it != op->data.end(); ++it) {
       if (it != op->data.begin()) {
@@ -1094,7 +1141,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Reduce>([](const Reduce* op, IRPrinter* p) {
+.set_dispatch<Reduce>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const Reduce*>(node.get());
     p->stream << "reduce(combiner="
               << op->combiner;
     p->stream << ", source=" << op->source;
@@ -1105,7 +1153,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<CommReducerNode>([](const CommReducerNode* op, IRPrinter* p) {
+.set_dispatch<CommReducerNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const CommReducerNode*>(node.get());
     p->stream << "comm_reducer(result=" << op->result
               << ", lhs=" << op->lhs
               << ", rhs=" << op->rhs
@@ -1114,8 +1163,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
-  p->stream << "?";
+.set_dispatch<Any>([](const ObjectRef& node, IRPrinter* p) {
+    p->stream << "?";
 });
 
 TVM_REGISTER_NODE_TYPE(CommReducerNode);
index 626b9f7..cb1ee05 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -26,7 +26,8 @@
 namespace tvm {
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) {
+.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const LoweredFuncNode*>(node.get());
     p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
 });
 
index 481a926..8c45a19 100644 (file)
@@ -27,7 +27,8 @@
 namespace tvm {
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<MemoryInfoNode>([](const MemoryInfoNode *op, IRPrinter *p) {
+.set_dispatch<MemoryInfoNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* op = static_cast<const MemoryInfoNode*>(node.get());
     p->stream << "mem-info("
               << "unit_bits=" << op->unit_bits << ", "
               << "max_num_bits=" << op->max_num_bits << ", "
index 1ac5642..db90e4e 100644 (file)
@@ -69,7 +69,8 @@ Tensor TensorNode::make(Array<Expr> shape,
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TensorNode>([](const TensorNode *t, IRPrinter *p) {
+.set_dispatch<TensorNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* t = static_cast<const TensorNode*>(node.get());
     p->stream << "Tensor(shape=" << t->shape
               << ", op.name=" << t->op->name << ')';
   });
@@ -100,8 +101,9 @@ TensorIntrin TensorIntrinNode::make(std::string name,
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TensorIntrinNode>([](const TensorIntrinNode *n, IRPrinter *p) {
-    p->stream << "TensorIntrin(name=" << n->name << ", " << n << ")";
+.set_dispatch<TensorIntrinNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const TensorIntrinNode*>(node.get());
+    p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
   });
 
 TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
@@ -124,7 +126,8 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TensorIntrinCallNode>([](const TensorIntrinCallNode *n, IRPrinter *p) {
+.set_dispatch<TensorIntrinCallNode>([](const ObjectRef& node, IRPrinter *p) {
+    auto* n = static_cast<const TensorIntrinCallNode*>(node.get());
     p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")";
   });
 
index 6958942..5f5d2d4 100644 (file)
@@ -40,7 +40,8 @@ namespace tvm {
 using namespace ir;
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
+.set_dispatch<ComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ComputeOpNode*>(node.get());
     p->stream << "compute(" << op->name << ", " << op << ")";
 });
 
index 9f33415..35fe469 100644 (file)
@@ -31,7 +31,8 @@ namespace tvm {
 using namespace ir;
 // ExternOpNode
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ExternOpNode>([](const ExternOpNode *op, IRPrinter *p) {
+.set_dispatch<ExternOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ExternOpNode*>(node.get());
     p->stream << "extern(" << op->name << ", " << op << ")";
   });
 
index e6a46fe..7a99ea1 100644 (file)
@@ -37,7 +37,8 @@ namespace tvm {
 using namespace ir;
 // HybridOpNode
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<HybridOpNode>([](const HybridOpNode *op, IRPrinter *p) {
+.set_dispatch<HybridOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const HybridOpNode*>(node.get());
     p->stream << "hybrid(" << op->name << ", " << op << ")";
   });
 
index 97d01ca..4d08fa3 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,7 +28,8 @@ namespace tvm {
 
 // PlaceholderOpNode
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<PlaceholderOpNode>([](const PlaceholderOpNode *op, IRPrinter *p) {
+.set_dispatch<PlaceholderOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const PlaceholderOpNode*>(node.get());
     p->stream << "placeholder(" << op->name << ", " << op << ")";
 });
 
index 7b7a47c..b02073b 100644 (file)
@@ -32,7 +32,8 @@ namespace tvm {
 using namespace ir;
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<ScanOpNode>([](const ScanOpNode *op, IRPrinter *p) {
+.set_dispatch<ScanOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ScanOpNode*>(node.get());
     p->stream << "scan(" << op->name << ", " << op << ")";
 });
 TVM_REGISTER_NODE_TYPE(ScanOpNode);
index d333461..6533b0e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -36,8 +36,8 @@ namespace tvm {
 using namespace ir;
 // TensorComputeOpNode
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<TensorComputeOpNode>([](const TensorComputeOpNode *op,
-                                      IRPrinter *p) {
+.set_dispatch<TensorComputeOpNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const TensorComputeOpNode*>(node.get());
     p->stream << "tensor_compute_op(" << op->name << ", " << op << ")";
   });
 
index fda1237..c8e46c9 100644 (file)
@@ -118,9 +118,9 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
 
 // Mutate Stmt
 
-#define DISPATCH_TO_MUTATE_STMT(OP)                                 \
-  set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) {  \
-      return m->Mutate_(op, s);                                     \
+#define DISPATCH_TO_MUTATE_STMT(OP)                                     \
+  set_dispatch<OP>([](const ObjectRef& node, const Stmt& s, IRMutator* m) { \
+      return m->Mutate_(static_cast<const OP*>(node.get()), s);         \
     })
 
 Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
@@ -344,9 +344,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
 
 // Mutate Expr
 
-#define DISPATCH_TO_MUTATE_EXPR(OP)                                 \
-  set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) {  \
-      return m->Mutate_(op, e);                                     \
+#define DISPATCH_TO_MUTATE_EXPR(OP)                                         \
+  set_dispatch<OP>([](const ObjectRef& node, const Expr& e, IRMutator* m) { \
+      return m->Mutate_(static_cast<const OP*>(node.get()), e);             \
     })
 
 Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
index fde183e..38c8490 100644 (file)
@@ -237,9 +237,9 @@ DEFINE_OP_NO_VISIT_(UIntImm)
 DEFINE_OP_NO_VISIT_(FloatImm)
 DEFINE_OP_NO_VISIT_(StringImm)
 
-#define DISPATCH_TO_VISIT(OP)                       \
-  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
-      v->Visit_(op);                                \
+#define DISPATCH_TO_VISIT(OP)                                \
+  set_dispatch<OP>([](const ObjectRef& node, IRVisitor* v) { \
+      v->Visit_(static_cast<const OP*>(node.get()));         \
     })
 
 TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
index 7ec287b..0342aa6 100644 (file)
@@ -24,7 +24,6 @@
 
 #include <dmlc/any.h>
 #include <dmlc/json.h>
-#include <tvm/node/ir_functor.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/runtime/device_api.h>
 
index 962728e..01693e5 100644 (file)
@@ -53,8 +53,9 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
 TVM_REGISTER_API("relay._make.Closure")
 .set_body_typed(ClosureNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const ClosureNode*>(ref.get());
     p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
   });
 
@@ -71,10 +72,11 @@ RecClosure RecClosureNode::make(Closure clos, Var bind) {
 TVM_REGISTER_API("relay._make.RecClosure")
 .set_body_typed(RecClosureNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
-                                p->stream << "RecClosureNode(" << node->clos << ")";
-                              });
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RecClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const RecClosureNode*>(ref.get());
+    p->stream << "RecClosureNode(" << node->clos << ")";
+  });
 
 TupleValue TupleValueNode::make(tvm::Array<Value> value) {
   NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
@@ -85,8 +87,9 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
 TVM_REGISTER_API("relay._make.TupleValue")
 .set_body_typed(TupleValueNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TupleValueNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TupleValueNode*>(ref.get());
     p->stream << "TupleValueNode(" << node->fields << ")";
   });
 
@@ -96,8 +99,9 @@ TensorValue TensorValueNode::make(runtime::NDArray data) {
   return TensorValue(n);
 }
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TensorValueNode>([](const TensorValueNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TensorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TensorValueNode*>(ref.get());
     auto to_str = GetPackedFunc("relay._tensor_value_repr");
     std::string data_str = to_str(GetRef<TensorValue>(node));
     p->stream << "TensorValueNode(" << data_str << ")";
@@ -117,11 +121,11 @@ TVM_REGISTER_API("relay._make.RefValue")
 
 TVM_REGISTER_NODE_TYPE(RefValueNode);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RefValueNode>([](const RefValueNode* node,
-                               tvm::IRPrinter* p) {
-                              p->stream << "RefValueNode(" << node->value << ")";
-                            });
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RefValueNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const RefValueNode*>(ref.get());
+    p->stream << "RefValueNode(" << node->value << ")";
+  });
 
 ConstructorValue ConstructorValueNode::make(int32_t tag,
                                             tvm::Array<Value> fields,
@@ -138,9 +142,9 @@ TVM_REGISTER_API("relay._make.ConstructorValue")
 
 TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
-                                       tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const ConstructorValueNode*>(ref.get());
   p->stream << "ConstructorValueNode(" << node->tag << ","
             << node->fields << ")";
 });
index 12cebe5..1f51ecc 100644 (file)
@@ -37,9 +37,8 @@ TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
 TVM_REGISTER_API("relay._make.PatternWildcard")
 .set_body_typed(PatternWildcardNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
-                                      tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, IRPrinter* p) {
   p->stream << "PatternWildcardNode()";
 });
 
@@ -54,9 +53,9 @@ TVM_REGISTER_NODE_TYPE(PatternVarNode);
 TVM_REGISTER_API("relay._make.PatternVar")
 .set_body_typed(PatternVarNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
-                                 tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PatternVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const PatternVarNode*>(ref.get());
   p->stream << "PatternVarNode(" << node->var << ")";
 });
 
@@ -73,9 +72,9 @@ TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
 TVM_REGISTER_API("relay._make.PatternConstructor")
 .set_body_typed(PatternConstructorNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
-                                         tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const PatternConstructorNode*>(ref.get());
   p->stream << "PatternConstructorNode(" << node->constructor
             << ", " << node->patterns << ")";
 });
@@ -91,9 +90,9 @@ TVM_REGISTER_NODE_TYPE(PatternTupleNode);
 TVM_REGISTER_API("relay._make.PatternTuple")
 .set_body_typed(PatternTupleNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
-                                   tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const PatternTupleNode*>(ref.get());
   p->stream << "PatternTupleNode(" << node->patterns << ")";
 });
 
@@ -112,9 +111,9 @@ TVM_REGISTER_NODE_TYPE(ConstructorNode);
 TVM_REGISTER_API("relay._make.Constructor")
 .set_body_typed(ConstructorNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
-                                  tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const ConstructorNode*>(ref.get());
   p->stream << "ConstructorNode(" << node->name_hint << ", "
             << node->inputs << ", " << node->belong_to << ")";
 });
@@ -134,9 +133,9 @@ TVM_REGISTER_NODE_TYPE(TypeDataNode);
 TVM_REGISTER_API("relay._make.TypeData")
 .set_body_typed(TypeDataNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
-                               tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TypeDataNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const TypeDataNode*>(ref.get());
   p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
             << node->constructors << ")";
 });
@@ -153,9 +152,9 @@ TVM_REGISTER_NODE_TYPE(ClauseNode);
 TVM_REGISTER_API("relay._make.Clause")
 .set_body_typed(ClauseNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ClauseNode>([](const ClauseNode* node,
-                             tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ClauseNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const ClauseNode*>(ref.get());
   p->stream << "ClauseNode(" << node->lhs << ", "
             << node->rhs << ")";
   });
@@ -173,9 +172,9 @@ TVM_REGISTER_NODE_TYPE(MatchNode);
 TVM_REGISTER_API("relay._make.Match")
 .set_body_typed(MatchNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<MatchNode>([](const MatchNode* node,
-                            tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<MatchNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const MatchNode*>(ref.get());
   p->stream << "MatchNode(" << node->data << ", "
             << node->clauses << ", " << node->complete << ")";
 });
index 80f0790..3bc916d 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file base.cc
  * \brief The core base types for Relay.
  */
@@ -31,7 +30,7 @@ namespace relay {
 using tvm::IRPrinter;
 using namespace tvm::runtime;
 
-NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
+ObjectPtr<Object> GetSourceNameNode(const std::string& name) {
   // always return pointer as the reference can change as map re-allocate.
   // or use another level of indirection by creating a unique_ptr
   static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
@@ -54,8 +53,9 @@ SourceName SourceName::Get(const std::string& name) {
 TVM_REGISTER_API("relay._make.SourceName")
 .set_body_typed(SourceName::Get);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<SourceNameNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
+    auto* node = static_cast<const SourceNameNode*>(ref.get());
     p->stream << "SourceName(" << node->name << ", " << node << ")";
   });
 
@@ -78,8 +78,9 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
 TVM_REGISTER_API("relay._make.Span")
 .set_body_typed(SpanNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<SpanNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
+    auto* node = static_cast<const SpanNode*>(ref.get());
     p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
               << node->col_offset << ")";
   });
index 672cdab..47e735f 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file src/tvm/ir/expr.cc
  * \brief The expression AST nodes of Relay.
  */
@@ -41,8 +40,9 @@ TVM_REGISTER_NODE_TYPE(ConstantNode);
 TVM_REGISTER_API("relay._make.Constant")
 .set_body_typed(ConstantNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ConstantNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const ConstantNode*>(ref.get());
     const PackedFunc* fprint = Registry::Get("relay._constant_repr");
     CHECK(fprint) << "unable to find printing function for constants";
     std::string data = (*fprint)(GetRef<Constant>(node));
@@ -73,8 +73,9 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
 TVM_REGISTER_API("relay._make.Tuple")
 .set_body_typed(TupleNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TupleNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TupleNode*>(ref.get());
     p->stream << "Tuple(" << node->fields << ")";
   });
 
@@ -97,8 +98,9 @@ TVM_REGISTER_NODE_TYPE(VarNode);
 TVM_REGISTER_API("relay._make.Var")
 .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<VarNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const VarNode*>(ref.get());
     p->stream << "Var(" << node->name_hint();
     if (node->type_annotation.defined()) {
       p->stream << ", ty=";
@@ -118,8 +120,9 @@ TVM_REGISTER_NODE_TYPE(GlobalVarNode);
 TVM_REGISTER_API("relay._make.GlobalVar")
 .set_body_typed(GlobalVarNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const GlobalVarNode*>(ref.get());
     p->stream << "GlobalVar(" << node->name_hint << ")";
   });
 
@@ -217,12 +220,12 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
 TVM_REGISTER_API("relay._make.Function")
 .set_body_typed(FunctionNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<FunctionNode>([](const FunctionNode* node,
-                                   tvm::IRPrinter* p) {
-      p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
-                << ", " << node->body << ", " << node->type_params << ", "
-                << node->attrs << ")";
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FunctionNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const FunctionNode*>(ref.get());
+  p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
+            << ", " << node->body << ", " << node->type_params << ", "
+            << node->attrs << ")";
 });
 
 Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
@@ -240,11 +243,12 @@ TVM_REGISTER_NODE_TYPE(CallNode);
 TVM_REGISTER_API("relay._make.Call")
 .set_body_typed(CallNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
-    p->stream << "CallNode(" << node->op << ", " << node->args << ", "
-              << node->attrs << ", " << node->type_args << ")";
-});
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<CallNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const CallNode*>(ref.get());
+  p->stream << "CallNode(" << node->op << ", " << node->args << ", "
+            << node->attrs << ", " << node->type_args << ")";
+  });
 
 Let LetNode::make(Var var, Expr value, Expr body) {
   NodePtr<LetNode> n = make_node<LetNode>();
@@ -259,8 +263,9 @@ TVM_REGISTER_NODE_TYPE(LetNode);
 TVM_REGISTER_API("relay._make.Let")
 .set_body_typed(LetNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<LetNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const LetNode*>(ref.get());
   p->stream << "LetNode(" << node->var << ", " << node->value
             << ", " << node->body << ")";
 });
@@ -278,8 +283,9 @@ TVM_REGISTER_NODE_TYPE(IfNode);
 TVM_REGISTER_API("relay._make.If")
 .set_body_typed(IfNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IfNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const IfNode*>(ref.get());
   p->stream << "IfNode(" << node->cond << ", " << node->true_branch
             << ", " << node->false_branch << ")";
 });
@@ -296,8 +302,9 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
 TVM_REGISTER_API("relay._make.TupleGetItem")
 .set_body_typed(TupleGetItemNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const TupleGetItemNode*>(ref.get());
   p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
 });
 
@@ -312,8 +319,9 @@ TVM_REGISTER_NODE_TYPE(RefCreateNode);
 TVM_REGISTER_API("relay._make.RefCreate")
 .set_body_typed(RefCreateNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RefCreateNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const RefCreateNode*>(ref.get());
   p->stream << "RefCreateNode(" << node->value << ")";
 });
 
@@ -328,8 +336,9 @@ TVM_REGISTER_NODE_TYPE(RefReadNode);
 TVM_REGISTER_API("relay._make.RefRead")
 .set_body_typed(RefReadNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RefReadNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const RefReadNode*>(ref.get());
   p->stream << "RefReadNode(" << node->ref << ")";
 });
 
@@ -345,8 +354,9 @@ TVM_REGISTER_NODE_TYPE(RefWriteNode);
 TVM_REGISTER_API("relay._make.RefWrite")
 .set_body_typed(RefWriteNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RefWriteNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const RefWriteNode*>(ref.get());
   p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
 });
 
index 8a90f14..960c28f 100644 (file)
@@ -414,9 +414,9 @@ TVM_REGISTER_API("relay._module.Module_ImportFromStd")
   mod->ImportFromStd(path);
 });;
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ModuleNode>(
-  [](const ModuleNode *node, tvm::IRPrinter *p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ModuleNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const ModuleNode*>(ref.get());
     p->stream << "ModuleNode( " << node->functions << ")";
 });
 
index 7bfe41c..c4557ac 100644 (file)
@@ -199,8 +199,9 @@ TVM_REGISTER_NODE_TYPE(OpNode)
     return static_cast<const OpNode*>(n)->name;
   });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<OpNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const OpNode*>(ref.get());
     p->stream << "Op(" << node->name << ")";
   });
 
index 2604896..471b369 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file src/tvm/ir/type.cc
  * \brief The type system AST nodes of Relay.
  */
@@ -58,9 +57,9 @@ TVM_REGISTER_NODE_TYPE(TensorTypeNode);
 TVM_REGISTER_API("relay._make.TensorType")
 .set_body_typed(TensorTypeNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
-                                 tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const TensorTypeNode*>(ref.get());
   p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
 });
 
@@ -78,9 +77,9 @@ TVM_REGISTER_API("relay._make.TypeVar")
     return TypeVarNode::make(name, static_cast<Kind>(kind));
     });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
-                                    tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TypeVarNode*>(ref.get());
   p->stream << "TypeVarNode(" << node->var->name_hint << ", "
     << node->kind << ")";
 });
@@ -99,9 +98,9 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
     return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
     });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
-                                    tvm::IRPrinter *p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
   p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
             << node->kind << ")";
 });
@@ -118,9 +117,9 @@ TVM_REGISTER_NODE_TYPE(TypeCallNode);
 TVM_REGISTER_API("relay._make.TypeCall")
 .set_body_typed(TypeCallNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
-                               tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TypeCallNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TypeCallNode*>(ref.get());
   p->stream << "TypeCallNode(" << node->func << ", "
             << node->args << ")";
 });
@@ -138,12 +137,11 @@ TVM_REGISTER_API("relay._make.IncompleteType")
     return IncompleteTypeNode::make(static_cast<Kind>(kind));
   });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<IncompleteTypeNode>(
-    [](const IncompleteTypeNode* node,
-       tvm::IRPrinter* p) {
-      p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
-    });
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
+    p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
+  });
 
 FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
                             Type ret_type,
@@ -162,9 +160,9 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode);
 TVM_REGISTER_API("relay._make.FuncType")
 .set_body_typed(FuncTypeNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
-                                   tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const FuncTypeNode*>(ref.get());
   p->stream << "FuncTypeNode(" << node->type_params << ", "
             << node->arg_types << ", " << node->ret_type << ", "
             << node->type_constraints << ")";
@@ -187,8 +185,9 @@ TVM_REGISTER_NODE_TYPE(TypeRelationNode);
 TVM_REGISTER_API("relay._make.TypeRelation")
 .set_body_typed(TypeRelationNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, IRPrinter* p) {
+    auto* node = static_cast<const TypeRelationNode*>(ref.get());
     p->stream << "TypeRelationNode("
               << node->func->name
               << ", " << node->args << ")";
@@ -205,9 +204,9 @@ TVM_REGISTER_NODE_TYPE(TupleTypeNode);
 TVM_REGISTER_API("relay._make.TupleType")
 .set_body_typed(TupleTypeNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
-                                tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<TupleTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const TupleTypeNode*>(ref.get());
   p->stream << "TupleTypeNode(" << node->fields << ")";
 });
 
@@ -222,9 +221,9 @@ TVM_REGISTER_API("relay._make.RefType")
 
 TVM_REGISTER_NODE_TYPE(RefTypeNode);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
-                              tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<RefTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const RefTypeNode*>(ref.get());
   p->stream << "RefTypeNode(" << node->value << ")";
 });
 
index bd9e649..67c1391 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file type_functor.h
  * \brief A way to defined arbitrary function signature with dispatch on types.
  */
 #ifndef TVM_RELAY_IR_TYPE_FUNCTOR_H_
 #define TVM_RELAY_IR_TYPE_FUNCTOR_H_
 
-#include <tvm/node/ir_functor.h>
+#include <tvm/node/functor.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/adt.h>
 #include <string>
@@ -54,7 +53,7 @@ template <typename R, typename... Args>
 class TypeFunctor<R(const Type& n, Args...)> {
  private:
   using TSelf = TypeFunctor<R(const Type& n, Args...)>;
-  using FType = tvm::IRFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
+  using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
   /*! \brief the result type of this functor */
index dbecc6a..b025d37 100644 (file)
@@ -449,9 +449,9 @@ TVM_REGISTER_API("relay._transform.Info")
   *ret = pass->Info();
 });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PassInfoNode>([](const PassInfoNode* node,
-                                tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::IRPrinter* p) {
+  auto* node = static_cast<const PassInfoNode*>(ref.get());
   p->stream << "The meta data of the pass: ";
   p->stream << "pass name: " << node->name;
   p->stream << "opt_level: " << node->opt_level;
@@ -475,9 +475,9 @@ TVM_REGISTER_API("relay._transform.RunPass")
   *ret = pass(mod);
 });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<ModulePassNode>([](const ModulePassNode* node,
-                                 tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ModulePassNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const ModulePassNode*>(ref.get());
   const PassInfo info = node->Info();
   p->stream << "Run Module pass: " << info->name
             << " at the optimization level " << info->opt_level;
@@ -488,9 +488,9 @@ TVM_REGISTER_NODE_TYPE(FunctionPassNode);
 TVM_REGISTER_API("relay._transform.MakeFunctionPass")
 .set_body_typed(FunctionPassNode::make);
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
-                                   tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const FunctionPassNode*>(ref.get());
   const PassInfo info = node->Info();
   p->stream << "Run Function pass: " << info->name
             << " at the optimization level " << info->opt_level;
@@ -508,9 +508,9 @@ TVM_REGISTER_API("relay._transform.Sequential")
   *ret = Sequential(passes, pass_info);
 });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<SequentialNode>([](const SequentialNode* node,
-                                 tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<SequentialNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const SequentialNode*>(ref.get());
   const PassInfo info = node->Info();
   p->stream << "Run Sequential pass: " << info->name
             << " at the optimization level " << info->opt_level << ". ";
@@ -538,9 +538,9 @@ TVM_REGISTER_API("relay._transform.PassContext")
   *ret = pctx;
 });
 
-TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
-.set_dispatch<PassContextNode>([](const PassContextNode* node,
-                               tvm::IRPrinter* p) {
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<PassContextNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* node = static_cast<const PassContextNode*>(ref.get());
   p->stream << "Pass context information: " << "\n";
   p->stream << "\topt_level: " << node->opt_level << "\n";
   p->stream << "\tfallback device: "
index d564d2e..2793577 100644 (file)
@@ -117,7 +117,8 @@ QConfig& QConfig::Current() {
 TVM_REGISTER_NODE_TYPE(QConfigNode);
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) {
+.set_dispatch<QConfigNode>([](const ObjectRef& ref, IRPrinter* p) {
+  auto* op = static_cast<const QConfigNode*>(ref.get());
   p->stream << "qconfig(";
   p->stream << "nbit_input=" << op->nbit_input << ", ";
   p->stream << "nbit_weight=" << op->nbit_weight << ", ";
index 407729d..54503fc 100644 (file)
@@ -800,17 +800,20 @@ TVM_REGISTER_NODE_TYPE(ScheduleNode);
 
 // Printer
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<StageNode>([](const StageNode *op, IRPrinter *p) {
+.set_dispatch<StageNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const StageNode*>(node.get());
     if (op->op.defined()) {
       p->stream << "stage(" << op->origin_op->name << ", " << op << ")";
     } else {
       p->stream << "group-stage(" << op << ")";
     }
 })
-.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
+.set_dispatch<IterVarAttrNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const IterVarAttrNode*>(node.get());
     p->stream << IterVarType2String(op->iter_type);
 })
-.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
+.set_dispatch<SplitNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const SplitNode*>(node.get());
     p->stream << "split(parent=";
     p->Print(op->parent);
     p->stream << ", outer=";
@@ -819,7 +822,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
     p->Print(op->inner);
     p->stream << ')';
 })
-.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
+.set_dispatch<FuseNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const FuseNode*>(node.get());
     p->stream << "split(";
     p->stream << "outer=";
     p->Print(op->outer);
@@ -829,7 +833,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
     p->Print(op->fused);
     p->stream << ')';
 })
-.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
+.set_dispatch<RebaseNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const RebaseNode*>(node.get());
     p->stream << "rebase(";
     p->stream << "parent=";
     p->Print(op->parent);
@@ -837,12 +842,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
     p->Print(op->rebased);
     p->stream << ')';
 })
-.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
+.set_dispatch<SingletonNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const SingletonNode*>(node.get());
     p->stream << "singleton(";
     p->Print(op->iter);
     p->stream << ')';
 })
-.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
+.set_dispatch<ScheduleNode>([](const ObjectRef& node, IRPrinter* p) {
+    auto* op = static_cast<const ScheduleNode*>(node.get());
     p->stream << "schedule(" << op << ")";
   });
 }  // namespace tvm
index 038ad6f..9ccb9c9 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -53,7 +53,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
 TEST(Attrs, Basic) {
   using namespace tvm;
   using namespace tvm::test;
-  std::shared_ptr<TestAttrs> n = std::make_shared<TestAttrs>();
+  ObjectPtr<TestAttrs> n = make_object<TestAttrs>();
   try {
     n->InitBySeq("axis", 10);
     LOG(FATAL) << "bad";
index fef43f9..5636958 100644 (file)
@@ -21,7 +21,7 @@
 #include <gtest/gtest.h>
 #include <tvm/ir.h>
 #include <tvm/expr_operator.h>
-#include <tvm/node/ir_functor.h>
+#include <tvm/node/functor.h>
 #include <tvm/ir_functor_ext.h>
 
 TEST(IRF, Basic) {
@@ -30,12 +30,12 @@ TEST(IRF, Basic) {
   Var x("x");
   auto z = x + 1;
 
-  IRFunctor<int(const ObjectRef& n, int b)> f;
+  NodeFunctor<int(const ObjectRef& n, int b)> f;
   LOG(INFO) << "x";
-  f.set_dispatch<Variable>([](const Variable* n, int b) {
+  f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
       return b;
     });
-  f.set_dispatch<Add>([](const Add* n, int b) {
+  f.set_dispatch<Add>([](const ObjectRef& n, int b) {
       return b + 2;
     });
   CHECK_EQ(f(x, 2),  2);
index 30972e7..1b3296d 100644 (file)
@@ -45,7 +45,7 @@ IRMutator::FMutateExpr &IRVar2Const::vtable_expr() {  // NOLINT(*)
 }
 
 TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
-.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
+.set_dispatch<Variable>([](const ObjectRef& ref, const Expr &e, IRMutator* m) {
     IRVar2Const* vm = static_cast<IRVar2Const*>(m);
     if (e.same_as(vm->var)) {
       return Expr(IntImm::make(Int(32), vm->int_val));