// Forward declare Analyzer
class Analyzer;
-/*!
- * \brief reference class to ConstIntBoundNode
- * \sa ConstIntBoundNode
- */
-class ConstIntBound;
+
/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
v->Visit("max_value", &max_value);
}
- TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);
-
/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};
-TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);
+/*!
+ * \brief reference class to ConstIntBoundNode
+ * \sa ConstIntBoundNode
+ */
+class ConstIntBound : public NodeRef {
+ public:
+ /*!
+ * \brief constructor by fields.
+ * \param min_value The mininum value.
+ * \param max_value The maximum value.
+ */
+ TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
+
+ static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
+ static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
+ TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
+};
/*!
* \brief Analyzer to get constant integer bound over expression.
};
/*!
- * \brief reference of ModularSetNode
- * \sa ModularSetNode
- */
-class ModularSet;
-/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
v->Visit("base", &base);
}
- TVM_DLL static ModularSet make(int64_t coeff, int64_t base);
-
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};
-TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);
+/*!
+ * \brief reference of ModularSetNode
+ * \sa ModularSetNode
+ */
+class ModularSet : public NodeRef {
+ public:
+ TVM_DLL ModularSet(int64_t coeff, int64_t base);
+
+ TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
+};
/*!
* \brief Analyzer to get modular information over expression.
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
-/*! \brief Macro to make it easy to define node ref type given node */
-#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
- class TypeName : public ::tvm::NodeRef { \
- public: \
- TypeName() {} \
- explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
- const NodeName* operator->() const { \
- return static_cast<const NodeName*>(node_.get()); \
- } \
- using ContainerType = NodeName; \
- }; \
+/*!
+ * \brief Macro to define common node ref methods.
+ * \param TypeName The name of the NodeRef.
+ * \param BaseTypeName The Base type.
+ * \param NodeName The node container type.
+ */
+#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
+ TypeName() {} \
+ explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
+ const NodeName* operator->() const { \
+ return static_cast<const NodeName*>(node_.get()); \
+ } \
+ operator bool() const { return this->defined(); } \
+ using ContainerType = NodeName;
/*!
- * \brief Macro to make it easy to define node ref type that
- * has a CopyOnWrite member function.
+ * \brief Macro to define CopyOnWrite function in a NodeRef.
+ * \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
*
* \endcode
*/
-#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
- class TypeName : public BaseType { \
- public: \
- TypeName() {} \
- explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
- const NodeName* operator->() const { \
- return static_cast<const NodeName*>(node_.get()); \
- } \
- inline NodeName* CopyOnWrite() { \
+#define TVM_DEFINE_NODE_REF_COW(NodeName) \
+ NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
- } \
- using ContainerType = NodeName; \
- };
+ }
+/*! \brief Macro to make it easy to define node ref type given node */
+#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
+ class TypeName : public ::tvm::NodeRef { \
+ public: \
+ TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
+ }; \
+
+/*!
+ * \brief Macro to make it easy to define node ref type that
+ * has a CopyOnWrite member function.
+ */
+#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
+ class TypeName : public BaseType { \
+ public: \
+ TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
+ TVM_DEFINE_NODE_REF_COW(NodeName); \
+ };
/*!
* \brief save the node as well as all the node it depends on as json.
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
TVM_REGISTER_API("arith.DomainTouched")
.set_body_typed(DomainTouched);
-
TVM_REGISTER_API("_IntervalSetGetMin")
.set_body_method(&IntSet::min);
TVM_REGISTER_API("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything);
+ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
+ return ConstIntBound(min_value, max_value);
+}
+
TVM_REGISTER_API("arith._make_ConstIntBound")
-.set_body_typed(ConstIntBoundNode::make);
+.set_body_typed(MakeConstIntBound);
+
+ModularSet MakeModularSet(int64_t coeff, int64_t base) {
+ return ModularSet(coeff, base);
+}
TVM_REGISTER_API("arith._make_ModularSet")
-.set_body_typed(ModularSetNode::make);
+.set_body_typed(MakeModularSet);
TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
-ConstIntBound ConstIntBoundNode::make(
+ConstIntBound::ConstIntBound(
int64_t min_value, int64_t max_value) {
auto node = make_node<ConstIntBoundNode>();
node->min_value = min_value;
node->max_value = max_value;
- return ConstIntBound(node);
+ node_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
- static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
- static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
+ static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
+ static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
static_assert(-kNegInf == kPosInf, "invariant of inf");
// internal helper functions
/*!
ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr);
- return ConstIntBoundNode::make(ret.min_value, ret.max_value);
+ return ConstIntBound(ret.min_value, ret.max_value);
}
void ConstIntBoundAnalyzer::Update(const Var& var,
TVM_REGISTER_NODE_TYPE(ModularSetNode);
-ModularSet ModularSetNode::make(int64_t coeff, int64_t base) {
+ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>();
node->coeff = coeff;
node->base = base;
- return ModularSet(node);
+ // finish construction.
+ node_ = std::move(node);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() {
- return Entry(0, 1);
+ return Entry(0, 1);
}
};
ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr);
- return ModularSetNode::make(ret.coeff, ret.base);
+ return ModularSet(ret.coeff, ret.base);
}
void ModularSetAnalyzer::Update(const Var& var,