[NODE] Macro to define NodeRef methods, constructor style example (#3224)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 23 May 2019 17:17:03 +0000 (10:17 -0700)
committerGitHub <noreply@github.com>
Thu, 23 May 2019 17:17:03 +0000 (10:17 -0700)
include/tvm/arithmetic.h
include/tvm/base.h
src/api/api_arith.cc
src/arithmetic/const_int_bound.cc
src/arithmetic/modular_set.cc

index 9a8d9d3..6eec767 100644 (file)
@@ -48,11 +48,7 @@ namespace arith {
 
 // Forward declare Analyzer
 class Analyzer;
-/*!
- * \brief reference class to ConstIntBoundNode
- * \sa ConstIntBoundNode
- */
-class ConstIntBound;
+
 /*!
  * \brief Constant integer up and lower bound(inclusive).
  *  Useful for value bound analysis.
@@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
     v->Visit("max_value", &max_value);
   }
 
-  TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);
-
   /*! \brief Number to represent +inf */
   static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
   /*!
@@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
   TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
 };
 
-TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);
+/*!
+ * \brief reference class to ConstIntBoundNode
+ * \sa ConstIntBoundNode
+ */
+class ConstIntBound : public NodeRef {
+ public:
+  /*!
+   * \brief constructor by fields.
+   * \param min_value The mininum value.
+   * \param max_value The maximum value.
+   */
+  TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
+
+  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
+  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
+  TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
+};
 
 /*!
  * \brief Analyzer to get constant integer bound over expression.
@@ -134,11 +144,6 @@ class ConstIntBoundAnalyzer {
 };
 
 /*!
- * \brief reference of ModularSetNode
- * \sa ModularSetNode
- */
-class ModularSet;
-/*!
  * \brief Range of a linear integer function.
  *  Use to do specify the possible index values.
  *
@@ -162,13 +167,20 @@ class ModularSetNode : public Node {
     v->Visit("base", &base);
   }
 
-  TVM_DLL static ModularSet make(int64_t coeff, int64_t base);
-
   static constexpr const char* _type_key = "arith.ModularSet";
   TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
 };
 
-TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);
+/*!
+ * \brief reference of ModularSetNode
+ * \sa ModularSetNode
+ */
+class ModularSet : public NodeRef {
+ public:
+  TVM_DLL ModularSet(int64_t coeff, int64_t base);
+
+  TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
+};
 
 /*!
  * \brief Analyzer to get modular information over expression.
index ae2d91f..049a427 100644 (file)
@@ -39,21 +39,24 @@ using ::tvm::Node;
 using ::tvm::NodeRef;
 using ::tvm::AttrVisitor;
 
-/*! \brief Macro to make it easy to define node ref type given node */
-#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                  \
-  class TypeName : public ::tvm::NodeRef {                       \
-   public:                                                       \
-    TypeName() {}                                                 \
-    explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {}     \
-    const NodeName* operator->() const {                          \
-      return static_cast<const NodeName*>(node_.get());           \
-    }                                                             \
-    using ContainerType = NodeName;                               \
-  };                                                              \
+/*!
+ * \brief Macro to define common node ref methods.
+ * \param TypeName The name of the NodeRef.
+ * \param BaseTypeName The Base type.
+ * \param NodeName The node container type.
+ */
+#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName)   \
+  TypeName() {}                                                         \
+  explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
+  const NodeName* operator->() const {                                  \
+    return static_cast<const NodeName*>(node_.get());                   \
+  }                                                                     \
+  operator bool() const { return this->defined(); }                     \
+  using ContainerType = NodeName;
 
 /*!
- * \brief Macro to make it easy to define node ref type that
- *  has a CopyOnWrite member function.
+ * \brief Macro to define CopyOnWrite function in a NodeRef.
+ * \param NodeName The Type of the Node.
  *
  *  CopyOnWrite will generate a unique copy of the internal node.
  *  The node will be copied if it is referenced by multiple places.
@@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
  *
  * \endcode
  */
-#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName)           \
-  class TypeName : public BaseType {                                    \
-   public:                                                              \
-    TypeName() {}                                                       \
-    explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {}   \
-    const NodeName* operator->() const {                                \
-      return static_cast<const NodeName*>(node_.get());                 \
-    }                                                                   \
-    inline NodeName* CopyOnWrite() {                                    \
+#define TVM_DEFINE_NODE_REF_COW(NodeName)                               \
+  NodeName* CopyOnWrite() {                                             \
       CHECK(node_ != nullptr);                                          \
       if (!node_.unique())  {                                           \
         NodePtr<NodeName> n = make_node<NodeName>(*(operator->()));     \
         NodePtr<Node>(std::move(n)).swap(node_);                        \
       }                                                                 \
       return static_cast<NodeName*>(node_.get());                       \
-    }                                                                   \
-    using ContainerType = NodeName;                                     \
-  };
+    }
 
+/*! \brief Macro to make it easy to define node ref type given node */
+#define TVM_DEFINE_NODE_REF(TypeName, NodeName)                      \
+  class TypeName : public ::tvm::NodeRef {                           \
+   public:                                                           \
+    TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
+  };                                                                 \
+
+/*!
+ * \brief Macro to make it easy to define node ref type that
+ *  has a CopyOnWrite member function.
+ */
+#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName)           \
+  class TypeName : public BaseType {                                    \
+   public:                                                              \
+    TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName);          \
+    TVM_DEFINE_NODE_REF_COW(NodeName);                                  \
+  };
 
 /*!
  * \brief save the node as well as all the node it depends on as json.
index fce73aa..55a7064 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
 TVM_REGISTER_API("arith.DomainTouched")
 .set_body_typed(DomainTouched);
 
-
 TVM_REGISTER_API("_IntervalSetGetMin")
 .set_body_method(&IntSet::min);
 
@@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
 TVM_REGISTER_API("_IntSetIsEverything")
 .set_body_method(&IntSet::is_everything);
 
+ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
+  return ConstIntBound(min_value, max_value);
+}
+
 TVM_REGISTER_API("arith._make_ConstIntBound")
-.set_body_typed(ConstIntBoundNode::make);
+.set_body_typed(MakeConstIntBound);
+
+ModularSet MakeModularSet(int64_t coeff, int64_t base) {
+  return ModularSet(coeff, base);
+}
 
 TVM_REGISTER_API("arith._make_ModularSet")
-.set_body_typed(ModularSetNode::make);
+.set_body_typed(MakeModularSet);
 
 TVM_REGISTER_API("arith._CreateAnalyzer")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
index bfd06c8..72b8508 100644 (file)
@@ -34,12 +34,12 @@ using namespace ir;
 
 TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
 
-ConstIntBound ConstIntBoundNode::make(
+ConstIntBound::ConstIntBound(
     int64_t min_value, int64_t max_value) {
   auto node = make_node<ConstIntBoundNode>();
   node->min_value = min_value;
   node->max_value = max_value;
-  return ConstIntBound(node);
+  node_ = std::move(node);
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
@@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
   std::vector<BoundInfo> additional_info_;
   // constants: the limit value means umlimited
   // NOTE: kNegInf/kPosInf are used to represent infinity.
-  static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
-  static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
+  static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
+  static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
   static_assert(-kNegInf == kPosInf, "invariant of inf");
   // internal helper functions
   /*!
@@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :
 
 ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
   Entry ret = impl_->VisitExpr(expr);
-  return ConstIntBoundNode::make(ret.min_value, ret.max_value);
+  return ConstIntBound(ret.min_value, ret.max_value);
 }
 
 void ConstIntBoundAnalyzer::Update(const Var& var,
index 7701e04..57e8294 100644 (file)
@@ -35,11 +35,12 @@ using namespace ir;
 
 TVM_REGISTER_NODE_TYPE(ModularSetNode);
 
-ModularSet ModularSetNode::make(int64_t coeff, int64_t base) {
+ModularSet::ModularSet(int64_t coeff, int64_t base) {
   auto node = make_node<ModularSetNode>();
   node->coeff = coeff;
   node->base = base;
-  return ModularSet(node);
+  // finish construction.
+  node_ = std::move(node);
 }
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
@@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
    * \return Bound that represent everything dtype can represent.
    */
   static Entry Nothing() {
-      return Entry(0, 1);
+    return Entry(0, 1);
   }
 };
 
 ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
   Entry ret = impl_->VisitExpr(expr);
-  return ModularSetNode::make(ret.coeff, ret.base);
+  return ModularSet(ret.coeff, ret.base);
 }
 
 void ModularSetAnalyzer::Update(const Var& var,