From b0b51f25301046860720f54a6dd4239868143439 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 15 Jan 2020 09:07:22 -0800 Subject: [PATCH] [REFACTOR][IR] attrs.h -> ir (#4709) This PR moves attrs.h into the ir folder as it can serve as a common infra for building ir dats structures. We also moved common container(FloatImm) into ir/expr.h --- include/tvm/expr_operator.h | 4 +-- include/tvm/ir.h | 18 +----------- include/tvm/{ => ir}/attrs.h | 53 ++++++++++++++--------------------- include/tvm/ir/expr.h | 50 +++++++++++++++++++++++++++++++++ include/tvm/ir/op.h | 14 +++++---- include/tvm/ir/type_relation.h | 2 +- include/tvm/relay/adt.h | 2 +- include/tvm/relay/attrs/algorithm.h | 2 +- include/tvm/relay/attrs/annotation.h | 2 +- include/tvm/relay/attrs/bitserial.h | 2 +- include/tvm/relay/attrs/debug.h | 2 +- include/tvm/relay/attrs/device_copy.h | 2 +- include/tvm/relay/attrs/image.h | 2 +- include/tvm/relay/attrs/memory.h | 2 +- include/tvm/relay/attrs/nn.h | 2 +- include/tvm/relay/attrs/reduce.h | 2 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/relay/attrs/vision.h | 2 +- include/tvm/relay/expr.h | 2 +- include/tvm/relay/qnn/attrs.h | 2 +- include/tvm/relay/type.h | 3 +- src/api/api_ir.cc | 1 - src/api/api_pass.cc | 2 +- src/api/api_test.cc | 2 +- src/arithmetic/const_fold.h | 14 ++++----- src/autotvm/touch_extractor.cc | 24 ++++++++-------- src/codegen/llvm/codegen_x86_64.cc | 2 +- src/{lang => ir}/attr_functor.h | 6 ++-- src/{lang => ir}/attrs.cc | 2 +- src/ir/expr.cc | 16 +++++++++++ src/ir/op.cc | 4 +++ src/lang/expr.cc | 2 +- src/lang/expr_operator.cc | 26 ++++++++--------- src/lang/ir.cc | 2 +- src/node/reflection.cc | 6 +++- src/node/serialization.cc | 2 +- src/pass/storage_access.h | 2 +- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/hash.cc | 4 +-- src/relay/ir/pretty_printer.cc | 2 +- tests/cpp/attrs_test.cc | 4 ++- 41 files changed, 174 insertions(+), 123 deletions(-) rename include/tvm/{ => ir}/attrs.h (96%) rename src/{lang => ir}/attr_functor.h (99%) rename src/{lang => ir}/attrs.cc (99%) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index ff3b340..7d6b752 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { return LargeUIntImm(t, static_cast(low), static_cast(high)); } } - if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); + if (t.is_float()) return FloatImm(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? if (static_cast(t.code()) >= static_cast(kCustomBegin)) { - return ir::FloatImmNode::make(t, static_cast(value)); + return FloatImm(t, static_cast(value)); } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 9c14a31..553db4e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -37,25 +37,9 @@ namespace tvm { namespace ir { using IntImmNode = tvm::IntImmNode; +using FloatImmNode = tvm::FloatImmNode; using VarNode = tvm::VarNode; -/*! \brief Floating point constants. */ -class FloatImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - double value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, double value); - - static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); -}; - /*! \brief String constants, only used in asserts. */ class StringImmNode : public PrimExprNode { public: diff --git a/include/tvm/attrs.h b/include/tvm/ir/attrs.h similarity index 96% rename from include/tvm/attrs.h rename to include/tvm/ir/attrs.h index 9d9f98e..5916a78 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/ir/attrs.h @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ - /*! - * \file tvm/attrs.h - * \brief TVM attribute module + * \file tvm/ir/attrs.h + * \brief Helpers for attribute objects. * * This module enables declaration of named attributes * which support default value setup and bound checking. @@ -42,20 +41,19 @@ * * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD */ -#ifndef TVM_ATTRS_H_ -#define TVM_ATTRS_H_ +#ifndef TVM_IR_ATTRS_H_ +#define TVM_IR_ATTRS_H_ #include +#include +#include + #include #include #include #include #include #include -#include "ir.h" -#include "base.h" -#include "expr.h" -#include "packed_func_ext.h" namespace tvm { /*! @@ -63,10 +61,10 @@ namespace tvm { * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ +#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ + static constexpr const char* _type_key = TypeKey; \ TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ - template \ + template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) @@ -481,45 +479,36 @@ template inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } + template inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast(val.value().v_int64); } else { - PrimExpr expr = val; - CHECK(expr.defined()); - if (const ir::IntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); - } else { - LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); - } + IntImm expr = val; + *ptr = static_cast(expr->value); } } + template<> inline void SetValue(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kStr) { *ptr = val.operator std::string(); } else { - PrimExpr expr = val; - const ir::StringImmNode* op = expr.as(); - CHECK(op != nullptr); - *ptr = op->value; + LOG(FATAL) << "Expect str"; } } -template<> -inline void SetValue(DataType* ptr, const TVMArgValue& val) { - *ptr = val.operator DataType(); -} + template<> inline void SetValue(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); } else { - PrimExpr expr = val; + ObjectRef expr = val; CHECK(expr.defined()); - if (const ir::IntImmNode* op = expr.as()) { + if (const IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::IntImmNode* op = expr.as()) { + } else if (const FloatImmNode* op = expr.as()) { *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); @@ -611,7 +600,7 @@ struct TypeName { template<> struct TypeName { - static constexpr const char* value = "Type"; + static constexpr const char* value = "DataType"; }; template<> @@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) } } // namespace tvm -#endif // TVM_ATTRS_H_ +#endif // TVM_IR_ATTRS_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 12e505e..ddb5f80 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -182,6 +182,56 @@ class IntImm : public PrimExpr { }; /*! + * \brief Constant floating point literals in the program. + * \sa FloatImm + */ +class FloatImmNode : public PrimExprNode { + public: + /*! \brief The constant value content. */ + double value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "FloatImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to FloatImmNode. + * + * \sa FloatImmNode + */ +class FloatImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + FloatImm() {} + /*! + * \brief constructor from node. + */ + explicit FloatImm(ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL FloatImm(DataType dtype, double value); + /*! + * \brief Get pointer to the container. + * \return The pointer. + */ + const FloatImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = FloatImmNode; +}; + +/*! * \brief Base node of all non-primitive expressions. * * RelayExpr supports tensor types, functions and ADT as diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index f5d0639..f3a8603 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -26,7 +26,7 @@ #define TVM_IR_OP_H_ #include -#include +#include #include #include #include @@ -296,7 +296,8 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value, + TVM_DLL void UpdateAttr(const std::string& key, + runtime::TVMRetValue value, int plevel); }; @@ -316,7 +317,7 @@ class GenericOpMap { * \param op The key to the map * \return the const reference to the content value. */ - inline const TVMRetValue& operator[](const Op& op) const; + inline const runtime::TVMRetValue& operator[](const Op& op) const; /*! * \brief get the corresponding value element at op with default value. * \param op The key to the map @@ -342,7 +343,7 @@ class GenericOpMap { // the attribute field. std::string attr_name_; // internal data - std::vector > data_; + std::vector > data_; // The value GenericOpMap() = default; }; @@ -532,7 +533,7 @@ template inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; - TVMRetValue rv; + runtime::TVMRetValue rv; rv = value; UpdateAttr(attr_name, rv, plevel); return *this; @@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const { } } -inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const { +inline const runtime::TVMRetValue& +GenericOpMap::operator[](const Op& op) const { CHECK(op.defined()); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second != 0) diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 333c538..962eea3 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include namespace tvm { diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 1807696..6f72072 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_ -#include +#include #include #include #include diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index ce14a6a..2d1b902 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_ALGORITHM_H_ #define TVM_RELAY_ATTRS_ALGORITHM_H_ -#include +#include #include #include diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 4481d2a..cc21e34 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_ANNOTATION_H_ #define TVM_RELAY_ATTRS_ANNOTATION_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h index 2a7376b..962afc2 100644 --- a/include/tvm/relay/attrs/bitserial.h +++ b/include/tvm/relay/attrs/bitserial.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAY_ATTRS_BITSERIAL_H_ #define TVM_RELAY_ATTRS_BITSERIAL_H_ -#include +#include #include #include diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h index 82a2046..ed9ed4e 100644 --- a/include/tvm/relay/attrs/debug.h +++ b/include/tvm/relay/attrs/debug.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_DEBUG_H_ #define TVM_RELAY_ATTRS_DEBUG_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 2469c4b..3935629 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 22d657d..4bf40e3 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_IMAGE_H_ #define TVM_RELAY_ATTRS_IMAGE_H_ -#include +#include #include #include diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index c74b648..00204b3 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_MEMORY_H_ #define TVM_RELAY_ATTRS_MEMORY_H_ -#include +#include #include #include diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 549eb67..5620feb 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_NN_H_ #define TVM_RELAY_ATTRS_NN_H_ -#include +#include #include #include diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index d5fe9b8..443efb5 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_REDUCE_H_ #define TVM_RELAY_ATTRS_REDUCE_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 26637d5..11c7886 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_TRANSFORM_H_ #define TVM_RELAY_ATTRS_TRANSFORM_H_ -#include +#include #include #include #include diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index b98bbfc..c4a30ce 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_ATTRS_VISION_H_ #define TVM_RELAY_ATTRS_VISION_H_ -#include +#include #include #include diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 47ae552..1062c20 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_EXPR_H_ #define TVM_RELAY_EXPR_H_ -#include +#include #include #include #include diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 7ef0b10..3c1c4a3 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_QNN_ATTRS_H_ #define TVM_RELAY_QNN_ATTRS_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 17b8b57..d4243d8 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -34,7 +35,7 @@ #include #include "base.h" -#include "../attrs.h" + namespace tvm { namespace relay { diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 30ca515..261b94e 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); REGISTER_MAKE(Add); diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index ff30f5e..639855c 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include diff --git a/src/api/api_test.cc b/src/api/api_test.cc index 7ded78b..0bc83ea 100644 --- a/src/api/api_test.cc +++ b/src/api/api_test.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 3b803ec..d82ac89 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -102,7 +102,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); + if (fa && fb) return FloatImm(rtype, fa->value + fb->value); if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); @@ -115,7 +115,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); + if (fa && fb) return FloatImm(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; }); return PrimExpr(); @@ -134,7 +134,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pb->value == 1) return a; if (pb->value == 0) return b; } - if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value); + if (fa && fb) return FloatImm(rtype, fa->value * fb->value); if (fa) { if (fa->value == 1) return b; if (fa->value == 0) return a; @@ -165,7 +165,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImmNode::make(rtype, fa->value / fb->value); + return FloatImm(rtype, fa->value / fb->value); } if (fa && fa->value == 0) return a; if (fb) { @@ -210,7 +210,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImmNode::make(rtype, std::floor(fa->value / fb->value)); + return FloatImm(rtype, std::floor(fa->value / fb->value)); } if (fa && fa->value == 0) return a; if (fb) { @@ -244,7 +244,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); + if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; return PrimExpr(); @@ -255,7 +255,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); + if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; return PrimExpr(); diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 55ed36c..b5bf2ed 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(Array{std::string("_itervar_"), var}); Array attr{std::string("_attr_"), - FloatImmNode::make(DataType::Float(32), trans(fea.length)), + FloatImm(DataType::Float(32), trans(fea.length)), IntImm(DataType::Int(32), fea.nest_level), - FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), - FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), + FloatImm(DataType::Float(32), trans(fea.topdown_product)), + FloatImm(DataType::Float(32), trans(fea.bottomup_product)), }; // one hot annotation for (int i = 0; i < kNum; i++) { @@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > // arithmetic feature_row.push_back(Array{std::string("_arith_"), - FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)), - FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)), - FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)), + FloatImm(DataType::Float(32), trans(fea.add_ct)), + FloatImm(DataType::Float(32), trans(fea.mul_ct)), + FloatImm(DataType::Float(32), trans(fea.div_ct)), }); // touch map @@ -283,12 +283,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > TouchPattern &v = fea.touch_feature[k]; feature_row.push_back( Array{k, - FloatImmNode::make(DataType::Float(32), trans(v.stride)), - FloatImmNode::make(DataType::Float(32), trans(v.mod)), - FloatImmNode::make(DataType::Float(32), trans(v.count)), - FloatImmNode::make(DataType::Float(32), trans(v.reuse)), - FloatImmNode::make(DataType::Float(32), trans(v.thread_count)), - FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)), + FloatImm(DataType::Float(32), trans(v.stride)), + FloatImm(DataType::Float(32), trans(v.mod)), + FloatImm(DataType::Float(32), trans(v.count)), + FloatImm(DataType::Float(32), trans(v.reuse)), + FloatImm(DataType::Float(32), trans(v.thread_count)), + FloatImm(DataType::Float(32), trans(v.thread_reuse)), }); } diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 11bda70..2e41931 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -95,7 +95,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ir::CallNode::PureIntrinsic)), MakeValue( ir::BroadcastNode::make( - ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), + FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); diff --git a/src/lang/attr_functor.h b/src/ir/attr_functor.h similarity index 99% rename from src/lang/attr_functor.h rename to src/ir/attr_functor.h index 4fffc47..c140123 100644 --- a/src/lang/attr_functor.h +++ b/src/ir/attr_functor.h @@ -27,8 +27,8 @@ * - array of attributes * - map of attributes */ -#ifndef TVM_LANG_ATTR_FUNCTOR_H_ -#define TVM_LANG_ATTR_FUNCTOR_H_ +#ifndef TVM_IR_ATTR_FUNCTOR_H_ +#define TVM_IR_ATTR_FUNCTOR_H_ #include #include @@ -230,4 +230,4 @@ class AttrsHashHandler : } }; } // namespace tvm -#endif // TVM_LANG_ATTR_FUNCTOR_H_ +#endif // TVM_IR_ATTR_FUNCTOR_H_ diff --git a/src/lang/attrs.cc b/src/ir/attrs.cc similarity index 99% rename from src/lang/attrs.cc rename to src/ir/attrs.cc index a590f10..a487995 100644 --- a/src/lang/attrs.cc +++ b/src/ir/attrs.cc @@ -20,7 +20,7 @@ /*! * \file attrs.cc */ -#include +#include #include #include diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 6d89967..0cf91c2 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -45,6 +45,22 @@ TVM_REGISTER_GLOBAL("make.IntImm") return IntImm(dtype, value); }); + +FloatImm::FloatImm(DataType dtype, double value) { + CHECK_EQ(dtype.lanes(), 1) + << "ValueError: FloatImm can only take scalar."; + ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("make.FloatImm") +.set_body_typed([](DataType dtype, double value) { + return FloatImm(dtype, value); +}); + + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); diff --git a/src/ir/op.cc b/src/ir/op.cc index 0ed2f3d..f1be383 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -36,6 +36,10 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); namespace tvm { +using runtime::TVMRetValue; +using runtime::TVMArgs; +using runtime::PackedFunc; + ::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); } diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 55dfb89..62cbc37 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -33,7 +33,7 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) - : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} + : PrimExpr(FloatImm(DataType::Float(32), value)) {} PrimExpr::PrimExpr(std::string str) : PrimExpr(ir::StringImmNode::make(str)) {} diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index bd43d89..3055767 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -108,11 +108,11 @@ PrimExpr max_value(const DataType& dtype) { } } else if (dtype.is_float()) { if (dtype.bits() == 64) { - return FloatImmNode::make(dtype, std::numeric_limits::max()); + return FloatImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() == 32) { - return FloatImmNode::make(dtype, std::numeric_limits::max()); + return FloatImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() == 16) { - return FloatImmNode::make(dtype, 65504.0); + return FloatImm(dtype, 65504.0); } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; @@ -134,11 +134,11 @@ PrimExpr min_value(const DataType& dtype) { return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { - return FloatImmNode::make(dtype, std::numeric_limits::lowest()); + return FloatImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() == 32) { - return FloatImmNode::make(dtype, std::numeric_limits::lowest()); + return FloatImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() == 16) { - return FloatImmNode::make(dtype, -65504.0); + return FloatImm(dtype, -65504.0); } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; @@ -219,7 +219,7 @@ PrimExpr operator-(PrimExpr a) { const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value); - if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); + if (fa) return FloatImm(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } @@ -492,7 +492,7 @@ PrimExpr abs(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { - return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value)); + return FloatImm(x.dtype(), std::fabs(fx->value)); } return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { @@ -593,28 +593,28 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) { PrimExpr floor(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); - if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value)); + if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic); } PrimExpr ceil(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); - if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value)); + if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic); } PrimExpr round(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); - if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); + if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic); } PrimExpr nearbyint(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); - if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); + if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic); } @@ -622,7 +622,7 @@ PrimExpr trunc(PrimExpr x) { using ir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { - return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : + return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index f06a6be..b7f3c27 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -32,7 +32,7 @@ namespace ir { // constructors -PrimExpr FloatImmNode::make(DataType t, double value) { +PrimExpr FloatImm(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; ObjectPtr node = make_object(); diff --git a/src/node/reflection.cc b/src/node/reflection.cc index f535837..df162cf 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -25,10 +25,14 @@ #include #include #include -#include +#include namespace tvm { +using runtime::TVMRetValue; +using runtime::TVMArgs; +using runtime::PackedFunc; + // Attr getter. class AttrGetter : public AttrVisitor { public: diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 5e8a0f7..d45112d 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 80400ad..aea9f1e 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -25,7 +25,7 @@ #define TVM_PASS_STORAGE_ACCESS_H_ #include -#include +#include #include #include #include diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 4398e44..ae4b83f 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -29,7 +29,7 @@ #include #include #include "type_functor.h" -#include "../../lang/attr_functor.h" +#include "../../ir/attr_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 3bc72fd..0ee9ac5 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -26,9 +26,9 @@ #include #include #include -#include +#include #include "type_functor.h" -#include "../../lang/attr_functor.h" +#include "../../ir/attr_functor.h" namespace tvm { namespace relay { diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 400a6be..e2fc57e 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -38,7 +38,7 @@ #include "doc.h" #include "type_functor.h" #include "../pass/dependency_graph.h" -#include "../../lang/attr_functor.h" +#include "../../ir/attr_functor.h" namespace tvm { namespace relay { diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index 9a24257..730d204 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -19,7 +19,9 @@ #include #include -#include +#include +#include +#include #include namespace tvm { -- 2.7.4