[REFACTOR][IR] attrs.h -> ir (#4709)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 15 Jan 2020 17:07:22 +0000 (09:07 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Wed, 15 Jan 2020 17:07:22 +0000 (09:07 -0800)
This PR moves attrs.h into the ir folder as it
can serve as a common infra for building ir dats structures.

We also moved common container(FloatImm) into ir/expr.h

41 files changed:
include/tvm/expr_operator.h
include/tvm/ir.h
include/tvm/ir/attrs.h [moved from include/tvm/attrs.h with 96% similarity]
include/tvm/ir/expr.h
include/tvm/ir/op.h
include/tvm/ir/type_relation.h
include/tvm/relay/adt.h
include/tvm/relay/attrs/algorithm.h
include/tvm/relay/attrs/annotation.h
include/tvm/relay/attrs/bitserial.h
include/tvm/relay/attrs/debug.h
include/tvm/relay/attrs/device_copy.h
include/tvm/relay/attrs/image.h
include/tvm/relay/attrs/memory.h
include/tvm/relay/attrs/nn.h
include/tvm/relay/attrs/reduce.h
include/tvm/relay/attrs/transform.h
include/tvm/relay/attrs/vision.h
include/tvm/relay/expr.h
include/tvm/relay/qnn/attrs.h
include/tvm/relay/type.h
src/api/api_ir.cc
src/api/api_pass.cc
src/api/api_test.cc
src/arithmetic/const_fold.h
src/autotvm/touch_extractor.cc
src/codegen/llvm/codegen_x86_64.cc
src/ir/attr_functor.h [moved from src/lang/attr_functor.h with 99% similarity]
src/ir/attrs.cc [moved from src/lang/attrs.cc with 99% similarity]
src/ir/expr.cc
src/ir/op.cc
src/lang/expr.cc
src/lang/expr_operator.cc
src/lang/ir.cc
src/node/reflection.cc
src/node/serialization.cc
src/pass/storage_access.h
src/relay/ir/alpha_equal.cc
src/relay/ir/hash.cc
src/relay/ir/pretty_printer.cc
tests/cpp/attrs_test.cc

index ff3b340..7d6b752 100644 (file)
@@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
       return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
     }
   }
-  if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
+  if (t.is_float()) return FloatImm(t, static_cast<double>(value));
   // For now, we store const scalar values of custom datatypes within doubles; later, during the
   // datatypes lowering pass, we will lower the value to its true representation in the format
   // specified by the datatype.
   // TODO(gus) when do we need to start worrying about doubles not being precise enough?
   if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
-    return ir::FloatImmNode::make(t, static_cast<double>(value));
+    return FloatImm(t, static_cast<double>(value));
   }
   LOG(FATAL) << "cannot make const for type " << t;
   return PrimExpr();
index 9c14a31..553db4e 100644 (file)
@@ -37,25 +37,9 @@ namespace tvm {
 namespace ir {
 
 using IntImmNode = tvm::IntImmNode;
+using FloatImmNode = tvm::FloatImmNode;
 using VarNode = tvm::VarNode;
 
-/*! \brief Floating point constants. */
-class FloatImmNode : public PrimExprNode {
- public:
-  /*! \brief The constant value content. */
-  double value;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("dtype", &dtype);
-    v->Visit("value", &value);
-  }
-
-  TVM_DLL static PrimExpr make(DataType t, double value);
-
-  static constexpr const char* _type_key = "FloatImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
-};
-
 /*! \brief String constants, only used in asserts. */
 class StringImmNode : public PrimExprNode {
  public:
similarity index 96%
rename from include/tvm/attrs.h
rename to include/tvm/ir/attrs.h
index 9d9f98e..5916a78 100644 (file)
  * specific language governing permissions and limitations
  * under the License.
  */
-
 /*!
- * \file tvm/attrs.h
- * \brief TVM attribute module
+ * \file tvm/ir/attrs.h
+ * \brief Helpers for attribute objects.
  *
  *  This module enables declaration of named attributes
  *  which support default value setup and bound checking.
  *
  * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
  */
-#ifndef TVM_ATTRS_H_
-#define TVM_ATTRS_H_
+#ifndef TVM_IR_ATTRS_H_
+#define TVM_IR_ATTRS_H_
 
 #include <dmlc/common.h>
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/packed_func.h>
+
 #include <unordered_map>
 #include <vector>
 #include <functional>
 #include <type_traits>
 #include <string>
 #include <utility>
-#include "ir.h"
-#include "base.h"
-#include "expr.h"
-#include "packed_func_ext.h"
 
 namespace tvm {
 /*!
@@ -63,10 +61,10 @@ namespace tvm {
  * \param ClassName The name of the class.
  * \param TypeKey The type key to be used by the TVM node system.
  */
-#define TVM_DECLARE_ATTRS(ClassName, TypeKey)                   \
-  static constexpr const char* _type_key = TypeKey;             \
+#define TVM_DECLARE_ATTRS(ClassName, TypeKey)                      \
+  static constexpr const char* _type_key = TypeKey;                \
   TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode)   \
-  template<typename FVisit>                                     \
+  template<typename FVisit>                                        \
   void __VisitAttrs__(FVisit& __fvisit__)  // NOLINT(*)
 
 
@@ -481,45 +479,36 @@ template<typename T>
 inline void SetValue(T* ptr, const TVMArgValue& val) {
   *ptr = val.operator T();
 }
+
 template<typename T>
 inline void SetIntValue(T* ptr, const TVMArgValue& val) {
   if (val.type_code() == kDLInt) {
     *ptr = static_cast<T>(val.value().v_int64);
   } else {
-    PrimExpr expr = val;
-    CHECK(expr.defined());
-    if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
-      *ptr = static_cast<T>(op->value);
-    } else {
-      LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
-    }
+    IntImm expr = val;
+    *ptr = static_cast<T>(expr->value);
   }
 }
+
 template<>
 inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
   if (val.type_code() == kStr) {
     *ptr = val.operator std::string();
   } else {
-    PrimExpr expr = val;
-    const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
-    CHECK(op != nullptr);
-    *ptr = op->value;
+    LOG(FATAL) << "Expect str";
   }
 }
-template<>
-inline void SetValue(DataType* ptr, const TVMArgValue& val) {
-  *ptr = val.operator DataType();
-}
+
 template<>
 inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
   if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
     *ptr = val.operator double();
   } else {
-    PrimExpr expr = val;
+    ObjectRef expr = val;
     CHECK(expr.defined());
-    if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
+    if (const IntImmNode* op = expr.as<IntImmNode>()) {
       *ptr = static_cast<double>(op->value);
-    } else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
+    } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
       *ptr = static_cast<double>(op->value);
     } else {
       LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
@@ -611,7 +600,7 @@ struct TypeName<uint64_t> {
 
 template<>
 struct TypeName<DataType> {
-  static constexpr const char* value = "Type";
+  static constexpr const char* value = "DataType";
 };
 
 template<>
@@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
 }
 
 }  // namespace tvm
-#endif  // TVM_ATTRS_H_
+#endif  // TVM_IR_ATTRS_H_
index 12e505e..ddb5f80 100644 (file)
@@ -182,6 +182,56 @@ class IntImm : public PrimExpr {
 };
 
 /*!
+ * \brief Constant floating point literals in the program.
+ * \sa FloatImm
+ */
+class FloatImmNode : public PrimExprNode {
+ public:
+  /*! \brief The constant value content. */
+  double value;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &dtype);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "FloatImm";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference class to FloatImmNode.
+ *
+ * \sa FloatImmNode
+ */
+class FloatImm : public PrimExpr {
+ public:
+  /*!
+   * \brief Constructor
+   */
+  FloatImm() {}
+  /*!
+   * \brief constructor from node.
+   */
+  explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
+  /*!
+   * \brief Constructor.
+   * \param dtype The data type of the value.
+   * \param value The internal value.
+   */
+  TVM_DLL FloatImm(DataType dtype, double value);
+  /*!
+   * \brief Get pointer to the container.
+   * \return The pointer.
+   */
+  const FloatImmNode* operator->() const {
+    return static_cast<const FloatImmNode*>(get());
+  }
+  /*! \brief type indicate the container type */
+  using ContainerType = FloatImmNode;
+};
+
+/*!
  * \brief Base node of all non-primitive expressions.
  *
  * RelayExpr supports tensor types, functions and ADT as
index f5d0639..f3a8603 100644 (file)
@@ -26,7 +26,7 @@
 #define TVM_IR_OP_H_
 
 #include <dmlc/registry.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/type.h>
@@ -296,7 +296,8 @@ class OpRegistry {
   // return internal pointer to op.
   inline OpNode* get();
   // update the attribute OpMap
-  TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
+  TVM_DLL void UpdateAttr(const std::string& key,
+                          runtime::TVMRetValue value,
                           int plevel);
 };
 
@@ -316,7 +317,7 @@ class GenericOpMap {
    * \param op The key to the map
    * \return the const reference to the content value.
    */
-  inline const TVMRetValue& operator[](const Op& op) const;
+  inline const runtime::TVMRetValue& operator[](const Op& op) const;
   /*!
    * \brief get the corresponding value element at op with default value.
    * \param op The key to the map
@@ -342,7 +343,7 @@ class GenericOpMap {
   // the attribute field.
   std::string attr_name_;
   // internal data
-  std::vector<std::pair<TVMRetValue, int> > data_;
+  std::vector<std::pair<runtime::TVMRetValue, int> > data_;
   // The value
   GenericOpMap() = default;
 };
@@ -532,7 +533,7 @@ template <typename ValueType>
 inline OpRegistry& OpRegistry::set_attr(  // NOLINT(*)
     const std::string& attr_name, const ValueType& value, int plevel) {
   CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
-  TVMRetValue rv;
+  runtime::TVMRetValue rv;
   rv = value;
   UpdateAttr(attr_name, rv, plevel);
   return *this;
@@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const {
   }
 }
 
-inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
+inline const runtime::TVMRetValue&
+GenericOpMap::operator[](const Op& op) const {
   CHECK(op.defined());
   const uint32_t idx = op->index_;
   CHECK(idx < data_.size() && data_[idx].second != 0)
index 333c538..962eea3 100644 (file)
@@ -27,7 +27,7 @@
 #include <tvm/ir/type.h>
 #include <tvm/ir/module.h>
 #include <tvm/ir/env_func.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 
 namespace tvm {
 
index 1807696..6f72072 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ADT_H_
 #define TVM_RELAY_ADT_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir/adt.h>
 #include <string>
 #include <functional>
index ce14a6a..2d1b902 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
 #define TVM_RELAY_ATTRS_ALGORITHM_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <string>
 
index 4481d2a..cc21e34 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
 #define TVM_RELAY_ATTRS_ANNOTATION_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <string>
 
 namespace tvm {
index 2a7376b..962afc2 100644 (file)
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
 #define TVM_RELAY_ATTRS_BITSERIAL_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <string>
 
index 82a2046..ed9ed4e 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_DEBUG_H_
 #define TVM_RELAY_ATTRS_DEBUG_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <string>
 
 namespace tvm {
index 2469c4b..3935629 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
 #define TVM_RELAY_ATTRS_DEVICE_COPY_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <string>
 
 namespace tvm {
index 22d657d..4bf40e3 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_IMAGE_H_
 #define TVM_RELAY_ATTRS_IMAGE_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <string>
 
index c74b648..00204b3 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_MEMORY_H_
 #define TVM_RELAY_ATTRS_MEMORY_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/expr.h>
 #include <string>
 
index 549eb67..5620feb 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_NN_H_
 #define TVM_RELAY_ATTRS_NN_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <string>
 
index d5fe9b8..443efb5 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_REDUCE_H_
 #define TVM_RELAY_ATTRS_REDUCE_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <string>
 
 namespace tvm {
index 26637d5..11c7886 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
 #define TVM_RELAY_ATTRS_TRANSFORM_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <tvm/relay/expr.h>
 #include <string>
index b98bbfc..c4a30ce 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_ATTRS_VISION_H_
 #define TVM_RELAY_ATTRS_VISION_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/relay/base.h>
 #include <string>
 
index 47ae552..1062c20 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_EXPR_H_
 #define TVM_RELAY_EXPR_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir/expr.h>
 #include <string>
 #include <functional>
index 7ef0b10..3c1c4a3 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_QNN_ATTRS_H_
 #define TVM_RELAY_QNN_ATTRS_H_
 
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <string>
 
 namespace tvm {
index 17b8b57..d4243d8 100644 (file)
@@ -26,6 +26,7 @@
 
 #include <tvm/ir/type.h>
 #include <tvm/ir/type_relation.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/ir/env_func.h>
@@ -34,7 +35,7 @@
 #include <string>
 
 #include "base.h"
-#include "../attrs.h"
+
 
 namespace tvm {
 namespace relay {
index 30ca515..261b94e 100644 (file)
@@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
 REGISTER_MAKE(Reduce);
 REGISTER_MAKE(AttrStmt);
 
-REGISTER_MAKE(FloatImm);
 REGISTER_MAKE(StringImm);
 
 REGISTER_MAKE(Add);
index ff30f5e..639855c 100644 (file)
@@ -23,7 +23,7 @@
  */
 #include <tvm/expr.h>
 #include <tvm/ir.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_functor_ext.h>
 #include <tvm/runtime/registry.h>
index 7ded78b..0bc83ea 100644 (file)
@@ -23,7 +23,7 @@
  */
 #include <tvm/expr.h>
 #include <tvm/tensor.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/env_func.h>
 #include <tvm/packed_func_ext.h>
index 3b803ec..d82ac89 100644 (file)
@@ -102,7 +102,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
       if (pa && pb) return IntImm(rtype, pa->value + pb->value);
       if (pa && pa->value == 0) return b;
       if (pb && pb->value == 0) return a;
-      if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
+      if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
       if (fa && fa->value == 0) return b;
       if (fb && fb->value == 0) return a;
     });
@@ -115,7 +115,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImm(rtype, pa->value - pb->value);
       if (pb && pb->value == 0) return a;
-      if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
+      if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
       if (fb && fb->value == 0) return a;
     });
   return PrimExpr();
@@ -134,7 +134,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
         if (pb->value == 1) return a;
         if (pb->value == 0) return b;
       }
-      if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value);
+      if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
       if (fa) {
         if (fa->value == 1) return b;
         if (fa->value == 0) return a;
@@ -165,7 +165,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
       if (fa && fb && fb->value != 0) {
-        return FloatImmNode::make(rtype, fa->value / fb->value);
+        return FloatImm(rtype, fa->value / fb->value);
       }
       if (fa && fa->value == 0) return a;
       if (fb) {
@@ -210,7 +210,7 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
       if (fa && fb && fb->value != 0) {
-        return FloatImmNode::make(rtype, std::floor(fa->value / fb->value));
+        return FloatImm(rtype, std::floor(fa->value / fb->value));
       }
       if (fa && fa->value == 0) return a;
       if (fb) {
@@ -244,7 +244,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
-      if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
+      if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
   return PrimExpr();
@@ -255,7 +255,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
-      if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
+      if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
   return PrimExpr();
index 55ed36c..b5bf2ed 100644 (file)
@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
     feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
 
     Array<PrimExpr> attr{std::string("_attr_"),
-                     FloatImmNode::make(DataType::Float(32), trans(fea.length)),
+                     FloatImm(DataType::Float(32), trans(fea.length)),
                      IntImm(DataType::Int(32), fea.nest_level),
-                     FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
-                     FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
+                     FloatImm(DataType::Float(32), trans(fea.topdown_product)),
+                     FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
     };
     // one hot annotation
     for (int i = 0; i < kNum; i++) {
@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
 
     // arithmetic
     feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
-            FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
-            FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
-            FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
+            FloatImm(DataType::Float(32), trans(fea.add_ct)),
+            FloatImm(DataType::Float(32), trans(fea.mul_ct)),
+            FloatImm(DataType::Float(32), trans(fea.div_ct)),
     });
 
     // touch map
@@ -283,12 +283,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
       TouchPattern &v = fea.touch_feature[k];
       feature_row.push_back(
           Array<PrimExpr>{k,
-                FloatImmNode::make(DataType::Float(32), trans(v.stride)),
-                FloatImmNode::make(DataType::Float(32), trans(v.mod)),
-                FloatImmNode::make(DataType::Float(32), trans(v.count)),
-                FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
-                FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
-                FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
+                FloatImm(DataType::Float(32), trans(v.stride)),
+                FloatImm(DataType::Float(32), trans(v.mod)),
+                FloatImm(DataType::Float(32), trans(v.count)),
+                FloatImm(DataType::Float(32), trans(v.reuse)),
+                FloatImm(DataType::Float(32), trans(v.thread_count)),
+                FloatImm(DataType::Float(32), trans(v.thread_reuse)),
                 });
     }
 
index 11bda70..2e41931 100644 (file)
@@ -95,7 +95,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
                 ir::CallNode::PureIntrinsic)),
                 MakeValue(
                     ir::BroadcastNode::make(
-                      ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
+                      FloatImm(DataType::Float(32), 0), from.lanes())),
                 /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
                 /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
           });
similarity index 99%
rename from src/lang/attr_functor.h
rename to src/ir/attr_functor.h
index 4fffc47..c140123 100644 (file)
@@ -27,8 +27,8 @@
  *  - array of attributes
  *  - map of attributes
  */
-#ifndef TVM_LANG_ATTR_FUNCTOR_H_
-#define TVM_LANG_ATTR_FUNCTOR_H_
+#ifndef TVM_IR_ATTR_FUNCTOR_H_
+#define TVM_IR_ATTR_FUNCTOR_H_
 
 #include <tvm/node/functor.h>
 #include <utility>
@@ -230,4 +230,4 @@ class AttrsHashHandler :
   }
 };
 }  // namespace tvm
-#endif  // TVM_LANG_ATTR_FUNCTOR_H_
+#endif  // TVM_IR_ATTR_FUNCTOR_H_
similarity index 99%
rename from src/lang/attrs.cc
rename to src/ir/attrs.cc
index a590f10..a487995 100644 (file)
@@ -20,7 +20,7 @@
 /*!
  * \file attrs.cc
  */
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/packed_func_ext.h>
 
index 6d89967..0cf91c2 100644 (file)
@@ -45,6 +45,22 @@ TVM_REGISTER_GLOBAL("make.IntImm")
   return IntImm(dtype, value);
 });
 
+
+FloatImm::FloatImm(DataType dtype, double value) {
+  CHECK_EQ(dtype.lanes(), 1)
+      << "ValueError: FloatImm can only take scalar.";
+  ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
+  node->dtype = dtype;
+  node->value = value;
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("make.FloatImm")
+.set_body_typed([](DataType dtype, double value) {
+  return FloatImm(dtype, value);
+});
+
+
 GlobalVar::GlobalVar(std::string name_hint) {
   ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
   n->name_hint = std::move(name_hint);
index 0ed2f3d..f1be383 100644 (file)
@@ -36,6 +36,10 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
 
 namespace tvm {
 
+using runtime::TVMRetValue;
+using runtime::TVMArgs;
+using runtime::PackedFunc;
+
 ::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
   return ::dmlc::Registry<OpRegistry>::Get();
 }
index 55dfb89..62cbc37 100644 (file)
@@ -33,7 +33,7 @@ PrimExpr::PrimExpr(int32_t value)
     : PrimExpr(IntImm(DataType::Int(32), value)) {}
 
 PrimExpr::PrimExpr(float value)
-    : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
+    : PrimExpr(FloatImm(DataType::Float(32), value)) {}
 
 PrimExpr::PrimExpr(std::string str)
     : PrimExpr(ir::StringImmNode::make(str)) {}
index bd43d89..3055767 100644 (file)
@@ -108,11 +108,11 @@ PrimExpr max_value(const DataType& dtype) {
     }
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
-      return FloatImmNode::make(dtype, std::numeric_limits<double>::max());
+      return FloatImm(dtype, std::numeric_limits<double>::max());
     } else if (dtype.bits() == 32) {
-      return FloatImmNode::make(dtype, std::numeric_limits<float>::max());
+      return FloatImm(dtype, std::numeric_limits<float>::max());
     } else if (dtype.bits() == 16) {
-      return FloatImmNode::make(dtype, 65504.0);
+      return FloatImm(dtype, 65504.0);
     }
   }
   LOG(FATAL) << "Cannot decide max_value for type" << dtype;
@@ -134,11 +134,11 @@ PrimExpr min_value(const DataType& dtype) {
     return IntImm(dtype, 0);
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
-      return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
+      return FloatImm(dtype, std::numeric_limits<double>::lowest());
     } else if (dtype.bits() == 32) {
-      return FloatImmNode::make(dtype, std::numeric_limits<float>::lowest());
+      return FloatImm(dtype, std::numeric_limits<float>::lowest());
     } else if (dtype.bits() == 16) {
-      return FloatImmNode::make(dtype, -65504.0);
+      return FloatImm(dtype, -65504.0);
     }
   }
   LOG(FATAL) << "Cannot decide min_value for type" << dtype;
@@ -219,7 +219,7 @@ PrimExpr operator-(PrimExpr a) {
   const IntImmNode* pa = a.as<IntImmNode>();
   const FloatImmNode* fa = a.as<FloatImmNode>();
   if (pa) return IntImm(a.dtype(), -pa->value);
-  if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
+  if (fa) return FloatImm(a.dtype(), -fa->value);
   return make_zero(a.dtype()) - a;
 }
 
@@ -492,7 +492,7 @@ PrimExpr abs(PrimExpr x) {
     using ir::FloatImmNode;
     const FloatImmNode* fx = x.as<FloatImmNode>();
     if (fx) {
-      return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value));
+      return FloatImm(x.dtype(), std::fabs(fx->value));
     }
     return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic);
   } else if (x.dtype().is_uint()) {
@@ -593,28 +593,28 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
 PrimExpr floor(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
-  if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
+  if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
   return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
 }
 
 PrimExpr ceil(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
-  if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
+  if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
   return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
 }
 
 PrimExpr round(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
-  if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+  if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
   return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
 }
 
 PrimExpr nearbyint(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
-  if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+  if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
   return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
 }
 
@@ -622,7 +622,7 @@ PrimExpr trunc(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) {
-    return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
+    return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
                                      std::floor(fx->value)));
   }
   return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic);
index f06a6be..b7f3c27 100644 (file)
@@ -32,7 +32,7 @@ namespace ir {
 
 // constructors
 
-PrimExpr FloatImmNode::make(DataType t, double value) {
+PrimExpr FloatImm(DataType t, double value) {
   CHECK_EQ(t.lanes(), 1)
       << "ValueError: FloatImm can only take scalar";
   ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
index f535837..df162cf 100644 (file)
 #include <tvm/node/node.h>
 #include <tvm/node/container.h>
 #include <tvm/node/reflection.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 
 namespace tvm {
 
+using runtime::TVMRetValue;
+using runtime::TVMArgs;
+using runtime::PackedFunc;
+
 // Attr getter.
 class AttrGetter : public AttrVisitor {
  public:
index 5e8a0f7..d45112d 100644 (file)
@@ -29,7 +29,7 @@
 #include <tvm/node/container.h>
 #include <tvm/node/reflection.h>
 #include <tvm/node/serialization.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 
 #include <string>
 #include <map>
index 80400ad..aea9f1e 100644 (file)
@@ -25,7 +25,7 @@
 #define TVM_PASS_STORAGE_ACCESS_H_
 
 #include <tvm/ir.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_functor_ext.h>
 #include <vector>
index 4398e44..ae4b83f 100644 (file)
@@ -29,7 +29,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/attrs/nn.h>
 #include "type_functor.h"
-#include "../../lang/attr_functor.h"
+#include "../../ir/attr_functor.h"
 namespace tvm {
 namespace relay {
 
index 3bc72fd..0ee9ac5 100644 (file)
@@ -26,9 +26,9 @@
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/relay/analysis.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
 #include "type_functor.h"
-#include "../../lang/attr_functor.h"
+#include "../../ir/attr_functor.h"
 
 namespace tvm {
 namespace relay {
index 400a6be..e2fc57e 100644 (file)
@@ -38,7 +38,7 @@
 #include "doc.h"
 #include "type_functor.h"
 #include "../pass/dependency_graph.h"
-#include "../../lang/attr_functor.h"
+#include "../../ir/attr_functor.h"
 
 namespace tvm {
 namespace relay {
index 9a24257..730d204 100644 (file)
@@ -19,7 +19,9 @@
 
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
-#include <tvm/attrs.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/expr_operator.h>
+#include <tvm/packed_func_ext.h>
 #include <tvm/ir.h>
 
 namespace tvm {