[INFA][IR] Build and Evolve Low-level IR. Remove HalideIR dep. (#3533)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 11 Jul 2019 21:26:43 +0000 (14:26 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Jul 2019 21:26:43 +0000 (14:26 -0700)
* [INFA][IR] Build and Evolve Low-level IR. Remove dep from HalideIR.

* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
* Update include/tvm/node/ir_functor.h

Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
50 files changed:
CMakeLists.txt
apps/howto_deploy/tvm_runtime_pack.cc
include/tvm/arithmetic.h
include/tvm/attrs.h
include/tvm/data_layout.h
include/tvm/dtype.h [new file with mode: 0644]
include/tvm/expr.h
include/tvm/ir.h
include/tvm/ir_mutator.h
include/tvm/lowered_func.h
include/tvm/node/container.h [new file with mode: 0644]
include/tvm/node/ir_functor.h [new file with mode: 0644]
include/tvm/node/memory.h [new file with mode: 0644]
include/tvm/node/node.h [new file with mode: 0644]
include/tvm/operation.h
include/tvm/packed_func_ext.h
include/tvm/runtime/packed_func.h
include/tvm/schedule.h
include/tvm/schedule_pass.h
include/tvm/tensor.h
src/api/api_ir.cc
src/api/api_lang.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/const_int_bound.cc
src/codegen/codegen_opengl.cc
src/lang/expr.cc
src/lang/ir.cc
src/lang/tensor.cc
src/node/node.cc [new file with mode: 0644]
src/op/compute_op.cc
src/op/extern_op.cc
src/op/hybrid_op.cc
src/op/op_util.cc
src/op/scan_op.cc
src/pass/inject_prefetch.cc
src/pass/ir_deep_compare.cc
src/pass/ir_mutator.cc
src/pass/storage_access.h
src/pass/storage_flatten.cc
src/relay/backend/build_module.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/ir/expr.cc
src/relay/op/nn/upsampling.cc
src/schedule/graph.cc
src/schedule/schedule_lang.cc
tests/cpp/ir_mutator_test.cc
tests/cpp/ir_ssa_test.cc
tests/cpp/ir_visitor_test.cc
topi/include/topi/image/resize.h
topi/include/topi/transform.h

index c23d403bcb6a1cdf4cbe44023499412decd6c65c..534a9f80b1ac1ab30446cfab2de6c69aa045de1a 100644 (file)
@@ -76,7 +76,6 @@ if(MSVC)
   add_definitions(-D_CRT_SECURE_NO_WARNINGS)
   add_definitions(-D_SCL_SECURE_NO_WARNINGS)
   add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
-  add_definitions(-DHalide_SHARED)
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
   set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
@@ -112,8 +111,8 @@ else(MSVC)
 endif(MSVC)
 
 # add source group
-FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
-FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
+FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
+FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
                                 "nnvm/src/*.h" "nnvm/include/*.h")
 assign_source_group("Source" ${GROUP_SOURCE})
 assign_source_group("Include" ${GROUP_INCLUDE})
@@ -127,6 +126,7 @@ file(GLOB COMPILER_SRCS
     src/lang/*.cc
     src/pass/*.cc
     src/op/*.cc
+    src/node/*.cc
     src/schedule/*.cc
     )
 
@@ -154,12 +154,7 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
 file(GLOB TOPI_SRCS
     topi/src/*.cc
 )
-file(GLOB_RECURSE HALIDEIR_SRCS
-  3rdparty/HalideIR/src/base/*.cpp
-  3rdparty/HalideIR/src/ir/*.cpp
-  3rdparty/HalideIR/src/tvm/*.cpp
-)
-list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
+
 file(GLOB RUNTIME_SRCS
   src/runtime/*.cc
   src/runtime/vm/*.cc
@@ -245,7 +240,6 @@ target_link_libraries(nnvm_compiler tvm)
 # Related headers
 target_include_directories(
   tvm
-  PUBLIC "3rdparty/HalideIR/src"
   PUBLIC "topi/include")
 target_include_directories(
   tvm_topi
@@ -294,11 +288,6 @@ if (INSTALL_DEV)
     FILES_MATCHING
     PATTERN "*.h"
   )
-  install(
-    DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
-    FILES_MATCHING
-    PATTERN "*.h"
-  )
   install(
     DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
     FILES_MATCHING
@@ -319,8 +308,6 @@ endif(INSTALL_DEV)
 
 # More target definitions
 if(MSVC)
-  target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
-  target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
   target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
   target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
   target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
index 736add6e9fa36758352136e46d9d3c62c7efcb82..6ebad8177cd55172f659e6719b902ae93763adc4 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 446c4c0c19a911c4e26f2d3fea8ca08105b0f71e..105cbf7af8e9867deb005e8ee0bb7758cf19e279 100644 (file)
@@ -591,7 +591,7 @@ IntSet EvalSet(Range r,
                const std::unordered_map<const Variable*, IntSet>& dom_map);
 
 /*! \brief Map from Expr to IntSet */
-using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
+using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
 /*!
  * \brief Find the integer set of every sub-expression, given the
  *  domain of each iteration variables.
index ed021beb5d35ad90d46b7cce1641d204e31acc07..10fbe9f2ce4de2c9b300351bb099dc346327be44 100644 (file)
@@ -89,8 +89,8 @@ inline TNodeRef NullValue() {
 }
 
 template<>
-inline Type NullValue<Type>() {
-  return Type(Type::Handle, 0, 0);
+inline DataType NullValue<DataType>() {
+  return DataType(kHandle, 0, 0);
 }
 
 /*! \brief Error thrown during attribute checking. */
index ff5f8e37dbb6507412c24a41ac2adc130357f281..a703d928ba5f886688186ac5ccac728125612d0f 100644 (file)
@@ -221,7 +221,7 @@ class Layout : public NodeRef {
     if (!this->defined()) return -1;
     const auto axes = operator->()->axes;
     for (size_t i = 0; i < axes.size(); ++i) {
-      if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
+      if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
     }
     return -1;
   }
@@ -243,7 +243,7 @@ class Layout : public NodeRef {
   bool Contains(const LayoutAxis& axis) const {
     if (!defined()) return false;
     for (const IterVar var : operator->()->axes) {
-      if (var->var.get()->name_hint == axis.name()) {
+      if (var->var->name_hint == axis.name()) {
         return true;
       }
     }
diff --git a/include/tvm/dtype.h b/include/tvm/dtype.h
new file mode 100644 (file)
index 0000000..60a96a3
--- /dev/null
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file tvm/dtype.h
+ * \brief Data type used in IR.
+ */
+#ifndef TVM_DTYPE_H_
+#define TVM_DTYPE_H_
+
+#include "runtime/packed_func.h"
+
+namespace tvm {
+class Expr;
+
+/*!
+ * \brief Primitive data types in tvm.
+ */
+class DataType {
+ public:
+  /*! \brief default constructor */
+  DataType() {}
+  /*!
+   * \brief Constructor
+   * \param dtype The DLDataType
+   */
+  explicit DataType(DLDataType dtype)
+      : data_(dtype) {}
+  /*!
+   * \brief Constructor
+   * \param code The type code.
+   * \param bits The number of bits in the type.
+   * \param lanes The number of lanes.
+   */
+  DataType(int code, int bits, int lanes) {
+    data_.code = static_cast<uint8_t>(code);
+    data_.bits = static_cast<uint8_t>(bits);
+    data_.lanes = static_cast<uint16_t>(lanes);
+  }
+  /*! \return The type code. */
+  int code() const {
+    return static_cast<int>(data_.code);
+  }
+  /*! \return number of bits in the data. */
+  int bits() const {
+    return static_cast<int>(data_.bits);
+  }
+  /*! \return number of bytes to store each scalar. */
+  int bytes() const {
+    return (bits() + 7) / 8;
+  }
+  /*! \return number of lanes in the data. */
+  int lanes() const {
+    return static_cast<int>(data_.lanes);
+  }
+  /*! \return whether type is a scalar type. */
+  bool is_scalar() const {
+    return lanes() == 1;
+  }
+  /*! \return whether type is a scalar type. */
+  bool is_bool() const {
+    return code() == kDLUInt && bits() == 1;
+  }
+  /*! \return whether type is a float type. */
+  bool is_float() const {
+    return code() == kDLFloat;
+  }
+  /*! \return whether type is an int type. */
+  bool is_int() const {
+    return code() == kDLInt;
+  }
+  /*! \return whether type is an uint type. */
+  bool is_uint() const {
+    return code() == kDLUInt;
+  }
+  /*! \return whether type is a handle type. */
+  bool is_handle() const {
+    return code() == kHandle;
+  }
+  /*! \return whether type is a vector type. */
+  bool is_vector() const {
+    return lanes() > 1;
+  }
+  /*!
+   * \brief Create a new data type by change lanes to a specified value.
+   * \param lanes The target number of lanes.
+   * \return the result type.
+   */
+  DataType with_lanes(int lanes) const {
+    return DataType(data_.code, data_.bits, lanes);
+  }
+  /*!
+   * \brief Create a new data type by change bits to a specified value.
+   * \param bits The target number of bits.
+   * \return the result type.
+   */
+  DataType with_bits(int bits) const {
+    return DataType(data_.code, bits, data_.lanes);
+  }
+  /*!
+   * \brief Get the scalar version of the type.
+   * \return the result type.
+   */
+  DataType element_of() const {
+    return with_lanes(1);
+  }
+  // operator overloadings
+  bool operator==(const DataType& other) const {
+    return
+        data_.code == other.data_.code &&
+        data_.bits == other.data_.bits &&
+        data_.lanes == other.data_.lanes;
+  }
+  bool operator!=(const DataType& other) const {
+    return !operator==(other);
+  }
+  operator DLDataType () const {
+    return data_;
+  }
+  /*! \return the maximum possible value in this format. */
+  TVM_DLL Expr max() const;
+  /*! \return the minimum possible value in this format. */
+  TVM_DLL Expr min() const;
+
+ private:
+  DLDataType data_;
+};
+
+/*!
+ * \brief Construct an int type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes.
+ * \return The constructed data type.
+ */
+inline DataType Int(int bits, int lanes = 1) {
+  return DataType(kDLInt, bits, lanes);
+}
+
+/*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType UInt(int bits, int lanes = 1) {
+  return DataType(kDLUInt, bits, lanes);
+}
+
+/*!
+ * \brief Construct a bool type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Bool(int lanes = 1) {
+  return UInt(1, lanes);
+}
+
+/*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Float(int bits, int lanes = 1) {
+  return DataType(kDLFloat, bits, lanes);
+}
+
+/*!
+ * \brief Construct a handle type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Handle(int bits = 64, int lanes = 1) {
+  return DataType(kHandle, bits, lanes);
+}
+
+/*!
+ * \brief Get the corresponding type of TVMShapeIndex.
+ * \return The type of TVM shape index.
+ */
+inline DataType TVMShapeIndexType() {
+  if (std::is_signed<tvm_index_t>::value) {
+    return Int(sizeof(tvm_index_t) * 8);
+  } else {
+    return UInt(sizeof(tvm_index_t) * 8);
+  }
+}
+
+/*!
+ * \brief Convert DLDataType to DataType.
+ * \param t The original type.
+ * \return The conversion result.
+ */
+inline DataType TVMType2Type(DLDataType t) {
+  return DataType(t.code, t.bits, t.lanes);
+}
+
+/*!
+ * \brief Convert DataType to DataType.
+ * \param t The original type.
+ * \return The conversion result.
+ */
+inline DLDataType Type2TVMType(DataType t) {
+  return t.operator DLDataType();
+}
+
+/*!
+ * \brief Get the number of bytes needed in a vector.
+ * \param dtype The data type.
+ * \return Number of bytes needed.
+ */
+inline int GetVectorBytes(DataType dtype) {
+  int data_bits = dtype.bits() * dtype.lanes();
+  // allow bool to exist
+  if (dtype == Bool()) return 1;
+  CHECK_EQ(data_bits % 8, 0U)
+      << "Need to load/store by multiple of bytes";
+  return data_bits / 8;
+}
+
+// Overload print function.
+inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
+  using namespace tvm::runtime;
+  return os << dtype.operator DLDataType();
+}
+
+// Backward compatibility
+using Type = DataType;
+}  // namespace tvm
+#endif  //  TVM_DTYPE_H_
index c7e69d59d6367739e0d5ada02d9b95e454c69e96..07cfbc7791da5f1a09b1996d27e17615ab906663 100644 (file)
@@ -1,4 +1,3 @@
-
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
 #ifndef TVM_EXPR_H_
 #define TVM_EXPR_H_
 
-#include <ir/Expr.h>
-#include <ir/IRPrinter.h>
 #include <string>
 #include <algorithm>
 #include <unordered_map>
 #include "base.h"
+#include "dtype.h"
+#include "node/container.h"
+#include "node/ir_functor.h"
 #include "runtime/c_runtime_api.h"
 
 namespace tvm {
 
-using HalideIR::Type;
-using HalideIR::Float;
-using HalideIR::Bool;
-using HalideIR::Int;
-using HalideIR::UInt;
-using HalideIR::Handle;
-using HalideIR::ExprHash;
-using HalideIR::ExprEqual;
-
-using HalideIR::Expr;
-using HalideIR::VarExpr;
-using HalideIR::IR::RangeNode;
-using HalideIR::IR::FunctionRef;
-using HalideIR::IR::FunctionBaseNode;
-using HalideIR::Internal::IntImm;
-using HalideIR::Internal::Stmt;
-using HalideIR::Internal::IRPrinter;
-using HalideIR::Internal::Variable;
-
-inline Type TVMShapeIndexType() {
-  if (std::is_signed<tvm_index_t>::value) {
-    return Int(sizeof(tvm_index_t) * 8);
-  } else {
-    return UInt(sizeof(tvm_index_t) * 8);
+/*! \brief Base node of all expressions. */
+class ExprNode : public Node {
+ public:
+  /*! \brief The data type of the expression. */
+  DataType type;
+
+  static constexpr const char* _type_key = "Expr";
+  TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
+};
+
+/*! \brief Container of all expressions. */
+class Expr : public NodeRef {
+ public:
+  Expr() {}
+  explicit Expr(NodePtr<Node> ptr) : NodeRef(ptr) {}
+  /*!
+   * \brief construct from integer.
+   * \param value The value to be constructed.
+   */
+  TVM_DLL Expr(int32_t value);  // NOLINT(*)
+  /*!
+   * \brief construct from float.
+   * \param value The value to be constructed.
+   */
+  TVM_DLL Expr(float value);  // NOLINT(*)
+  /*!
+   * \brief construct from string.
+   * \param str The value to be constructed.
+   */
+  TVM_DLL Expr(std::string str);  // NOLINT(*)
+
+  /*! \return the data type of this expression. */
+  DataType type() const {
+    return static_cast<const ExprNode*>(get())->type;
   }
-}
 
-inline Type TVMType2Type(TVMType t) {
-  return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
-}
+  /*! \brief type indicate the container type */
+  using ContainerType = ExprNode;
+};
 
-inline TVMType Type2TVMType(Type t) {
-  TVMType ret;
-  ret.code = static_cast<uint8_t>(t.code());
-  ret.bits = static_cast<uint8_t>(t.bits());
-  ret.lanes = static_cast<uint16_t>(t.lanes());
-  return ret;
-}
+/*! \brief Base node of all statements. */
+class StmtNode : public Node {
+ public:
+  static constexpr const char* _type_key = "Stmt";
+  TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
+};
 
-// Get number of bytes considering vector type.
-inline int GetVectorBytes(Type dtype) {
-  int data_bits = dtype.bits() * dtype.lanes();
-  // allow bool to exist
-  if (dtype == Bool()) return 1;
-  CHECK_EQ(data_bits % 8, 0U)
-      << "Need to load/store by multiple of bytes";
-  return data_bits / 8;
-}
+/*! \brief Container of all statements */
+class Stmt : public NodeRef {
+ public:
+  TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
+};
+
+class Var;
+/*!
+ * \brief A variable node in the IR.
+ *
+ * A vraible is uniquely identified by its address.
+ *
+ * Each variable is only binded once in the following nodes:
+ * - Allocate
+ * - For
+ * - Let
+ * - LetStmt
+ */
+class Variable : public ExprNode {
+ public:
+  /*!
+   * \brief The hint to the variable name.
+   * \note Each variable is uniquely identified by its address.
+   */
+  std::string name_hint;
+
+  static Var make(DataType dtype, std::string name_hint);
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("name", &name_hint);
+  }
+
+  static constexpr const char* _type_key = "Variable";
+  TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
+};
 
 /*! \brief a named variable in TVM */
-class Var : public HalideIR::VarExpr {
+class Var : public Expr {
  public:
-  EXPORT explicit Var(const std::string& name_hint = "v",
-               Type t = Int(32)) : VarExpr(name_hint, t) {}
-  explicit Var(NodePtr<Node> n) : VarExpr(n) {}
-  explicit Var(VarExpr v) : VarExpr(v) {}
+  explicit Var(NodePtr<Node> n) : Expr(n) {}
+  TVM_DLL explicit Var(std::string name_hint = "v",
+                       Type t = Int(32));
   /*!
    * \brief Make a new copy of var with same type, append suffix
    * \param suffix The suffix to be appended.
@@ -99,10 +133,47 @@ class Var : public HalideIR::VarExpr {
   Var copy_with_suffix(const std::string& suffix) const {
     return Var((*this)->name_hint + suffix, (*this)->type);
   }
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const Variable* operator->() const {
+    return get();
+  }
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the corresponding Variable.
+   */
+  const Variable* get() const {
+    return static_cast<Variable*>(node_.get());
+  }
   /*! \brief type indicate the container type */
   using ContainerType = Variable;
 };
 
+// Backward compatibility, will be removed later.
+using VarExpr = Var;
+using BaseExprNode = ExprNode;
+using ExprHash = NodeHash;
+using ExprEqual = NodeEqual;
+
+class Integer;
+/*! \brief ExprNode: constant integer. */
+class IntImm : public ExprNode {
+ public:
+  /*! \brief the Internal value. */
+  int64_t value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static Integer make(DataType t, int64_t value);
+
+  static constexpr const char* _type_key = "IntImm";
+  TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
+};
 
 /*!
  * \brief Container of constant integer (IntImm).
@@ -148,34 +219,52 @@ class Integer : public Expr {
   using ContainerType = IntImm;
 };
 
+/*! \brief range over one dimension */
+class RangeNode : public Node {
+ public:
+  /*! \brief beginning of the node */
+  Expr min;
+  /*! \brief the extend of range */
+  Expr extent;
+  /*! \brief constructor */
+  RangeNode() {}
+  RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
 
-/*! \brief container class of iteration variable. */
-class IterVarNode;
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("min", &min);
+    v->Visit("extent", &extent);
+  }
 
-/*!
- * \brief same as HalideIR::IR::Range
- *  except it provide an constructor with (begin, end)
- *
- *  \note Traditional Halide's Range have a constructor with
- *   (begin, extent), which does not match the convention in e.g. python.
- *   We decided to correct it by removing the constructor in HalideIR,
- *   and add it back in TVM's range.
- */
-class Range : public HalideIR::IR::Range {
+  static constexpr const char* _type_key = "Range";
+  TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
+};
+
+/*! \brief Range constainer  */
+class Range : public NodeRef {
  public:
-  /*! \brief constructor */
-  Range() {}
-  explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
   /*!
    * \brief constructor by begin and end
    * \param begin The begin of the range.
    * \param end The end of the range.
    */
   TVM_DLL Range(Expr begin, Expr end);
-
-  TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
+  /*!
+   * \brief construct a new range with min and extent
+   *  The corresponding constructor is removed,
+   *  because that is counter convention of tradition meaning
+   *  of range(begin, end)
+   *
+   * \param min The minimum range.
+   * \param extent The extent of the range.
+   */
+  static Range make_by_min_extent(Expr min, Expr extent);
+  // declare range.
+  TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
 };
 
+/*! \brief container class of iteration variable. */
+class IterVarNode;
+
 using Region = Array<Range>;
 
 /*!
@@ -289,9 +378,6 @@ TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
 
 using Domain = Array<Range>;
 
-// print functions for expr
-TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n);  // NOLINT(*)
-
 /*!
  * \brief Dump the node to stderr, used for debug purposes.
  * \param node The input node
@@ -364,7 +450,7 @@ inline const char* IterVarType2String(IterVarType t) {
  * \param name_hint The name hint for the expression
  * \param t The type of the expression
  */
-TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
+TVM_DLL Var var(std::string name_hint, Type t = Int(32));
 
 /*
  * \brief Template function to convert Map to unordered_map
@@ -382,6 +468,32 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
   }
   return ret;
 }
+
+// Printer infra.
+/*! \brief A Pretty printer class to print the IR. */
+class IRPrinter {
+ public:
+  /*! \brief The output stream */
+  std::ostream& stream;
+  /*! \brief The indentation level. */
+  int indent{0};
+  explicit IRPrinter(std::ostream& stream)  // NOLINT(*)
+      : stream(stream) {}
+
+  /*! \brief The node to be printed. */
+  TVM_DLL void Print(const NodeRef& node);
+  /*! \brief Print indent to the stream */
+  TVM_DLL void PrintIndent();
+  // Allow registration to be printer.
+  using FType = IRFunctor<void(const NodeRef&, IRPrinter *)>;
+  TVM_DLL static FType& vtable();
+};
+
+// default print function for all nodes
+inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
+  IRPrinter(os).Print(n);
+  return os;
+}
 }  // namespace tvm
 
 namespace std {
index 7524109ec48b127dc3fa0d1c05d6b9c82549f6a3..547386154f76a1f43913456404174a3859545b4d 100644 (file)
  * specific language governing permissions and limitations
  * under the License.
  */
-
 /*!
  * \file tvm/ir.h
  * \brief Additional high level nodes in the IR
  */
+// Acknowledgement: Most low-level IR nodes originate from Halide.
+
 #ifndef TVM_IR_H_
 #define TVM_IR_H_
 
-#include <ir/Expr.h>
-#include <ir/IR.h>
 #include <type_traits>
 #include <string>
+#include <vector>
+#include <utility>
 #include "base.h"
 #include "expr.h"
 #include "runtime/util.h"
 namespace tvm {
 namespace ir {
 
-using HalideIR::Internal::BaseExprNode;
-using HalideIR::Internal::ExprNode;
-using HalideIR::Internal::StmtNode;
-using HalideIR::Internal::IRNodeType;
-using HalideIR::Internal::ForType;
-using HalideIR::DeviceAPI;
+using IntImm = tvm::IntImm;
+using Variable = tvm::Variable;
+
+/*! \brief constant unsigned integer. */
+class UIntImm : public ExprNode {
+ public:
+  /*! \brief The constant value content. */
+  uint64_t value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static Expr make(Type t, uint64_t value);
+
+  static constexpr const char* _type_key = "UIntImm";
+  TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode);
+};
+
+/*! \brief Floating point constants. */
+class FloatImm : public ExprNode {
+ public:
+  /*! \brief The constant value content. */
+  double value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static Expr make(Type t, double value);
+
+  static constexpr const char* _type_key = "FloatImm";
+  TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode);
+};
+
+/*! \brief String constants, only used in asserts. */
+class StringImm : public ExprNode {
+ public:
+  /*! \brief The constant value content. */
+  std::string value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL Expr static make(std::string value);
+
+  static constexpr const char* _type_key = "StringImm";
+  TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode);
+};
+
+/*!
+ * \brief Cast value from one data type to another.
+ * \note The lanes of value should keep fixed.
+ */
+class Cast : public ExprNode {
+ public:
+  /*! \brief Original data type. */
+  Expr value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static Expr make(Type t, Expr v);
+
+  static constexpr const char* _type_key = "Cast";
+  TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode);
+};
+
+/*!
+ * \brief Base template to implement binary ops.
+ * \tparam T The type of the child class.
+ */
+template<typename T>
+class BinaryOpNode : public ExprNode {
+ public:
+  /*! \brief The left operand. */
+  Expr a;
+  /*! \brief The right operand. */
+  Expr b;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &(this->type));
+    v->Visit("a", &a);
+    v->Visit("b", &b);
+  }
+
+  static Expr make(Expr a, Expr b) {
+    CHECK(a.defined()) << "ValueError: a is undefined\n";
+    CHECK(b.defined()) << "ValueError: b is undefined\n";
+    CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+    NodePtr<T> node = make_node<T>();
+    node->type = a.type();
+    node->a = std::move(a);
+    node->b = std::move(b);
+    return Expr(node);
+  }
+
+  TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
+};
+
+/*! \brief a + b */
+class Add : public BinaryOpNode<Add> {
+ public:
+  static constexpr const char* _type_key = "Add";
+};
+
+/*! \brief a - b */
+class Sub : public BinaryOpNode<Sub> {
+ public:
+  static constexpr const char* _type_key = "Sub";
+};
+
+/*! \brief a * b */
+class Mul : public BinaryOpNode<Mul> {
+ public:
+  static constexpr const char* _type_key = "Mul";
+};
+
+/*!
+ * \brief a / b in the C semnatics.
+ * \note For integer division, C standard uses trunc div.
+ */
+class Div : public BinaryOpNode<Div> {
+ public:
+  static constexpr const char* _type_key = "Div";
+};
+
+/*!
+ * \brief a % b in the C semnatics.
+ * \note For integer division, C standard uses trunc div.
+ */
+class Mod : public BinaryOpNode<Mod> {
+ public:
+  static constexpr const char* _type_key = "Mod";
+};
+
+/*! \brief min(a, b) */
+class Min : public BinaryOpNode<Min> {
+ public:
+  static constexpr const char* _type_key = "Min";
+};
+
+/*! \brief max(a, b) */
+class Max : public BinaryOpNode<Max> {
+ public:
+  static constexpr const char* _type_key = "Max";
+};
+
+/*!
+ * \brief Base template to implement comparison ops.
+ * \tparam T The type of the child class.
+ */
+template<typename T>
+class CmpOpNode : public ExprNode {
+ public:
+  /*! \brief The left operand. */
+  Expr a;
+  /*! \brief The right operand. */
+  Expr b;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &(this->type));
+    v->Visit("a", &a);
+    v->Visit("b", &b);
+  }
+
+  static Expr make(Expr a, Expr b) {
+    CHECK(a.defined()) << "ValueError: a is undefined\n";
+    CHECK(b.defined()) << "ValueError: b is undefined\n";
+    CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+    NodePtr<T> node = make_node<T>();
+    node->type = Bool(a.type().lanes());
+    node->a = std::move(a);
+    node->b = std::move(b);
+    return Expr(node);
+  }
+
+  TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
+};
+
+/*! \brief a == b */
+class EQ : public CmpOpNode<EQ> {
+ public:
+  static constexpr const char* _type_key = "EQ";
+};
+
+/*! \brief a != b */
+class NE : public CmpOpNode<NE> {
+ public:
+  static constexpr const char* _type_key = "NE";
+};
+
+/*! \brief a < b */
+class LT : public CmpOpNode<LT> {
+ public:
+  static constexpr const char* _type_key = "LT";
+};
+
+/*! \brief a <= b */
+struct LE : public CmpOpNode<LE> {
+ public:
+  static constexpr const char* _type_key = "LE";
+};
+
+/*! \brief a > b */
+class GT : public CmpOpNode<GT> {
+ public:
+  static constexpr const char* _type_key = "GT";
+};
+
+/*! \brief a >= b */
+class GE : public CmpOpNode<GE> {
+ public:
+  static constexpr const char* _type_key = "GE";
+};
+
+/*! \brief a && b */
+class And : public ExprNode {
+ public:
+  /*! \brief The left operand. */
+  Expr a;
+  /*! \brief The right operand. */
+  Expr b;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &(this->type));
+    v->Visit("a", &a);
+    v->Visit("b", &b);
+  }
+
+  TVM_DLL static Expr make(Expr a, Expr b);
+
+  static constexpr const char* _type_key = "And";
+  TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode);
+};
+
+/*! \brief a || b */
+class Or : public ExprNode {
+ public:
+  /*! \brief The left operand. */
+  Expr a;
+  /*! \brief The right operand. */
+  Expr b;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("a", &a);
+    v->Visit("b", &b);
+  }
+
+  TVM_DLL static Expr make(Expr a, Expr b);
+
+  static constexpr const char* _type_key = "Or";
+  TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode);
+};
+
+/*! \brief !a */
+class Not : public ExprNode {
+ public:
+  /*! \brief The input operand. */
+  Expr a;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("a", &a);
+  }
+
+  TVM_DLL static Expr make(Expr a);
+
+  static constexpr const char* _type_key = "Not";
+  TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode);
+};
+
+/*!
+ * \brief return true_value if condition is true, otherwise return false_value.
+ * \note Both true_value and false_value could be evaluated
+ *       regardless of the condition value.
+ *       Do not use it to guard against out of bound access,
+ *       please use if_then_else instead.
+ */
+class Select : public ExprNode {
+ public:
+  /*! \brief The condition */
+  Expr condition;
+  /*! \brief value to be returned when condition is true. */
+  Expr true_value;
+  /*! \brief value to be returned when condition is false. */
+  Expr false_value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("condition", &condition);
+    v->Visit("true_value", &true_value);
+    v->Visit("false_value", &false_value);
+  }
+
+  TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
+
+  static constexpr const char* _type_key = "Select";
+  TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode);
+};
+
+/*!
+ * \brief Load the value from buffer_var.
+ *
+ *  Equivalent to ((DType*)buffer_var)[index]
+ *  where DType is the type specified by type().element_of().
+ *
+ *  For example, if type = float32x3, then the load will corresponds to
+ *
+ * \code
+ *
+ *  auto buffer = static_cast<float*>(buffer_var);
+ *  auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
+ *
+ * \endcode
+ */
+class Load : public ExprNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The index locations to be loaded. */
+  Expr index;
+  /*! \brief The predicate to mask which lanes would be loaded. */
+  Expr predicate;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("buffer_var", &buffer_var);
+    v->Visit("index", &index);
+    v->Visit("predicate", &predicate);
+  }
+
+  TVM_DLL static Expr make(Type type, Var buffer_var, Expr index, Expr predicate);
+
+  static constexpr const char* _type_key = "Load";
+  TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode);
+};
+
+/*!
+ * \brief Construct a vector with lanes elements
+ *        where its i-th element equals base + i * stride.
+ *  This is useful to construct a index for a continuous vector load.
+ *
+ *  Examples:
+ *  - ramp(0, 1, 3) = [0, 1, 2]
+ *  - ramp(1, 2, 4) = [1, 3, 5, 7]
+ */
+class Ramp : public ExprNode {
+ public:
+  /*! \brief The base value. */
+  Expr base;
+  /*! \brief The stride of each step. */
+  Expr stride;
+  /*! \brief Total number of lanes. */
+  int lanes;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("base", &base);
+    v->Visit("stride", &stride);
+    v->Visit("lanes", &lanes);
+  }
+
+  TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
+
+  static constexpr const char* _type_key = "Ramp";
+  TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode);
+};
+
+/*! \brief Create a vector where all the elements are value. */
+class Broadcast : public ExprNode {
+ public:
+  /*! \brief The base value. */
+  Expr value;
+  /*! \brief The numerb of lanes. */
+  int lanes;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("value", &value);
+    v->Visit("lanes", &lanes);
+  }
+
+  TVM_DLL static Expr make(Expr value, int lanes);
+
+  static constexpr const char* _type_key = "Broadcast";
+  TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode);
+};
+
+/*!
+ * \brief Let binding. Bind var to value then evaluate body.
+ */
+class Let : public ExprNode {
+ public:
+  /*! \brief The variable. */
+  Var var;
+  /*! \brief The value to be binded. */
+  Expr value;
+  /*! \brief The result expression. */
+  Expr body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Expr make(Var var, Expr value, Expr body);
+
+  static constexpr const char* _type_key = "Let";
+  TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode);
+};
+
+// Call node, represent a function call or a multi-dimensional array load.
+//
+// TODO(tvm-team):
+// Refactor call with more explicit property registrations.
+// rather than calling a string symbol.
+// We should move most information into function itself and remove name.
+
+/*! \brief Base node of internal functions. */
+class FunctionBaseNode : public Node {
+ public:
+  /*! \return the name of the function */
+  virtual const std::string& func_name() const = 0;
+  /*! \return the number of outputs of this function */
+  virtual int num_outputs() const = 0;
+};
+
+/*! \brief reference to a function */
+class FunctionRef : public NodeRef {
+ public:
+  TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode);
+};
+
+/*!
+ * \brief Call node.
+ */
+class Call : public ExprNode {
+ public:
+  /*! \brief Possible types of calls. */
+  enum CallType : int {
+    /*! \brief Extern "C" function. */
+    Extern = 0,
+    /*! \brief Extern CXX function. */
+    ExternCPlusPlus = 1,
+    /*! \brief Extern "C" without side-effect. */
+    PureExtern = 2,
+    /*! \brief Halide-style call, evaluates func(args). */
+    Halide = 3,
+    /*! \brief Intrinsic functions. */
+    Intrinsic = 4,
+    /*! \brief Intrinsic functions that are pure. */
+    PureIntrinsic = 5
+  };
+  /*! \brief The name of the function/intrinsic. */
+  std::string name;
+  /*! \brief The arguments. */
+  Array<Expr> args;
+  /*! \brief Type of calls. */
+  CallType call_type;
+  /*! \brief The function to be called. */
+  FunctionRef func;
+  /*! \brief The output value index if func's value is a tuple. */
+  int value_index{0};
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("dtype", &type);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("call_type", &call_type);
+    v->Visit("func", &func);
+    v->Visit("value_index", &value_index);
+  }
+
+  TVM_DLL static Expr make(Type type,
+                           std::string name,
+                           Array<Expr> args,
+                           CallType call_type,
+                           FunctionRef func = FunctionRef(),
+                           int value_index = 0);
+
+  /*! \return Whether call node is pure. */
+  bool is_pure() const {
+    return (call_type == PureExtern ||
+            call_type == PureIntrinsic ||
+            call_type == Halide);
+  }
+
+  /*!
+   * \return Whether call node corresponds to a defined intrinsic.
+   * \param intrin_name The name of the intrinsic.
+   */
+  bool is_intrinsic(const char* intrin_name) const {
+    return
+        ((call_type == Intrinsic ||
+          call_type == PureIntrinsic) &&
+         name == intrin_name);
+  }
+
+  static constexpr const char* _type_key = "Call";
+  TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
+
+  // Build-in intrinsics
+  static constexpr const char* reinterpret = "reinterpret";
+  static constexpr const char* bitwise_and = "bitwise_and";
+  static constexpr const char* bitwise_not = "bitwise_not";
+  static constexpr const char* bitwise_xor = "bitwise_xor";
+  static constexpr const char* bitwise_or = "bitwise_or";
+  static constexpr const char* shift_left = "shift_left";
+  static constexpr const char* shift_right = "shift_right";
+  static constexpr const char* popcount = "popcount";
+  static constexpr const char* likely = "likely";
+  static constexpr const char* glsl_texture_store = "glsl_texture_store";
+  static constexpr const char* prefetch = "prefetch";
+};
+
+/*!
+ * \brief Shuffle instruction.
+ *  vec = concat(vectors)
+ *  result = (vec[indices[0]], vec[indices[1]] ...)
+ */
+class Shuffle : public ExprNode {
+ public:
+  /*! \brief the input vectors. */
+  Array<Expr> vectors;
+  /*! \brief The indices of each element. */
+  Array<Expr> indices;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("vectors", &vectors);
+    v->Visit("indices", &indices);
+  }
+
+  TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
+  TVM_DLL static Expr make_concat(Array<Expr> vectors);
+  TVM_DLL static Expr make_extract_element(Expr vector, int index);
 
-// Node container for CommReducer
-struct CommReducerNode;
+  static constexpr const char* _type_key = "Shuffle";
+  TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode);
+};
 
-struct CommReducer : public NodeRef {
+// Reduce operator
+class CommReducerNode;
+
+class CommReducer : public NodeRef {
+ public:
   CommReducer() {}
   explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
   /*!
@@ -66,7 +611,8 @@ struct CommReducer : public NodeRef {
  * \brief A commutative reducer node to represent a commutative
  *  binary operator with identity element
  */
-struct CommReducerNode : public Node {
+class CommReducerNode : public Node {
+ public:
   /*! \brief The left argument of reducer */
   Array<Var> lhs;
   /*! \brief The right argument of reducer */
@@ -82,8 +628,10 @@ struct CommReducerNode : public Node {
   /*! \brief Function call operator to combine a and b */
   Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
   /*! \brief construct CommReducer from args, result and identity_element */
-  TVM_DLL static CommReducer make(Array<Var> lhs, Array<Var> rhs,
-                                 Array<Expr> result, Array<Expr> identity_element);
+  TVM_DLL static CommReducer make(Array<Var> lhs,
+                                  Array<Var> rhs,
+                                  Array<Expr> result,
+                                  Array<Expr> identity_element);
 
   void VisitAttrs(AttrVisitor* v) final {
     v->Visit("lhs", &lhs);
@@ -104,7 +652,8 @@ inline const CommReducerNode* CommReducer::operator->() const {
 }
 
 /*! \brief Reduction operator operator */
-struct Reduce : public ExprNode<Reduce> {
+class Reduce : public ExprNode {
+ public:
   /*! \brief The commutative combiner */
   CommReducer combiner;
   /*! \brief The source operand */
@@ -134,17 +683,483 @@ struct Reduce : public ExprNode<Reduce> {
     v->Visit("condition", &condition);
     v->Visit("value_index", &value_index);
   }
-  static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
+
   static constexpr const char* _type_key = "Reduce";
+  TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode);
 };
 
 /*! \brief Any shape. */
-struct Any : public ExprNode<Any> {
+class Any : public ExprNode {
+ public:
+  void VisitAttrs(AttrVisitor* v) final {}
+
   TVM_DLL static Expr make();
 
-  void VisitAttrs(AttrVisitor* v) final {}
-  static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
   static constexpr const char* _type_key = "Any";
+  TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode);
+};
+
+// Statements
+/*!
+ * \brief Let binding, bind var to value, then run body.
+ */
+class LetStmt : public StmtNode {
+ public:
+  /*! \brief The variable. */
+  Var var;
+  /*! \brief The value to be binded. */
+  Expr value;
+  /*! \brief The body block. */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
+
+  static constexpr const char* _type_key = "LetStmt";
+  TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode);
+};
+
+/*!
+ * \brief Define certain auxiliary attribute for the body to be a symbolic value.
+ *  This provide auxiliary information for IR passes that transforms body.
+ *
+ *  In terms of effect, this is equivalent to Block(Evaluate(value), body).
+ *
+ *  Examples of possible usage:
+ *    - Bound of function, variables.
+ *    - Hint which block corresponds to a parallel region.
+ */
+class AttrStmt : public StmtNode {
+ public:
+  /*! \brief this is attribute about certain node */
+  NodeRef node;
+  /*! \brief the type key of the attribute */
+  std::string attr_key;
+  /*! \brief The attribute value, value is well defined at current scope. */
+  Expr value;
+  /*! \brief The body statement to be executed */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("node", &node);
+    v->Visit("attr_key", &attr_key);
+    v->Visit("value", &value);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(NodeRef node,
+                           std::string type_key,
+                           Expr value,
+                           Stmt body);
+
+  static constexpr const char* _type_key = "AttrStmt";
+  TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode);
+};
+
+/*!
+ * \brief Assert condition, if an error occurs, return the error message.
+ */
+class AssertStmt : public StmtNode {
+ public:
+  /*! \brief Condition to be checked. */
+  Expr condition;
+  /*! \brief Error message when assertion failed. */
+  Expr message;
+  /*!
+   * \brief Body which this assertion holds true.
+   *  Will be executed after the assertion.
+   */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("condition", &condition);
+    v->Visit("message", &message);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
+
+  static constexpr const char* _type_key = "AssertStmt";
+  TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode);
+};
+
+// TODO(tvm-team): consider consolidate with AttrStmt.
+/*! \brief annotation node of producer/consumer relation. */
+class ProducerConsumer : public StmtNode {
+ public:
+  /*! \brief The corresponding tensor. */
+  FunctionRef func;
+  /*! \brief Whether the relation is producer. */
+  bool is_producer;
+  /*! \brief Body to be executed. */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("func", &func);
+    v->Visit("is_producer", &is_producer);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
+
+  static constexpr const char* _type_key = "ProducerConsumer";
+  TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode);
+};
+
+/*!
+ * \brief Store value to the buffer.
+ *
+ *  Equivalent to ((DType*)buffer_var)[index] = value.
+ *  where DType is the type specified by type().element_of().
+ *
+ *  For example, if type = float32x3, then the load will corresponds to
+ *
+ * \code
+ *
+ *  auto buffer = static_cast<float*>(buffer_var);
+ *  buffer[index.v0] = value.v0;
+ *  buffer[index.v1] = value.v1;
+ *  buffer[index.v2] = value.v2;
+ *
+ * \endcode
+ * \sa Load
+ */
+class Store : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The value to be stored. */
+  Expr value;
+  /*! \brief The index locations to be stored. */
+  Expr index;
+  /*! \brief The predicate to mask which lanes would be stored. */
+  Expr predicate;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("buffer_var", &buffer_var);
+    v->Visit("value", &value);
+    v->Visit("index", &index);
+    v->Visit("predicate", &predicate);
+  }
+
+  TVM_DLL static Stmt make(Var buffer_var,
+                           Expr value,
+                           Expr index,
+                           Expr predicate);
+
+  static constexpr const char* _type_key = "Store";
+  TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode);
+};
+
+/*!
+ * \brief Store value into mult-dimensional array defined by func.
+ */
+class Provide : public StmtNode {
+ public:
+  /*! \brief The function to be updated. */
+  FunctionRef func;
+  /*! \brief The output value index if func's value is a tuple. */
+  int value_index{0};
+  /*! \brief The value to be stored. */
+  Expr value;
+  /*! \brief The index arguments of the function. */
+  Array<Expr> args;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("func", &func);
+    v->Visit("value_index", &value_index);
+    v->Visit("value", &value);
+    v->Visit("args", &args);
+  }
+
+  TVM_DLL static Stmt make(FunctionRef func,
+                           int value_index,
+                           Expr value,
+                           Array<Expr> args);
+
+  static constexpr const char* _type_key = "Provide";
+  TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class Allocate : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The type of the buffer. */
+  DataType type;
+  /*! \brief The extents of the buffer. */
+  Array<Expr> extents;
+  /*! \brief Only allocate buffer when condition is satisfied. */
+  Expr condition;
+  /*! \brief The body to be executed. */
+  Stmt body;
+  // The following two fields are deprecated
+  // kept for backward compatibility and will be refactored later.
+  Expr new_expr;
+  std::string free_function;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("buffer_var", &buffer_var);
+    v->Visit("dtype", &type);
+    v->Visit("extents", &extents);
+    v->Visit("condition", &condition);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(Var buffer_var,
+                           DataType type,
+                           Array<Expr> extents,
+                           Expr condition,
+                           Stmt body,
+                           Expr new_expr = Expr(),
+                           std::string free_function = std::string());
+
+  /*!
+   * \brief If the buffer size is constant, return the size.
+   *        Otherwise return 0.
+   * \return The result.
+   */
+  int32_t constant_allocation_size() const {
+    return constant_allocation_size(extents);
+  }
+  /*!
+   * \brief If the buffer size is constant, return the size.
+   *        Otherwise return 0.
+   * \param extents The extents of the buffer.
+   * \return The result.
+   */
+  TVM_DLL static int32_t constant_allocation_size(
+      const Array<Expr>& extents);
+
+  static constexpr const char* _type_key = "Allocate";
+  TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode);
+};
+
+/*! \brief Free the resources in the buffer before the scope ends. */
+class Free : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("buffer_var", &buffer_var);
+  }
+
+  TVM_DLL static Stmt make(Var buffer_var);
+
+  static constexpr const char* _type_key = "Free";
+  TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode);
+};
+
+/*!
+ * \brief Annotate the bounds where func need to be written and read in body.
+ *  We will need to allocate space for the corresponding regions.
+ */
+class Realize : public StmtNode {
+ public:
+  /*! \brief The function to be realized. */
+  FunctionRef func;
+  /*! \brief The output value index if func's value is a tuple. */
+  int value_index;
+  /*! \brief The data type of the array. */
+  DataType type;
+  /*! \brief Bounds to be realized. */
+  Region bounds;
+  /*! \brief Only realize if condition holds. */
+  Expr condition;
+  /*! \brief The body of realization. */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("func", &func);
+    v->Visit("value_index", &value_index);
+    v->Visit("dtype", &type);
+    v->Visit("bounds", &bounds);
+    v->Visit("condition", &condition);
+    v->Visit("body", &body);
+  }
+
+  TVM_DLL static Stmt make(FunctionRef func,
+                           int value_index,
+                           DataType type,
+                           Region bounds,
+                           Expr condition,
+                           Stmt body);
+
+  static constexpr const char* _type_key = "Realize";
+  TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode);
+};
+
+/*!
+ * \brief A sequence of statements.
+ */
+class Block : public StmtNode {
+ public:
+  /*! \brief The first statement. */
+  Stmt first;
+  /*! \brief The restof statments. */
+  Stmt rest;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("first", &first);
+    v->Visit("rest", &rest);
+  }
+
+  TVM_DLL static Stmt make(Stmt first, Stmt rest);
+  TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
+
+  static constexpr const char* _type_key = "Block";
+  TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode);
+};
+
+/*!
+ * \brief IfThenElse statment.
+ */
+class IfThenElse : public StmtNode {
+ public:
+  /*! \brief The condition. */
+  Expr condition;
+  /*! \brief The branch to be executed when condition is true. */
+  Stmt then_case;
+  /*! \brief The branch to be executed when condition is false, can be null. */
+  Stmt else_case;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("condition", &condition);
+    v->Visit("then_case", &then_case);
+    v->Visit("else_case", &else_case);
+  }
+
+  TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
+
+  static constexpr const char* _type_key = "IfThenElse";
+  TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode);
+};
+
+/*!
+ * \brief Evaluates an expression.
+ *  This is mostly used for putting a Call node into Stmt.
+ *
+ *  If value do not have side-effect, this node can be safely removed.
+ */
+class Evaluate : public StmtNode {
+ public:
+  /*! \brief The expression to be evaluated. */
+  Expr value;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("value", &value);
+  }
+
+  TVM_DLL static Stmt make(Expr v);
+
+  static constexpr const char* _type_key = "Evaluate";
+  TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode);
+};
+
+/*! \brief Additional annotation of for loop. */
+enum class ForType : int {
+  /*! \brief serial execution. */
+  Serial = 0,
+  /*! \brief parallel execution on CPU. */
+  Parallel = 1,
+  /*! \brief Vector SIMD loop annotaion. */
+  Vectorized = 2,
+  /*! \brief Unroll annotation. */
+  Unrolled = 3
+};
+
+// Kevice api of for loop
+// kept for backward compatibility
+// consider refactor and remove later.
+enum class DeviceAPI: int {
+  None = 0
+};
+
+/*!
+ * \brief A for loop, with poissible type annotations.
+ *
+ * \code
+ *
+ *  for (loop_var = min; loop_var < min + extent; ++loop_var) {
+ *    // body
+ *  }
+ * \endcode
+ */
+class For : public StmtNode {
+ public:
+  /*! \brief The loop variable. */
+  Var loop_var;
+  /*! \brief The minimum value of iteration. */
+  Expr min;
+  /*! \brief The extent of the iteration. */
+  Expr extent;
+  /*! \brief The type of the for loop. */
+  ForType for_type;
+  /*!
+   * \brief Deprecated, reserved for backward compatibility.
+   *  Consider refactor and remove later.
+   */
+  DeviceAPI device_api;
+  /*! \brief The body of the for loop. */
+  Stmt body;
+
+  TVM_DLL static Stmt make(Var loop_var,
+                           Expr min,
+                           Expr extent,
+                           ForType for_type,
+                           DeviceAPI device_api,
+                           Stmt body);
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("loop_var", &loop_var);
+    v->Visit("min", &min);
+    v->Visit("extent", &extent);
+    v->Visit("for_type", &for_type);
+    v->Visit("device_api", &device_api);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "For";
+  TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode);
+};
+
+/*!
+ * \brief A prefetch hint of func.
+ */
+class Prefetch : public StmtNode {
+ public:
+  /*! \brief The function to be prefetched. */
+  FunctionRef func;
+  /*! \brief The output value index if func's value is a tuple. */
+  int value_index;
+  /*! \brief The data type of the array. */
+  DataType type;
+  /*! \brief Bounds to be prefetched. */
+  Region bounds;
+
+  void VisitAttrs(AttrVisitor* v) final {
+    v->Visit("func", &func);
+    v->Visit("value_index", &value_index);
+    v->Visit("type", &type);
+    v->Visit("bounds", &bounds);
+  }
+
+  TVM_DLL static Stmt make(FunctionRef func,
+                           int value_index,
+                           DataType type,
+                           Region bounds);
+
+  static constexpr const char* _type_key = "Prefetch";
+  TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode);
 };
 
 /*!
@@ -517,50 +1532,6 @@ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
 
 }   // namespace intrinsic
 
-// Reuse IR node defintiion from HalideIR
-using HalideIR::Internal::IntImm;
-using HalideIR::Internal::UIntImm;
-using HalideIR::Internal::FloatImm;
-using HalideIR::Internal::StringImm;
-using HalideIR::Internal::Cast;
-using HalideIR::Internal::Add;
-using HalideIR::Internal::Sub;
-using HalideIR::Internal::Mul;
-using HalideIR::Internal::Div;
-using HalideIR::Internal::Mod;
-using HalideIR::Internal::Min;
-using HalideIR::Internal::Max;
-using HalideIR::Internal::EQ;
-using HalideIR::Internal::NE;
-using HalideIR::Internal::LT;
-using HalideIR::Internal::LE;
-using HalideIR::Internal::GT;
-using HalideIR::Internal::GE;
-using HalideIR::Internal::And;
-using HalideIR::Internal::Or;
-using HalideIR::Internal::Not;
-using HalideIR::Internal::Select;
-using HalideIR::Internal::Load;
-using HalideIR::Internal::Ramp;
-using HalideIR::Internal::Broadcast;
-using HalideIR::Internal::Call;
-using HalideIR::Internal::Let;
-using HalideIR::Internal::LetStmt;
-using HalideIR::Internal::AttrStmt;
-using HalideIR::Internal::AssertStmt;
-using HalideIR::Internal::ProducerConsumer;
-using HalideIR::Internal::For;
-using HalideIR::Internal::Store;
-using HalideIR::Internal::Provide;
-using HalideIR::Internal::Allocate;
-using HalideIR::Internal::Free;
-using HalideIR::Internal::Realize;
-using HalideIR::Internal::Prefetch;
-using HalideIR::Internal::Block;
-using HalideIR::Internal::IfThenElse;
-using HalideIR::Internal::Evaluate;
-using HalideIR::Internal::Shuffle;
-
 /*!
  * \brief Create a type annotation expression
  * \param dtype The data type
@@ -571,6 +1542,10 @@ inline Expr TypeAnnotation(Type dtype) {
                         "type_annotation", {},
                         ir::Call::PureIntrinsic);
 }
+
+// overload printing of for type.
+TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
+
 }  // namespace ir
 }  // namespace tvm
 
index 61080078c1763282db1cf0583ea7cab323e7bd41..2837a660113604319f5867077a7412079f765123 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_IR_MUTATOR_H_
 
 #include <unordered_map>
+#include <utility>
 #include "expr.h"
 #include "ir.h"
 #include "tvm/node/ir_functor.h"
index cb03f6c9dae70ef731ab89daf2af0e34d239f22c..4da93b80c2ab6ad9bbbae9ddcac7260afd325622 100644 (file)
@@ -25,7 +25,6 @@
 #ifndef TVM_LOWERED_FUNC_H_
 #define TVM_LOWERED_FUNC_H_
 
-#include <ir/FunctionBase.h>
 #include <string>
 
 #include "base.h"
@@ -42,7 +41,7 @@ class LoweredFuncNode;
  * \brief LoweredFunc represents function after lowering.
  *  This is the final IR representation before codegen.
  */
-class LoweredFunc : public FunctionRef {
+class LoweredFunc : public ir::FunctionRef {
  public:
   LoweredFunc() {}
   explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
@@ -66,7 +65,7 @@ enum LoweredFuncType : int {
 };
 
 /*! \brief Node container of LoweredFunc */
-class LoweredFuncNode : public FunctionBaseNode {
+class LoweredFuncNode : public ir::FunctionBaseNode {
  public:
   /*! \brief The name of the function */
   std::string name;
diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h
new file mode 100644 (file)
index 0000000..7180890
--- /dev/null
@@ -0,0 +1,612 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/container.h
+ * \brief Array/Map container in the DSL graph.
+ */
+#ifndef TVM_NODE_CONTAINER_H_
+#define TVM_NODE_CONTAINER_H_
+
+#include <type_traits>
+#include <vector>
+#include <initializer_list>
+#include <unordered_map>
+#include <utility>
+#include <string>
+#include "node.h"
+#include "memory.h"
+
+namespace tvm {
+
+/*! \brief array node content in array */
+class ArrayNode : public Node {
+ public:
+  /*! \brief the data content */
+  std::vector<NodePtr<Node> > data;
+
+  void VisitAttrs(AttrVisitor* visitor) final {
+     // Visitor to array have no effect.
+  }
+
+  static constexpr const char* _type_key = "Array";
+  TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node);
+};
+
+/*! \brief map node content */
+class MapNode : public Node {
+ public:
+  void VisitAttrs(AttrVisitor* visitor) final {
+     // Visitor to map have no effect.
+  }
+  // hash function
+  struct Hash {
+    size_t operator()(const NodePtr<Node>& n) const {
+      return std::hash<Node*>()(n.get());
+    }
+  };
+  // comparator
+  struct Equal {
+    bool operator()(
+        const NodePtr<Node>& a,
+        const NodePtr<Node>& b) const {
+      return a.get() == b.get();
+    }
+  };
+
+  /*! \brief The corresponding conatiner type */
+  using ContainerType = std::unordered_map<
+   NodePtr<Node>,
+   NodePtr<Node>,
+   Hash, Equal>;
+
+  /*! \brief the data content */
+  ContainerType data;
+
+  static constexpr const char* _type_key = "Map";
+  TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node);
+};
+
+
+/*! \brief specialized map node with string as key */
+class StrMapNode : public Node {
+ public:
+  void VisitAttrs(AttrVisitor* visitor) final {
+     // Visitor to map have no effect.
+  }
+  /*! \brief The corresponding conatiner type */
+  using ContainerType = std::unordered_map<
+    std::string,
+    NodePtr<Node> >;
+
+  /*! \brief the data content */
+  ContainerType data;
+
+  static constexpr const char* _type_key = "StrMap";
+  TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node);
+};
+
+/*!
+ * \brief iterator adapter that adapts TIter to return another type.
+ * \tparam Converter a struct that contains converting function
+ * \tparam TIter the content iterator type.
+ */
+template<typename Converter,
+         typename TIter>
+class IterAdapter {
+ public:
+  explicit IterAdapter(TIter iter) : iter_(iter) {}
+  inline IterAdapter& operator++() {  // NOLINT(*)
+    ++iter_;
+    return *this;
+  }
+  inline IterAdapter& operator++(int) {  // NOLINT(*)
+    ++iter_;
+    return *this;
+  }
+  inline IterAdapter operator+(int offset) const {  // NOLINT(*)
+    return IterAdapter(iter_ + offset);
+  }
+  inline bool operator==(IterAdapter other) const {
+    return iter_ == other.iter_;
+  }
+  inline bool operator!=(IterAdapter other) const {
+    return !(*this == other);
+  }
+  inline const typename Converter::ResultType operator*() const {
+    return Converter::convert(*iter_);
+  }
+
+ private:
+  TIter iter_;
+};
+
+/*!
+ * \brief Array container of NodeRef in DSL graph.
+ *  Array implements copy on write semantics, which means array is mutable
+ *  but copy will happen when array is referenced in more than two places.
+ *
+ * operator[] only provide const acces, use Set to mutate the content.
+ * \tparam T The content NodeRef type.
+ */
+template<typename T,
+         typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
+class Array : public NodeRef {
+ public:
+  /*!
+   * \brief default constructor
+   */
+  Array() {
+    node_ = make_node<ArrayNode>();
+  }
+  /*!
+   * \brief move constructor
+   * \param other source
+   */
+  Array(Array<T> && other) {  // NOLINT(*)
+    node_ = std::move(other.node_);
+  }
+  /*!
+   * \brief copy constructor
+   * \param other source
+   */
+  Array(const Array<T> &other) : NodeRef(other.node_) { // NOLINT(*)
+  }
+  /*!
+   * \brief constructor from pointer
+   * \param n the container pointer
+   */
+  explicit Array(NodePtr<Node> n) : NodeRef(n) {}
+  /*!
+   * \brief constructor from iterator
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template<typename IterType>
+  Array(IterType begin, IterType end) {
+    assign(begin, end);
+  }
+  /*!
+   * \brief constructor from initializer list
+   * \param init The initalizer list
+   */
+  Array(std::initializer_list<T> init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+  /*!
+   * \brief constructor from vector
+   * \param init The vector
+   */
+  Array(const std::vector<T>& init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+  /*!
+   * \brief Constructs a container with n elements. Each element is a copy of val
+   * \param n The size of the container
+   * \param val The init value
+   */
+  explicit Array(size_t n, const T& val) {
+    auto tmp_node = make_node<ArrayNode>();
+    for (size_t i = 0; i < n; ++i) {
+      tmp_node->data.push_back(val.node_);
+    }
+    node_ = std::move(tmp_node);
+  }
+  /*!
+   * \brief move assign operator
+   * \param other The source of assignment
+   * \return reference to self.
+   */
+  Array<T>& operator=(Array<T> && other) {
+    node_ = std::move(other.node_);
+    return *this;
+  }
+  /*!
+   * \brief copy assign operator
+   * \param other The source of assignment
+   * \return reference to self.
+   */
+  Array<T>& operator=(const Array<T> & other) {
+    node_ = other.node_;
+    return *this;
+  }
+  /*!
+   * \brief reset the array to content from iterator.
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template<typename IterType>
+  void assign(IterType begin, IterType end) {
+    auto n = make_node<ArrayNode>();
+    for (IterType it = begin; it != end; ++it) {
+      n->data.push_back((*it).node_);
+    }
+    node_ = std::move(n);
+  }
+  /*!
+   * \brief Read i-th element from array.
+   * \param i The index
+   * \return the i-th element.
+   */
+  inline const T operator[](size_t i) const {
+    return T(static_cast<const ArrayNode*>(node_.get())->data[i]);
+  }
+  /*! \return The size of the array */
+  inline size_t size() const {
+    if (node_.get() == nullptr) return 0;
+    return static_cast<const ArrayNode*>(node_.get())->data.size();
+  }
+  /*!
+   * \brief copy on write semantics
+   *  Do nothing if current handle is the unique copy of the array.
+   *  Otherwise make a new copy of the array to ensure the current handle
+   *  hold a unique copy.
+   *
+   * \return Handle to the internal node container(which ganrantees to be unique)
+   */
+  inline ArrayNode* CopyOnWrite() {
+    if (node_.get() == nullptr || !node_.unique())  {
+      NodePtr<ArrayNode> n = make_node<ArrayNode>();
+      n->data = static_cast<ArrayNode*>(node_.get())->data;
+      NodePtr<Node>(std::move(n)).swap(node_);
+    }
+    return static_cast<ArrayNode*>(node_.get());
+  }
+  /*!
+   * \brief push a new item to the back of the list
+   * \param item The item to be pushed.
+   */
+  inline void push_back(const T& item) {
+    ArrayNode* n = this->CopyOnWrite();
+    n->data.push_back(item.node_);
+  }
+  /*!
+   * \brief set i-th element of the array.
+   * \param i The index
+   * \param value The value to be setted.
+   */
+  inline void Set(size_t i, const T& value) {
+    ArrayNode* n = this->CopyOnWrite();
+    n->data[i] = value.node_;
+  }
+  /*! \return whether array is empty */
+  inline bool empty() const {
+    return size() == 0;
+  }
+  /*! \brief specify container node */
+  using ContainerType = ArrayNode;
+
+  struct Ptr2NodeRef {
+    using ResultType = T;
+    static inline T convert(const NodePtr<Node>& n) {
+      return T(n);
+    }
+  };
+  using iterator = IterAdapter<Ptr2NodeRef,
+                               std::vector<NodePtr<Node> >::const_iterator>;
+
+  using reverse_iterator = IterAdapter<
+    Ptr2NodeRef,
+    std::vector<NodePtr<Node> >::const_reverse_iterator>;
+
+  /*! \return begin iterator */
+  inline iterator begin() const {
+    return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin());
+  }
+  /*! \return end iterator */
+  inline iterator end() const {
+    return iterator(static_cast<const ArrayNode*>(node_.get())->data.end());
+  }
+  /*! \return rbegin iterator */
+  inline reverse_iterator rbegin() const {
+    return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin());
+  }
+  /*! \return rend iterator */
+  inline reverse_iterator rend() const {
+    return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend());
+  }
+};
+
+/*!
+ * \brief Map container of NodeRef->NodeRef in DSL graph.
+ *  Map implements copy on write semantics, which means map is mutable
+ *  but copy will happen when array is referenced in more than two places.
+ *
+ * operator[] only provide const acces, use Set to mutate the content.
+ * \tparam K The key NodeRef type.
+ * \tparam V The value NodeRef type.
+ */
+template<typename K,
+         typename V,
+         typename = typename std::enable_if<
+           std::is_base_of<NodeRef, K>::value ||
+           std::is_base_of<std::string, K>::value >::type,
+         typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
+class Map : public NodeRef {
+ public:
+  /*!
+   * \brief default constructor
+   */
+  Map() {
+    node_ = make_node<MapNode>();
+  }
+  /*!
+   * \brief move constructor
+   * \param other source
+   */
+  Map(Map<K, V> && other) {  // NOLINT(*)
+    node_ = std::move(other.node_);
+  }
+  /*!
+   * \brief copy constructor
+   * \param other source
+   */
+  Map(const Map<K, V> &other) : NodeRef(other.node_) { // NOLINT(*)
+  }
+  /*!
+   * \brief constructor from pointer
+   * \param n the container pointer
+   */
+  explicit Map(NodePtr<Node> n) : NodeRef(n) {}
+  /*!
+   * \brief constructor from iterator
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template<typename IterType>
+  Map(IterType begin, IterType end) {
+    assign(begin, end);
+  }
+  /*!
+   * \brief constructor from initializer list
+   * \param init The initalizer list
+   */
+  Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+  /*!
+   * \brief constructor from vector
+   * \param init The vector
+   */
+  template<typename Hash, typename Equal>
+  Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+  /*!
+   * \brief move assign operator
+   * \param other The source of assignment
+   * \return reference to self.
+   */
+  Map<K, V>& operator=(Map<K, V> && other) {
+    node_ = std::move(other.node_);
+    return *this;
+  }
+  /*!
+   * \brief copy assign operator
+   * \param other The source of assignment
+   * \return reference to self.
+   */
+  Map<K, V>& operator=(const Map<K, V> & other) {
+    node_ = other.node_;
+    return *this;
+  }
+  /*!
+   * \brief reset the array to content from iterator.
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template<typename IterType>
+  void assign(IterType begin, IterType end) {
+    NodePtr<MapNode> n = make_node<MapNode>();
+    for (IterType i = begin; i != end; ++i) {
+      n->data.emplace(std::make_pair(i->first.node_,
+                                     i->second.node_));
+    }
+    node_ = std::move(n);
+  }
+  /*!
+   * \brief Read element from map.
+   * \param key The key
+   * \return the corresonding element.
+   */
+  inline const V operator[](const K& key) const {
+    return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
+  }
+  /*!
+   * \brief Read element from map.
+   * \param key The key
+   * \return the corresonding element.
+   */
+  inline const V at(const K& key) const {
+    return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
+  }
+  /*! \return The size of the array */
+  inline size_t size() const {
+    if (node_.get() == nullptr) return 0;
+    return static_cast<const MapNode*>(node_.get())->data.size();
+  }
+  /*! \return The number of elements of the key */
+  inline size_t count(const K& key) const {
+    if (node_.get() == nullptr) return 0;
+    return static_cast<const MapNode*>(node_.get())->data.count(key.node_);
+  }
+  /*!
+   * \brief copy on write semantics
+   *  Do nothing if current handle is the unique copy of the array.
+   *  Otherwise make a new copy of the array to ensure the current handle
+   *  hold a unique copy.
+   *
+   * \return Handle to the internal node container(which ganrantees to be unique)
+   */
+  inline MapNode* CopyOnWrite() {
+    if (node_.get() == nullptr || !node_.unique())  {
+      NodePtr<MapNode> n = make_node<MapNode>();
+      n->data = static_cast<const MapNode*>(node_.get())->data;
+      NodePtr<Node>(std::move(n)).swap(node_);
+    }
+    return static_cast<MapNode*>(node_.get());
+  }
+  /*!
+   * \brief set the Map.
+   * \param key The index key.
+   * \param value The value to be setted.
+   */
+  inline void Set(const K& key, const V& value) {
+    MapNode* n = this->CopyOnWrite();
+    n->data[key.node_] = value.node_;
+  }
+
+  /*! \return whether array is empty */
+  inline bool empty() const {
+    return size() == 0;
+  }
+  /*! \brief specify container node */
+  using ContainerType = MapNode;
+
+  struct Ptr2NodeRef {
+    using ResultType = std::pair<K, V>;
+    static inline ResultType convert(const std::pair<
+                            NodePtr<Node>,
+                            NodePtr<Node> >& n) {
+      return std::make_pair(K(n.first), V(n.second));
+    }
+  };
+
+  using iterator = IterAdapter<
+    Ptr2NodeRef, MapNode::ContainerType::const_iterator>;
+
+  /*! \return begin iterator */
+  inline iterator begin() const {
+    return iterator(static_cast<const MapNode*>(node_.get())->data.begin());
+  }
+  /*! \return end iterator */
+  inline iterator end() const {
+    return iterator(static_cast<const MapNode*>(node_.get())->data.end());
+  }
+  /*! \return begin iterator */
+  inline iterator find(const K& key) const {
+    return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_));
+  }
+};
+
+// specialize of string map
+template<typename V, typename T1, typename T2>
+class Map<std::string, V, T1, T2> : public NodeRef {
+ public:
+  // for code reuse
+  Map() {
+    node_ = make_node<StrMapNode>();
+  }
+  Map(Map<std::string, V> && other) {  // NOLINT(*)
+    node_ = std::move(other.node_);
+  }
+  Map(const Map<std::string, V> &other) : NodeRef(other.node_) { // NOLINT(*)
+  }
+  explicit Map(NodePtr<Node> n) : NodeRef(n) {}
+  template<typename IterType>
+  Map(IterType begin, IterType end) {
+    assign(begin, end);
+  }
+  Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+
+  template<typename Hash, typename Equal>
+  Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
+    assign(init.begin(), init.end());
+  }
+  Map<std::string, V>& operator=(Map<std::string, V> && other) {
+    node_ = std::move(other.node_);
+    return *this;
+  }
+  Map<std::string, V>& operator=(const Map<std::string, V> & other) {
+    node_ = other.node_;
+    return *this;
+  }
+  template<typename IterType>
+  void assign(IterType begin, IterType end) {
+    auto n = make_node<StrMapNode>();
+    for (IterType i = begin; i != end; ++i) {
+      n->data.emplace(std::make_pair(i->first,
+                                     i->second.node_));
+    }
+    node_ = std::move(n);
+  }
+  inline const V operator[](const std::string& key) const {
+    return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
+  }
+  inline const V at(const std::string& key) const {
+    return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
+  }
+  inline size_t size() const {
+    if (node_.get() == nullptr) return 0;
+    return static_cast<const StrMapNode*>(node_.get())->data.size();
+  }
+  inline size_t count(const std::string& key) const {
+    if (node_.get() == nullptr) return 0;
+    return static_cast<const StrMapNode*>(node_.get())->data.count(key);
+  }
+  inline StrMapNode* CopyOnWrite() {
+    if (node_.get() == nullptr || !node_.unique())  {
+      NodePtr<StrMapNode> n = make_node<StrMapNode>();
+      n->data = static_cast<const StrMapNode*>(node_.get())->data;
+      NodePtr<Node>(std::move(n)).swap(node_);
+    }
+    return static_cast<StrMapNode*>(node_.get());
+  }
+  inline void Set(const std::string& key, const V& value) {
+    StrMapNode* n = this->CopyOnWrite();
+    n->data[key] = value.node_;
+  }
+  inline bool empty() const {
+    return size() == 0;
+  }
+  using ContainerType = StrMapNode;
+
+  struct Ptr2NodeRef {
+    using ResultType = std::pair<std::string, V>;
+    static inline ResultType convert(const std::pair<
+                            std::string,
+                            NodePtr<Node> >& n) {
+      return std::make_pair(n.first, V(n.second));
+    }
+  };
+
+  using iterator = IterAdapter<
+    Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>;
+
+  /*! \return begin iterator */
+  inline iterator begin() const {
+    return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin());
+  }
+  /*! \return end iterator */
+  inline iterator end() const {
+    return iterator(static_cast<const StrMapNode*>(node_.get())->data.end());
+  }
+  /*! \return begin iterator */
+  inline iterator find(const std::string& key) const {
+    return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key));
+  }
+};
+
+}  // namespace tvm
+#endif  // TVM_NODE_CONTAINER_H_
diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h
new file mode 100644 (file)
index 0000000..23c5a3f
--- /dev/null
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/ir_functor.h
+ * \brief Defines the IRFunctor data structures.
+ */
+#ifndef TVM_NODE_IR_FUNCTOR_H_
+#define TVM_NODE_IR_FUNCTOR_H_
+
+#include <dmlc/logging.h>
+#include <string>
+#include <vector>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <functional>
+#include "node.h"
+
+namespace tvm {
+/*!
+ * \brief A dynamically dispatched functor on NodeRef in the first argument.
+ *
+ * \code
+ *   IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
+ *   tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
+ *     return prefix + "Add";
+ *   });
+ *   tostr.set_dispatch<IntImm>([](const IntImm* op) {
+ *     return prefix + "IntImm"
+ *   });
+ *
+ *   Expr x = make_const(1);
+ *   Expr y = x + x;
+ *   // dispatch to IntImm, outputs "MyIntImm"
+ *   LOG(INFO) << tostr(x, "My");
+ *   // dispatch to IntImm, outputs "MyAdd"
+ *   LOG(INFO) << tostr(y, "My");
+ * \endcode
+ *
+ * \tparam FType function signiture
+ *  This type if only defined for FType with function signature
+ */
+template<typename FType>
+class IRFunctor;
+
+template<typename R, typename ...Args>
+class IRFunctor<R(const NodeRef& n, Args...)> {
+ private:
+  using Function = std::function<R (const NodeRef&n, Args...)>;
+  using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
+  /*! \brief internal function table */
+  std::vector<Function> func_;
+
+ public:
+  /*! \brief the result type of this functor */
+  using result_type = R;
+  /*!
+   * \brief Whether the functor can dispatch the corresponding Node
+   * \param n The node to be dispatched
+   * \return Whether dispatching function is registered for n's type.
+   */
+  inline bool can_dispatch(const NodeRef& n) const {
+    uint32_t type_index = n.type_index();
+    return type_index < func_.size() && func_[type_index] != nullptr;
+  }
+  /*!
+   * \brief invoke the functor , dispatch on type of n
+   * \param n The Node argument
+   * \param args The additional arguments
+   * \return The result.
+   */
+  inline R operator()(const NodeRef& n, Args... args) const {
+    uint32_t type_index = n.type_index();
+    CHECK(type_index < func_.size() &&
+          func_[type_index] != nullptr)
+        << "IRFunctor calls un-registered function on type "
+        << Node::TypeIndex2Key(type_index);
+    return func_[type_index](n, std::forward<Args>(args)...);
+  }
+  /*!
+   * \brief set the dispacher for type TNode
+   * \param f The function to be set.
+   * \tparam TNode the type of Node to be dispatched.
+   * \return reference to self.
+   */
+  template<typename TNode>
+  inline TSelf& set_dispatch(Function f) {  // NOLINT(*)
+    uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
+    if (func_.size() <= tindex) {
+      func_.resize(tindex + 1, nullptr);
+    }
+    CHECK(func_[tindex] == nullptr)
+        << "Dispatch for " << Node::TypeIndex2Key(tindex)
+        << " is already set";
+    func_[tindex] = f;
+    return *this;
+  }
+  /*!
+   * \brief set the dispacher for type TNode
+   *  This allows f to used detailed const Node pointer to replace NodeRef
+   *
+   * \param f The function to be set.
+   * \tparam TNode the type of Node to be dispatched.
+   * \return reference to self.
+   */
+  template<typename TNode>
+  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
+    Function fun = [f](const NodeRef& n, Args... args) {
+      return f(static_cast<const TNode*>(n.node_.get()),
+               std::forward<Args>(args)...);
+    };
+    return this->set_dispatch<TNode>(fun);
+  }
+  /*!
+  * \brief unset the dispacher for type TNode
+  *
+  * \tparam TNode the type of Node to be dispatched.
+  * \return reference to self.
+  */
+  template<typename TNode>
+  inline TSelf& clear_dispatch() {  // NOLINT(*)
+    uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
+    CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
+    func_[tindex] = nullptr;
+    return *this;
+  }
+};
+
+#if defined(__GNUC__)
+#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
+#else
+#define TVM_ATTRIBUTE_UNUSED
+#endif
+
+/*! \brief helper macro to generate string concat */
+#define TVM_STR_CONCAT_(__x, __y) __x##__y
+#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
+
+#define TVM_REGISTER_VAR_DEF(ClsName)                                 \
+  static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
+
+/*!
+ * \brief Useful macro to set IRFunctor dispatch in a global static field.
+ *
+ * \code
+ *  // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
+ *  // vtable allows easy patch in of new Node types, without changing
+ *  // interface of IRPrinter.
+ *
+ *  class IRPrinter {
+ *   public:
+ *    std::ostream& stream;
+ *    // the dispatch function.
+ *    void print(Expr e) {
+ *      const static FType& f = *vtable();
+ *      f(e, this);
+ *    }
+ *
+ *    using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
+ *    // function to return global function table
+ *    static FType& vtable();
+ *  };
+ *
+ *  // in cpp/cc file
+ *  IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
+ *    static FType inst; return inst;
+ *  }
+ *
+ *  TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+ *  .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
+ *    p->print(n->a);
+ *    p->stream << '+'
+ *    p->print(n->b);
+ *  });
+ *
+ *
+ * \endcode
+ *
+ * \param ClsName The name of the class
+ * \param FField The static function that returns a singleton of IRFunctor.
+ */
+#define TVM_STATIC_IR_FUNCTOR(ClsName, FField)                       \
+  TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__)  =      \
+                              ClsName::FField()
+
+ /*!
+ * \brief A container for a list of callbacks. All callbacks are invoked when
+ * the object is destructed.
+ */
+class IRFunctorCleanList {
+ public:
+  ~IRFunctorCleanList() {
+    for (auto &f : clean_items) {
+      f();
+    }
+  }
+
+  void append(std::function<void()> func) {
+    clean_items.push_back(func);
+  }
+
+ private:
+  std::vector< std::function<void()> > clean_items;
+};
+
+/*!
+* \brief A wrapper around IRFunctor that will record calls to set_dispatch
+* and make a corresponding call to clear_dispatch when the last copy of
+* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
+* this can be used by NNVM and other libraries to unregister callbacks when
+* the library is unloaded. This prevents crashes when the underlying IRFunctor
+* is destructed as it will no longer contain std::function instances allocated
+* by a library that has been unloaded.
+*/
+template<typename FType>
+class IRFunctorStaticRegistry;
+
+template<typename R, typename ...Args>
+class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
+ private:
+  IRFunctor<R(const NodeRef& n, Args...)> *irf_;
+  std::shared_ptr<IRFunctorCleanList> free_list;
+
+  using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
+
+ public:
+  IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
+    irf_ = irf;
+    free_list = std::make_shared<IRFunctorCleanList>();
+  }
+
+  template<typename TNode>
+  inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) {  // NOLINT(*)
+    irf_->template set_dispatch<TNode>(f);
+    auto irf_copy = irf_;
+    free_list.get()->append([irf_copy] {
+      irf_copy->template clear_dispatch<TNode>();
+      });
+    return *this;
+  }
+};
+
+/*!
+* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
+* the compiler to deduce the template types.
+*/
+template<typename R, typename ...Args>
+IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
+  IRFunctor<R(const NodeRef& n, Args...)> *irf) {
+  return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
+}
+
+#define TVM_AUTO_REGISTER_VAR_DEF(ClsName)                           \
+  static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
+
+/*!
+* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
+* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
+* TVM_STATIC_IR_FUNCTOR.
+*/
+#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField)                  \
+  TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__)  = \
+                        MakeIRFunctorStaticRegistry(&ClsName::FField())
+
+}  // namespace tvm
+#endif  // TVM_NODE_IR_FUNCTOR_H_
diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h
new file mode 100644 (file)
index 0000000..1bba571
--- /dev/null
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/memory.h
+ * \brief Node memory management.
+ */
+#ifndef TVM_NODE_MEMORY_H_
+#define TVM_NODE_MEMORY_H_
+
+#include <utility>
+#include "node.h"
+
+namespace tvm {
+/*!
+ * \brief Allocate a node object.
+ * \param args arguments to the constructor.
+ * \tparam T the node type.
+ * \return The NodePtr to the allocated object.
+ */
+template<typename T, typename... Args>
+inline NodePtr<T> make_node(Args&&... args);
+
+// Detail implementations after this
+//
+// The current design allows swapping the
+// allocator pattern when necessary.
+//
+// Possible future allocator optimizations:
+// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
+// - Thread-local object pools: one pool per size and alignment requirement.
+// - Can specialize by type of object to give the specific allocator to each object.
+//
+template<typename T>
+class SimpleNodeAllocator {
+ public:
+  template<typename... Args>
+  static T* New(Args&&... args) {
+    return new T(std::forward<Args>(args)...);
+  }
+  static NodeBase::FDeleter Deleter() {
+    return Deleter_;
+  }
+
+ private:
+  static void Deleter_(NodeBase* ptr) {
+    delete static_cast<T*>(ptr);
+  }
+};
+
+template<typename T, typename... Args>
+inline NodePtr<T> make_node(Args&&... args) {
+  using Allocator = SimpleNodeAllocator<T>;
+  static_assert(std::is_base_of<NodeBase, T>::value,
+                "make_node can only be used to create NodeBase");
+  T* node = Allocator::New(std::forward<Args>(args)...);
+  node->deleter_ = Allocator::Deleter();
+  return NodePtr<T>(node);
+}
+
+}  // namespace tvm
+#endif  // TVM_NODE_MEMORY_H_
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
new file mode 100644 (file)
index 0000000..79187b8
--- /dev/null
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/node.h
+ * \brief Node system data structure.
+ */
+#ifndef TVM_NODE_NODE_H_
+#define TVM_NODE_NODE_H_
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/node_base.h>
+#include <string>
+#include <vector>
+#include <utility>
+#include <type_traits>
+
+
+namespace tvm {
+// forward declaration
+class DataType;
+class Node;
+class NodeRef;
+
+namespace runtime {
+// forward declaration
+class NDArray;
+// forward declaration
+class Object;
+}  // namespace runtime
+
+/*!
+ * \brief Visitor class to each node content.
+ *  The content is going to be called for each field.
+ */
+class TVM_DLL AttrVisitor {
+ public:
+//! \cond Doxygen_Suppress
+  virtual ~AttrVisitor() = default;
+  virtual void Visit(const char* key, double* value) = 0;
+  virtual void Visit(const char* key, int64_t* value) = 0;
+  virtual void Visit(const char* key, uint64_t* value) = 0;
+  virtual void Visit(const char* key, int* value) = 0;
+  virtual void Visit(const char* key, bool* value) = 0;
+  virtual void Visit(const char* key, std::string* value) = 0;
+  virtual void Visit(const char* key, void** value) = 0;
+  virtual void Visit(const char* key, DataType* value) = 0;
+  virtual void Visit(const char* key, NodeRef* value) = 0;
+  virtual void Visit(const char* key, runtime::NDArray* value) = 0;
+  virtual void Visit(const char* key, runtime::Object* value) = 0;
+  template<typename ENum,
+           typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+  void Visit(const char* key, ENum* ptr) {
+    static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
+                  "declare enum to be enum int to use visitor");
+    this->Visit(key, reinterpret_cast<int*>(ptr));
+  }
+//! \endcond
+};
+
+/*!
+ * \brief base class of node container in DSL AST.
+ */
+class TVM_DLL Node : public NodeBase {
+ public:
+  /*! \brief virtual destructor */
+  virtual ~Node() {}
+  /*! \return The unique type key of the node */
+  virtual const char* type_key() const = 0;
+  /*!
+   * \brief Apply visitor to each field of the Node
+   *  Visitor could mutate the content of the node.
+   *  override if Node contains attribute fields.
+   * \param visitor The visitor
+   */
+  virtual void VisitAttrs(AttrVisitor* visitor) {}
+  /*! \return the type index of the node */
+  virtual const uint32_t type_index() const = 0;
+  /*!
+   * \brief Whether this node derives from node with type_index=tid.
+   *  Implemented by TVM_DECLARE_NODE_TYPE_INFO
+   *
+   * \param tid The type index.
+   * \return the check result.
+   */
+  virtual const bool _DerivedFrom(uint32_t tid) const;
+  /*!
+   * \brief get a runtime unique type index given a type key
+   * \param type_key Type key of a type.
+   * \return the corresponding type index.
+   */
+  static uint32_t TypeKey2Index(const char* type_key);
+  /*!
+   * \brief get type key from type index.
+   * \param index The type index
+   * \return the corresponding type key.
+   */
+  static const char* TypeIndex2Key(uint32_t index);
+  /*!
+   * \return whether the type is derived from
+   */
+  template<typename T>
+  inline bool derived_from() const;
+  /*!
+   * \return whether the node is of type T
+   * \tparam The type to be checked.
+   */
+  template<typename T>
+  inline bool is_type() const;
+  /*!
+   * \brief Get a NodePtr that holds reference to this Node.
+   * \return the NodePtr
+   */
+  inline NodePtr<Node> GetNodePtr() const;
+  // node ref can see this
+  friend class NodeRef;
+  static constexpr const char* _type_key = "Node";
+};
+
+/*! \brief Base class of all node reference object */
+class NodeRef {
+ public:
+  /*! \brief type indicate the container type */
+  using ContainerType = Node;
+  /*!
+   * \brief Comparator
+   * \param other Another node ref.
+   * \return the compare result.
+   */
+  inline bool operator==(const NodeRef& other) const;
+  /*!
+   * \brief Comparator
+   * \param other Another node ref.
+   * \return the compare result.
+   */
+  inline bool same_as(const NodeRef& other) const;
+  /*!
+   * \brief Comparator
+   * \param other Another node ref.
+   * \return the compare result.
+   */
+  inline bool operator<(const NodeRef& other) const;
+  /*!
+   * \brief Comparator
+   * \param other Another node ref.
+   * \return the compare result.
+   */
+  inline bool operator!=(const NodeRef& other) const;
+  /*! \return the hash function for NodeRef */
+  inline size_t hash() const;
+  /*! \return whether the expression is null */
+  inline bool defined() const;
+  /*! \return the internal type index of IRNode */
+  inline uint32_t type_index() const;
+  /*! \return the internal node pointer */
+  inline const Node* get() const;
+  /*! \return the internal node pointer */
+  inline const Node* operator->() const;
+  /*!
+   * \brief Downcast this ir node to its actual type (e.g. Add, or
+   * Select). This returns nullptr if the node is not of the requested
+   * type. Example usage:
+   *
+   * if (const Add *add = node->as<Add>()) {
+   *   // This is an add node
+   * }
+   * \tparam T the target type, must be subtype of IRNode
+   */
+  template<typename T>
+  inline const T *as() const;
+  /*!
+   * \brief A more powerful version of as that also works with
+   *  intermediate base types.
+   * \tparam T the target type, must be subtype of IRNode
+   */
+  template<typename T>
+  inline const T *as_derived() const;
+  /*! \brief default constructor */
+  NodeRef() = default;
+  explicit NodeRef(NodePtr<Node> node) : node_(node) {}
+  /*! \brief the internal node object, do not touch  */
+  NodePtr<Node> node_;
+};
+
+/*!
+ * \brief Get a reference type from a Node ptr type
+ *
+ *  It is always important to get a reference type
+ *  if we want to return a value as reference or keep
+ *  the node alive beyond the scope of the function.
+ *
+ * \param ptr The node pointer
+ * \tparam RefType The reference type
+ * \tparam NodeType The node type
+ * \return The corresponding RefType
+ */
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr);
+
+/*!
+ * \brief Downcast a base reference type to a more specific type.
+ *
+ * \param ref The inptut reference
+ * \return The corresponding SubRef.
+ * \tparam SubRef The target specific reference type.
+ * \tparam BaseRef the current reference type.
+ */
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref);
+
+/*!
+ * \brief helper macro to declare type information in a base node.
+ */
+#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent)                    \
+  const bool _DerivedFrom(uint32_t tid) const override {                \
+    static uint32_t tidx = TypeKey2Index(TypeName::_type_key);          \
+    if (tidx == tid) return true;                                       \
+    return Parent::_DerivedFrom(tid);                                   \
+  }
+
+/*!
+ * \brief helper macro to declare type information in a terminal node
+ */
+#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent)                    \
+  const char* type_key() const final {                                  \
+    return TypeName::_type_key;                                         \
+  }                                                                     \
+  const uint32_t type_index() const final {                             \
+    static uint32_t tidx = TypeKey2Index(TypeName::_type_key);          \
+    return tidx;                                                        \
+  }                                                                     \
+  const bool _DerivedFrom(uint32_t tid) const final {                   \
+    static uint32_t tidx = TypeKey2Index(TypeName::_type_key);          \
+    if (tidx == tid) return true;                                       \
+    return Parent::_DerivedFrom(tid);                                   \
+  }
+
+// implementations of inline functions after this
+template<typename T>
+inline bool Node::derived_from() const {
+  // use static field so query only happens once.
+  static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
+  return this->_DerivedFrom(type_id);
+}
+
+
+template<typename T>
+inline bool Node::is_type() const {
+  // use static field so query only happens once.
+  static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
+  return type_id == this->type_index();
+}
+
+
+inline NodePtr<Node> Node::GetNodePtr() const {
+  return NodePtr<Node>(const_cast<Node*>(this));
+}
+
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr) {
+  static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
+                "Can only cast to the ref of same container type");
+  return RefType(ptr->GetNodePtr());
+}
+
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref) {
+  CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
+        ref->template derived_from<typename SubRef::ContainerType>())
+      << "Downcast from " << ref->type_key() << " to "
+      << SubRef::ContainerType::_type_key << " failed.";
+  return SubRef(std::move(ref.node_));
+}
+
+inline const Node* NodeRef::get() const {
+  return node_.get();
+}
+
+inline const Node* NodeRef::operator->() const {
+  return node_.get();
+}
+
+inline bool NodeRef::defined() const {
+  return node_.get() != nullptr;
+}
+
+inline bool NodeRef::operator==(const NodeRef& other) const {
+  return node_.get() == other.node_.get();
+}
+
+inline bool NodeRef::same_as(const NodeRef& other) const {
+  return node_.get() == other.node_.get();
+}
+
+inline bool NodeRef::operator<(const NodeRef& other) const {
+  return node_.get() < other.node_.get();
+}
+
+inline bool NodeRef::operator!=(const NodeRef& other) const {
+  return node_.get() != other.node_.get();
+}
+
+inline size_t NodeRef::hash() const {
+  return std::hash<Node*>()(node_.get());
+}
+
+inline uint32_t NodeRef::type_index() const {
+  CHECK(node_.get() != nullptr)
+      << "null type";
+  return get()->type_index();
+}
+
+template<typename T>
+inline const T* NodeRef::as() const {
+  const Node* ptr = static_cast<const Node*>(get());
+  if (ptr && ptr->is_type<T>()) {
+    return static_cast<const T*>(ptr);
+  }
+  return nullptr;
+}
+
+template<typename T>
+inline const T* NodeRef::as_derived() const {
+  const Node* ptr = static_cast<const Node*>(get());
+  if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
+    return static_cast<const T*>(ptr);
+  }
+  return nullptr;
+}
+
+/*! \brief The hash function for nodes */
+struct NodeHash {
+  size_t operator()(const NodeRef& a) const {
+    return a.hash();
+  }
+};
+
+/*! \brief The equal comparator for nodes */
+struct NodeEqual {
+  bool operator()(const NodeRef& a, const NodeRef& b) const {
+    return a.get() == b.get();
+  }
+};
+}  // namespace tvm
+#endif  // TVM_NODE_NODE_H_
index 38dc39bbe7a74cb42728c584405d0839d89c65a8..2602b383aab1332a77dff26c38a6f27e954e8953 100644 (file)
@@ -53,7 +53,7 @@ struct TensorDom {
 /*!
  * \brief Base class of all operation nodes
  */
-class OperationNode : public FunctionBaseNode {
+class OperationNode : public ir::FunctionBaseNode {
  public:
   /*! \brief optional name of the operation */
   std::string name;
@@ -463,7 +463,7 @@ class ExternOpNode : public OperationNode {
     v->Visit("output_placeholders", &output_placeholders);
     v->Visit("body", &body);
   }
-  EXPORT static Operation make(std::string name,
+  TVM_DLL static Operation make(std::string name,
                                std::string tag,
                                Map<std::string, NodeRef> attrs,
                                Array<Tensor> inputs,
@@ -530,12 +530,12 @@ class HybridOpNode : public OperationNode {
     v->Visit("axis", &axis);
     v->Visit("body", &body);
   }
-  EXPORT static Operation make(std::string name,
-                               std::string tag,
-                               Map<std::string, NodeRef> attrs,
-                               Array<Tensor> inputs,
-                               Array<Tensor> outputs,
-                               Stmt body);
+  TVM_DLL static Operation make(std::string name,
+                                std::string tag,
+                                Map<std::string, NodeRef> attrs,
+                                Array<Tensor> inputs,
+                                Array<Tensor> outputs,
+                                Stmt body);
 
   static constexpr const char* _type_key = "HybridOp";
   TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
index 8bbde878741dbb28c08437f35c15c055c56cc12d..5951594b873c70de3ed97563fa4050a2c74685d4 100644 (file)
@@ -70,7 +70,9 @@ struct NodeTypeChecker<Array<T> > {
     if (!sptr->is_type<ArrayNode>()) return false;
     ArrayNode* n = static_cast<ArrayNode*>(sptr);
     for (const auto& p : n->data) {
-      if (!NodeTypeChecker<T>::Check(p.get())) return false;
+      if (!NodeTypeChecker<T>::Check(p.get())) {
+        return false;
+      }
     }
     return true;
   }
@@ -144,7 +146,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
   return TNodeRef(sptr);
 }
 
-inline TVMArgValue::operator HalideIR::Expr() const {
+inline TVMArgValue::operator tvm::Expr() const {
   if (type_code_ == kNull) return Expr();
   if (type_code_ == kDLInt) {
     CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
@@ -240,21 +242,21 @@ inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const {  /
 }
 
 // type related stuffs
-inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
-  return this->operator=(Type2TVMType(t));
+inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
+  return this->operator=(t.operator DLDataType());
 }
 
-inline TVMRetValue::operator HalideIR::Type() const {
-  return TVMType2Type(operator TVMType());
+inline TVMRetValue::operator tvm::DataType() const {
+  return DataType(operator DLDataType());
 }
 
-inline TVMArgValue::operator HalideIR::Type() const {
-  return TVMType2Type(operator TVMType());
+inline TVMArgValue::operator tvm::DataType() const {
+  return DataType(operator DLDataType());
 }
 
 inline void TVMArgsSetter::operator()(
-    size_t i, const HalideIR::Type& t) const {
-  this->operator()(i, Type2TVMType(t));
+    size_t i, const DataType& t) const {
+  this->operator()(i, t.operator DLDataType());
 }
 }  // namespace runtime
 }  // namespace tvm
index 17fd626ee51d0ff8990bf8f3507375325e253c35..f06b2583127af45cf0c0a844dda8e578ef8d5c2e 100644 (file)
 #include "object.h"
 #include "node_base.h"
 
-namespace HalideIR {
-// Forward declare type for extensions
-// The header works fine without depending on this.
-struct Type;
-struct Expr;
-}
-
-
 // Whether use TVM runtime in header only mode.
 #ifndef TVM_RUNTIME_HEADER_ONLY
 #define TVM_RUNTIME_HEADER_ONLY 0
@@ -58,6 +50,8 @@ struct Expr;
 namespace tvm {
 // forward declarations
 class Integer;
+class DataType;
+class Expr;
 
 namespace runtime {
 
@@ -626,8 +620,8 @@ class TVMArgValue : public TVMPODValue_ {
            typename = typename std::enable_if<
              std::is_class<TNodeRef>::value>::type>
   inline bool IsNodeType() const;
-  inline operator HalideIR::Type() const;
-  inline operator HalideIR::Expr() const;
+  inline operator tvm::DataType() const;
+  inline operator tvm::Expr() const;
   inline operator tvm::Integer() const;
   // get internal node ptr, if it is node
   inline NodePtr<Node>& node_sptr();
@@ -835,8 +829,8 @@ class TVMRetValue : public TVMPODValue_ {
   inline TVMRetValue& operator=(const NodeRef& other);
   inline TVMRetValue& operator=(const NodePtr<Node>& other);
   // type related
-  inline operator HalideIR::Type() const;
-  inline TVMRetValue& operator=(const HalideIR::Type& other);
+  inline operator tvm::DataType() const;
+  inline TVMRetValue& operator=(const tvm::DataType& other);
 
  private:
   template<typename T>
@@ -1184,7 +1178,7 @@ class TVMArgsSetter {
   inline void operator()(size_t i, const T& value) const;
   // NodeRef related extenstions: in tvm/packed_func_ext.h
   inline void operator()(size_t i, const NodeRef& other) const;  // NOLINT(*)
-  inline void operator()(size_t i, const HalideIR::Type& t) const;
+  inline void operator()(size_t i, const tvm::DataType& t) const;
 
  private:
   /*! \brief The values fields */
index 659b42aa1afaf5253843197694ba08b8848e67ed..ac37f017436e71a181363e1c82b58680c658eff0 100644 (file)
@@ -75,24 +75,24 @@ class Stage : public NodeRef {
    * \brief set the memory scope of the stage
    * \param scope The memory scope.
    */
-  EXPORT Stage& set_scope(std::string scope);  // NOLINT(*)
+  TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
   /*!
    * \brief specify the schedule to be computed at the parent schedule's scope.
    * \param parent The parent schedule.
    * \param scope The iteration point to carry the schedule.
    * \return reference to self.
    */
-  EXPORT Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
+  TVM_DLL Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
   /*!
    * \brief Compute the function inline.
    * \return reference to self.
    */
-  EXPORT Stage& compute_inline();   // NOLINT(*)
+  TVM_DLL Stage& compute_inline();   // NOLINT(*)
   /*!
    * \brief Compute the function at group root.
    * \return reference to self.
    */
-  EXPORT Stage& compute_root();  // NOLINT(*)
+  TVM_DLL Stage& compute_root();  // NOLINT(*)
   /*!
    * \brief Bind the IterVar to thread index.
    *
@@ -100,7 +100,7 @@ class Stage : public NodeRef {
    * \param thread_ivar The thread axis to be bound.
    * \return reference to self.
    */
-  EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar);
+  TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
   /*!
    * \brief Set the predicate to determine whether a store to the array should be performed.
    *  Use this when there are multiple threads performing the same store and we only
@@ -111,7 +111,7 @@ class Stage : public NodeRef {
    * \param predicate The condition to be checked.
    * \return reference to self.
    */
-  EXPORT Stage& set_store_predicate(Expr predicate);
+  TVM_DLL Stage& set_store_predicate(Expr predicate);
   /*!
    * \brief Specify environment threads that launched around the group's scope.
    *  This can only be used in group stage.
@@ -120,7 +120,7 @@ class Stage : public NodeRef {
    *    This is a beta feature.
    * \return reference to self.
    */
-  EXPORT Stage& env_threads(Array<IterVar> threads);
+  TVM_DLL Stage& env_threads(Array<IterVar> threads);
   /*!
    * \brief Split the parent by factor, generate
    * \param parent The parent iteration domain.
@@ -129,7 +129,7 @@ class Stage : public NodeRef {
    * \param p_inner The result inner domain.
    * \return reference to self.
    */
-  EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
+  TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
   /*!
    * \brief Split the iteration with given number of parts.
    *
@@ -139,7 +139,7 @@ class Stage : public NodeRef {
    * \param p_inner The result inner domain.
    * \return reference to self.
    */
-  EXPORT Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
+  TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
   /*!
    * \brief Fuse the inner outer domain to the target
    * \param outer The outer domain to be fused.
@@ -147,7 +147,7 @@ class Stage : public NodeRef {
    * \param p_target The result target domain.
    * \return reference to self.
    */
-  EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
+  TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
   /*!
    * \brief Fuse all the axes together into a single axis.
    *
@@ -161,13 +161,13 @@ class Stage : public NodeRef {
    *
    * \return reference to self.
    */
-  EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
+  TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
   /*!
    * \brief Reorder the iteration
    * \param order The order of iteration variable.
    * \return reference to self.
    */
-  EXPORT Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
+  TVM_DLL Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
   /*!
    * \brief Perform tiling on two dimensions
    *  The final loop order from outmost to inner most are
@@ -183,7 +183,7 @@ class Stage : public NodeRef {
    * \param p_y_inner Inner axis of y dimension
    * \return reference to self.
    */
-  EXPORT Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
+  TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
                      Expr x_factor, Expr y_factor,
                      IterVar* p_x_outer, IterVar* p_y_outer,
                      IterVar* p_x_inner, IterVar* p_y_inner);
@@ -192,7 +192,7 @@ class Stage : public NodeRef {
    * \param var The axis to be vectorized.
    * \return reference to self.
    */
-  EXPORT Stage& vectorize(IterVar var);   // NOLINT(*)
+  TVM_DLL Stage& vectorize(IterVar var);   // NOLINT(*)
   /*!
    * \brief Replace computation of the current stage by tensor intrinsic f.
    * \param var The axis marks beginning of tensorization.
@@ -200,19 +200,19 @@ class Stage : public NodeRef {
    * \param f The Tensor compute intrinsics.
    * \return reference to self.
    */
-  EXPORT Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
+  TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
   /*!
    * \brief Unroll iteration.
    * \param var The axis to be unrolled.
    * \return reference to self.
    */
-  EXPORT Stage& unroll(IterVar var);   // NOLINT(*)
+  TVM_DLL Stage& unroll(IterVar var);   // NOLINT(*)
   /*!
    * \brief Parallelize iteration.
    * \param var The axis to be parallelized.
    * \return reference to self.
    */
-  EXPORT Stage& parallel(IterVar var);   // NOLINT(*)
+  TVM_DLL Stage& parallel(IterVar var);   // NOLINT(*)
   /*!
    * \brief Annotate the iteration with pragma
    *
@@ -222,7 +222,7 @@ class Stage : public NodeRef {
    *
    * \return reference to self.
    */
-  EXPORT Stage& pragma(IterVar var,
+  TVM_DLL Stage& pragma(IterVar var,
                        const std::string& pragma_type,
                        const Expr& pragma_value = Expr());   // NOLINT(*)
   /*!
@@ -232,7 +232,7 @@ class Stage : public NodeRef {
    * \param offset the number of iterations be to fetched in advance
    * \return reference to self
    */
-  EXPORT Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
+  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
   /*!
    * \brief Set alignment requirement for specific dimension.
    *
@@ -243,12 +243,12 @@ class Stage : public NodeRef {
    * \param offset The required offset factor.
    * \return reference to self
    */
-  EXPORT Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
+  TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
   /*!
    * \brief Compute current stage with double buffering.
    * \return reference to self.
    */
-  EXPORT Stage& double_buffer();   // NOLINT(*)
+  TVM_DLL Stage& double_buffer();   // NOLINT(*)
   /*!
    * \brief Schedule for OpenGL fragment shader.
    * \return reference to self.
@@ -289,13 +289,13 @@ class Schedule : public NodeRef {
    * \brief Get the stage corresponds to the op
    * \param op The operation.
    */
-  EXPORT Stage operator[](const Operation& op);
+  TVM_DLL Stage operator[](const Operation& op);
   /*!
    * \brief Short hand for getting the stage of tensor's operation.
    * \param tensor The tensor
    * \return The stage corresponding to the tensor's op
    */
-  EXPORT Stage operator[](const Tensor& tensor) {
+  TVM_DLL Stage operator[](const Tensor& tensor) {
     return this->operator[](tensor->op);
   }
   /*!
@@ -307,7 +307,7 @@ class Schedule : public NodeRef {
    * \param include_inputs Whether include inputs if they are reachable from outputs.
    * \return The new grouped stage.
    */
-  EXPORT Stage create_group(const Array<Tensor>& outputs,
+  TVM_DLL Stage create_group(const Array<Tensor>& outputs,
                      const Array<Tensor>& inputs,
                      bool include_inputs = false);
   /*!
@@ -319,7 +319,7 @@ class Schedule : public NodeRef {
    * \param readers The readers to redirect to the tensor.
    * \return The created tensor.
    */
-  EXPORT Tensor cache_read(const Tensor& tensor,
+  TVM_DLL Tensor cache_read(const Tensor& tensor,
                     const std::string& scope,
                     const Array<Operation>& readers);
   /*!
@@ -338,7 +338,7 @@ class Schedule : public NodeRef {
    * \param scope The scope of the storage.
    * \return The created tensor.
    */
-  EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
+  TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
   /*!
    * \brief Create a cache write tensor for producing tensor.
    *  The the tensor will take over body of original tensor op.
@@ -355,7 +355,7 @@ class Schedule : public NodeRef {
    * \param scope The scope of the storage.
    * \return The created tensor.
    */
-  EXPORT Tensor cache_write(const Tensor& tensor, const std::string& scope);
+  TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
   /*!
    * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
    * This will create a new stage that generated the new tensor with axis
@@ -369,7 +369,7 @@ class Schedule : public NodeRef {
    * \param factor_axis The position where the new axis is placed.
    * \return The created factored tensors.
    */
-  EXPORT Array<Tensor> rfactor(const Tensor& tensor,
+  TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
                         const IterVar& axis,
                         int factor_axis = 0);
   /*!
@@ -556,14 +556,14 @@ class ScheduleNode : public Node {
    * \param op The candidate Operation.
    * \return true if the schedule has the Operation. Otherwise, false.
    */
-  EXPORT bool Contain(const Operation& op) const;
+  TVM_DLL bool Contain(const Operation& op) const;
 
   /*!
    * \brief Check if the schedule contains a Tensor.
    * \param tensor The candidate tensor.
    * \return true if the schedule has the tensor. Otherwise, false.
    */
-  EXPORT bool Contain(const Tensor& tensor) const {
+  TVM_DLL bool Contain(const Tensor& tensor) const {
     return Contain(tensor->op);
   }
 
@@ -572,7 +572,7 @@ class ScheduleNode : public Node {
    * \param ops The ops to be scheduled.
    * \return sch The created Schedule.
    */
-  EXPORT static Schedule make(Array<Operation> ops);
+  TVM_DLL static Schedule make(Array<Operation> ops);
 
   static constexpr const char* _type_key = "Schedule";
   TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
index 27444ab693cd4f25663910d5877b238b39e75ded..3187e2b1772724484a38f031773a4a82b00d44a5 100644 (file)
@@ -70,7 +70,7 @@ void AutoInlineElemWise(Schedule sch);
  *
  * \param sch The schedule to be inlined.
  */
-EXPORT void AutoInlineInjective(Schedule sch);
+TVM_DLL void AutoInlineInjective(Schedule sch);
 
 }  // namespace schedule
 }  // namespace tvm
index c6be52181f6ca4cf531c7a3bb9dafaf3751848b1..2b33eea3c9c43019b944378fdcdd003ac762e280 100644 (file)
@@ -24,7 +24,6 @@
 #ifndef TVM_TENSOR_H_
 #define TVM_TENSOR_H_
 
-#include <ir/FunctionBase.h>
 #include <tvm/node/container.h>
 #include <string>
 #include <vector>
@@ -43,8 +42,6 @@ class TensorNode;
 // internal node container for Operation
 class OperationNode;
 
-using HalideIR::IR::FunctionRef;
-
 /*!
  * \brief Tensor structure representing a possible input,
  *  or intermediate computation result.
@@ -140,7 +137,7 @@ class Tensor : public NodeRef {
 };
 
 /*! \brief Operation that produces tensors */
-class Operation : public FunctionRef {
+class Operation : public ir::FunctionRef {
  public:
   /*! \brief default constructor  */
   Operation() {}
index 2525059b47ba600de118e010ec0f69c8ea09b6b6..e8e43786d3fd9ad1dfae12651e21a0bfe73f8642 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -59,14 +59,13 @@ TVM_REGISTER_API("make._range_by_min_extent")
 TVM_REGISTER_API("make.For")
 .set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
   VarExpr loop_var, Expr min, Expr extent,
-  int for_type, int device_api, Stmt body
-) {
+  int for_type, int device_api, Stmt body) {
   return For::make(loop_var,
-                    min,
-                    extent,
-                    static_cast<ForType>(for_type),
-                    static_cast<HalideIR::DeviceAPI>(device_api),
-                    body);
+                   min,
+                   extent,
+                   static_cast<ForType>(for_type),
+                   static_cast<DeviceAPI>(device_api),
+                   body);
 });
 
 TVM_REGISTER_API("make.Load")
index 00ac715e8c0754eaff5cdb61fd8e6c44d2e8bd09..aa0ce47b4a37b885b7d60585dc1c0d7099da3ac7 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  *  Implementation of API functions related to Higher DSL build.
  * \file api_lang.cc
  */
 namespace tvm {
 
 TVM_REGISTER_API("_min_value")
-.set_body_method(&Type::min);
+.set_body_method(&DataType::min);
 
 TVM_REGISTER_API("_max_value")
-.set_body_method(&Type::max);
+.set_body_method(&DataType::max);
 
 TVM_REGISTER_API("_const")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
index 1a93c8b3f7d8236452d0c7d0c9f46a333903f7d5..3cc64278891115b52e5b2bcd719ccfe8ef6767d9 100644 (file)
@@ -52,12 +52,6 @@ class CanonicalExprNode : public BaseExprNode {
   // overrides
   void VisitAttrs(tvm::AttrVisitor* v) final {
   }
-  void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final {
-    LOG(FATAL) << "not supported";
-  }
-  IRNodeType type_info() const final {
-    return IRNodeType::ExtensionExpr;
-  }
 
   static constexpr const char* _type_key = "arith.CanonicalExpr";
   TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
index b36d8f5625a1a29633cd74c56fb5475d06ab59b2..84b452cd70430355396401c721adf6dda7048ac9 100644 (file)
@@ -125,7 +125,7 @@ class ConstIntBoundAnalyzer::Impl :
   // Override visitor behaviors
   Entry VisitExprDefault_(const Node* op) final {
     return Everything(
-        static_cast<const ir::BaseExprNode*>(op)->type);
+        static_cast<const ExprNode*>(op)->type);
   }
 
   Entry VisitExpr(const Expr& expr) final {
index 797a7d1c406ee64dec2538e7b54f61cf2b4f0913..5b14abbcf7dc983286caa8e12540f0e01af2c3b4 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -224,15 +224,15 @@ std::string CodeGenOpenGL::GetBufferRef(
 
 void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
   switch (t.code()) {
-    case halideir_type_int:
+    case kDLInt:
       CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
       os << "int";
       break;
-    case halideir_type_uint:
+    case kDLUInt:
       CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
       os << "uint";
       break;
-    case halideir_type_float:
+    case kDLFloat:
       CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
       os << "float";
       break;
index fed604922ee393af98f9766b0b7d1c2dafe81f70..11b72c71fda7425916698a14597471d63e560862 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  * \file expr.cc
  */
 #include <tvm/base.h>
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/expr_operator.h>
-#include <ir/IRPrinter.h>
 #include <memory>
+#include <limits>
 
 namespace tvm {
 
-using HalideIR::IR::RangeNode;
+// maximum and min values
+Expr DataType::max() const {
+  using namespace ir;
+  CHECK_EQ(lanes(), 1);
+  if (is_int()) {
+    if (bits() == 64) {
+      return IntImm::make(*this, std::numeric_limits<int64_t>::max());
+    } else if (bits() < 64) {
+      int64_t val = 1;
+      val = (val << (bits() - 1)) - 1;
+      return IntImm::make(*this, val);
+    }
+  } else if (is_uint()) {
+    if (bits() == 64) {
+      return UIntImm::make(*this, std::numeric_limits<uint64_t>::max());
+    } else if (bits() < 64) {
+      uint64_t val = 1;
+      val = (val << static_cast<uint64_t>(bits())) - 1;
+      return UIntImm::make(*this, val);
+    }
+  } else if (is_float()) {
+    if (bits() == 64) {
+      return FloatImm::make(*this, std::numeric_limits<double>::max());
+    } else if (bits() == 32) {
+      return FloatImm::make(*this, std::numeric_limits<float>::max());
+    } else if (bits() == 16) {
+      return FloatImm::make(*this, 65504.0);
+    }
+  }
+  LOG(FATAL) << "Cannot decide max_value for type" << *this;
+  return Expr();
+}
+
+Expr DataType::min() const {
+  using namespace ir;
+  CHECK_EQ(lanes(), 1);
+  if (is_int()) {
+    if (bits() == 64) {
+      return IntImm::make(*this, std::numeric_limits<int64_t>::lowest());
+    } else if (bits() < 64) {
+      int64_t val = 1;
+      val = -(val << (bits() - 1));
+      return IntImm::make(*this, val);
+    }
+  } else if (is_uint()) {
+    return UIntImm::make(*this, 0);
+  } else if (is_float()) {
+    if (bits() == 64) {
+      return FloatImm::make(*this, std::numeric_limits<double>::lowest());
+    } else if (bits() == 32) {
+      return FloatImm::make(*this, std::numeric_limits<float>::lowest());
+    } else if (bits() == 16) {
+      return FloatImm::make(*this, -65504.0);
+    }
+  }
+  LOG(FATAL) << "Cannot decide min_value for type" << *this;
+  return Expr();
+}
+
+Expr::Expr(int32_t value)
+    : Expr(IntImm::make(Int(32), value)) {}
+
+Expr::Expr(float value)
+    : Expr(ir::FloatImm::make(Float(32), value)) {}
+
+Expr::Expr(std::string str)
+    : Expr(ir::StringImm::make(str)) {}
+
+Var::Var(std::string name_hint, DataType t)
+    : Var(Variable::make(t, name_hint)) {}
+
+Var Variable::make(DataType t, std::string name_hint) {
+  NodePtr<Variable> node = make_node<Variable>();
+  node->type = t;
+  node->name_hint = std::move(name_hint);
+  return Var(node);
+}
 
 Range::Range(Expr begin, Expr end)
     : Range(make_node<RangeNode>(
@@ -38,12 +113,23 @@ Range::Range(Expr begin, Expr end)
           is_zero(begin) ? end : (end - begin))) {
 }
 
+Integer IntImm::make(Type t, int64_t value) {
+  CHECK(t.is_int() && t.is_scalar())
+      << "ValueError: IntImm can only take scalar.";
+  NodePtr<IntImm> node = make_node<IntImm>();
+  node->type = t;
+  node->value = value;
+  return Integer(node);
+}
+
 Range Range::make_by_min_extent(Expr min, Expr extent) {
-  return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
+  return Range(make_node<RangeNode>(min, extent));
 }
 
-IterVar IterVarNode::make(Range dom, Var var,
-                          IterVarType t, std::string thread_tag) {
+IterVar IterVarNode::make(Range dom,
+                          Var var,
+                          IterVarType t,
+                          std::string thread_tag) {
   NodePtr<IterVarNode> n = make_node<IterVarNode>();
   n->dom = dom;
   n->var = var;
@@ -62,19 +148,48 @@ IterVar reduce_axis(Range dom, std::string name) {
       dom, Var(name), kCommReduce);
 }
 
-std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
-  IRPrinter(os).print(n);
-  return os;
-}
-
 void Dump(const NodeRef& n) {
   std::cerr << n << "\n";
 }
 
-Var var(const std::string& name_hint, Type t) {
+Var var(std::string name_hint, Type t) {
   return Var(name_hint, t);
 }
 
+void IRPrinter::Print(const NodeRef& ir) {
+  static const FType& f = vtable();
+  if (!ir.defined()) {
+    stream << "(nullptr)";
+  } else {
+    if (f.can_dispatch(ir)) {
+      f(ir, this);
+    } else {
+      // default value, output type key and addr.
+      stream << ir->type_key() << "(" << ir.get() << ")";
+    }
+  }
+}
+
+void IRPrinter::PrintIndent() {
+  for (int i = 0; i < indent; ++i) {
+    stream << ' ';
+  }
+}
+
+IRPrinter::FType& IRPrinter::vtable() {
+  static FType inst;
+  return inst;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) {
+    if (op->type == Int(32)) {
+      p->stream << op->value;
+    } else {
+      p->stream << "(" << op->type << ")" << op->value;
+    }
+  });
+
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 .set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
     p->stream << "iter_var(";
@@ -91,11 +206,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) {
+.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) {
     p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
   });
 
-
 TVM_REGISTER_NODE_TYPE(ArrayNode);
 TVM_REGISTER_NODE_TYPE(MapNode);
 TVM_REGISTER_NODE_TYPE(StrMapNode);
index 4eeddd91d80c46c2c16d00e7bd915127a533c74a..0557e287986f0fc3735b25083a55ad0cdbff9cae 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  * \file ir.cc
  */
 #include <tvm/base.h>
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
-#include <ir/IR.h>
-#include <ir/IRPrinter.h>
 #include <memory>
 #include "../pass/ir_util.h"
 
-namespace HalideIR {
-namespace Internal {
+namespace tvm {
+namespace ir {
 
-using tvm::ir::CommReducerNode;
-using tvm::ir::Reduce;
-using tvm::ir::Any;
-using tvm::ir::AttrStmt;
+// constructors
+Expr UIntImm::make(DataType t, uint64_t value) {
+  CHECK(t.is_uint() && t.lanes() == 1)
+      << "ValueError: UIntImm can only take scalar";
+  NodePtr<UIntImm> node = make_node<UIntImm>();
+  node->type = t;
+  node->value = value;
+  return Expr(node);
+}
 
-template<>
-void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
-  LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
+Expr FloatImm::make(DataType t, double value) {
+  CHECK_EQ(t.lanes(), 1)
+      << "ValueError: FloatImm can only take scalar";
+  NodePtr<FloatImm> node = make_node<FloatImm>();
+  node->type = t;
+  node->value = value;
+  return Expr(node);
 }
 
-template<>
-void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
-  LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
+Expr StringImm::make(std::string value) {
+  NodePtr<StringImm> node = make_node<StringImm>();
+  node->type = Handle();
+  node->value = std::move(value);
+  return Expr(node);
 }
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
-  p->stream << "?";
-});
+Expr Cast::make(DataType t, Expr value) {
+  CHECK(value.defined());
+  CHECK_EQ(t.lanes(), value.type().lanes());
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
-  p->stream << "reduce(combiner="
-            << op->combiner;
-  p->stream << ", source=" << op->source;
-  p->stream << ", axis=" << op->axis;
-  p->stream << ", where=" << op->condition;
-  p->stream << ", value_index=" << op->value_index;
-  p->stream << ")";
-});
+  NodePtr<Cast> node = make_node<Cast>();
+  node->type = t;
+  node->value = std::move(value);
+  return Expr(node);
+}
 
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
-  p->stream << "comm_reducer(result=" << op->result
-            << ", lhs=" << op->lhs
-            << ", rhs=" << op->rhs
-            << ", identity_element=" << op->identity_element
-            << ")";
-});
-}  // namespace Internal
-}  // namespace HalideIR
 
-namespace tvm {
-namespace ir {
+Expr And::make(Expr a, Expr b) {
+  CHECK(a.defined()) << "ValueError: a is undefined";
+  CHECK(b.defined()) << "ValueError: b is undefined";
+  CHECK(a.type().is_bool());
+  CHECK(b.type().is_bool());
+  CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+
+  NodePtr<And> node = make_node<And>();
+  node->type = Bool(a.type().lanes());
+  node->a = std::move(a);
+  node->b = std::move(b);
+  return Expr(node);
+}
+
+Expr Or::make(Expr a, Expr b) {
+  CHECK(a.defined()) << "ValueError: a is undefined";
+  CHECK(b.defined()) << "ValueError: b is undefined";
+  CHECK(a.type().is_bool());
+  CHECK(b.type().is_bool());
+  CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+
+  NodePtr<Or> node = make_node<Or>();
+  node->type = Bool(a.type().lanes());
+  node->a = std::move(a);
+  node->b = std::move(b);
+  return Expr(node);
+}
+
+Expr Not::make(Expr a) {
+  CHECK(a.defined()) << "ValueError: a is undefined";
+  CHECK(a.type().is_bool());
+
+  NodePtr<Not> node = make_node<Not>();
+  node->type = Bool(a.type().lanes());
+  node->a = std::move(a);
+  return Expr(node);
+}
+
+Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
+  CHECK(condition.defined()) << "ValueError: condition is undefined";
+  CHECK(true_value.defined()) << "ValueError: true_value is undefined";
+  CHECK(false_value.defined()) << "ValueError: true_value is undefined";
+  CHECK(condition.type().is_bool());
+  CHECK_EQ(condition.type().lanes(), true_value.type().lanes());
+  CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types";
+
+  NodePtr<Select> node = make_node<Select>();
+  node->type = true_value.type();
+  node->condition = std::move(condition);
+  node->true_value = std::move(true_value);
+  node->false_value = std::move(false_value);
+  return Expr(node);
+}
+
+Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
+  CHECK(buffer_var.defined());
+  CHECK(predicate.defined());
+  CHECK(index.defined());
+  CHECK_EQ(type.lanes(), index.type().lanes());
+  CHECK_EQ(type.lanes(), predicate.type().lanes());
+
+  NodePtr<Load> node = make_node<Load>();
+  node->type = type;
+  node->buffer_var = std::move(buffer_var);
+  node->index = std::move(index);
+  node->predicate = std::move(predicate);
+
+  return Expr(node);
+}
+
+Expr Ramp::make(Expr base, Expr stride, int lanes) {
+  CHECK(base.defined());
+  CHECK(stride.defined());
+  CHECK(base.type().is_scalar());
+  CHECK(stride.type().is_scalar());
+  CHECK_GT(lanes, 1);
+  CHECK_EQ(stride.type(), base.type());
+
+  NodePtr<Ramp> node = make_node<Ramp>();
+  node->type = base.type().with_lanes(lanes);
+  node->base = base;
+  node->stride = stride;
+  node->lanes = lanes;
+  return Expr(node);
+}
+
+Expr Broadcast::make(Expr value, int lanes) {
+  CHECK(value.defined());
+  CHECK(value.type().is_scalar());
+  CHECK_GT(lanes, 1);
+
+  NodePtr<Broadcast> node = make_node<Broadcast>();
+  node->type = value.type().with_lanes(lanes);
+  node->value = std::move(value);
+  node->lanes = lanes;
+  return Expr(node);
+}
+
+Expr Let::make(Var var, Expr value, Expr body) {
+  CHECK(value.defined());
+  CHECK(body.defined());
+  CHECK_EQ(value.type(), var.type());
+
+  NodePtr<Let> node = make_node<Let>();
+  node->type = body.type();
+  node->var = std::move(var);
+  node->value = std::move(value);
+  node->body = std::move(body);
+  return Expr(node);
+}
+
+Expr Call::make(DataType type,
+                std::string name,
+                Array<Expr> args,
+                CallType call_type,
+                FunctionRef func,
+                int value_index) {
+  for (size_t i = 0; i < args.size(); ++i) {
+    CHECK(args[i].defined());
+  }
+
+  if (call_type == Halide) {
+    for (size_t i = 0; i < args.size(); ++i) {
+      CHECK(args[i].type().is_int());
+    }
+  }
+
+  NodePtr<Call> node = make_node<Call>();
+  node->type = type;
+  node->name = std::move(name);
+  node->args = std::move(args);
+  node->call_type = call_type;
+  node->func = std::move(func);
+  node->value_index = value_index;
+  return Expr(node);
+}
+
+Expr Shuffle::make(Array<Expr> vectors,
+                   Array<Expr> indices) {
+  CHECK_NE(vectors.size(), 0U);
+  CHECK_NE(indices.size(), 0U);
+
+  Type base_type = vectors[0].type().element_of();
+  int total_lanes = 0;
+
+  for (Expr val : vectors) {
+    CHECK(val.type().element_of() == base_type);
+    total_lanes += val.type().lanes();
+  }
+  CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
+
+  NodePtr<Shuffle> node = make_node<Shuffle>();
+  node->type = base_type.with_lanes(static_cast<int>(indices.size()));
+  node->vectors = std::move(vectors);
+  node->indices = std::move(indices);
+  return Expr(node);
+}
+
+Expr Shuffle::make_concat(Array<Expr> vectors) {
+  CHECK_NE(vectors.size(), 0);
+  if (vectors.size() == 1) {
+    return vectors[0];
+  }
+  Array<Expr> indices;
+  int index = 0;
+  for (const Expr& e : vectors) {
+    for (int i = 0; i < e.type().lanes(); ++i) {
+      indices.push_back(IntImm::make(Int(32), index++));
+    }
+  }
+  return make(vectors, indices);
+}
+
+Expr Shuffle::make_extract_element(Expr vector, int index) {
+  return make({vector}, {Integer(index)});
+}
 
 CommReducer CommReducerNode::make(Array<Var> lhs,
                                   Array<Var> rhs,
@@ -132,6 +298,802 @@ Expr Any::make() {
   return Expr(n);
 }
 
+Stmt LetStmt::make(Var var, Expr value, Stmt body) {
+  CHECK(value.defined());
+  CHECK(body.defined());
+  CHECK_EQ(value.type(), var.type());
+
+  NodePtr<LetStmt> node = make_node<LetStmt>();
+  node->var = std::move(var);
+  node->value = std::move(value);
+  node->body = std::move(body);
+  return Stmt(node);
+}
+
+Stmt AttrStmt::make(NodeRef node,
+                    std::string attr_key,
+                    Expr value,
+                    Stmt body) {
+  auto n = make_node<AttrStmt>();
+  n->node = node;
+  n->attr_key = std::move(attr_key);
+  n->value = std::move(value);
+  n->body = std::move(body);
+  return Stmt(n);
+}
+
+Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
+  CHECK(condition.defined());
+  CHECK(message.type() == Int(32) ||
+        message.as<StringImm>())
+      << "TypeError: AssertStmt message must be an int or string:"
+      << message << "\n";
+
+  NodePtr<AssertStmt> node = make_node<AssertStmt>();
+  node->condition = std::move(condition);
+  node->message = std::move(message);
+  node->body = std::move(body);
+  return Stmt(node);
+}
+
+Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
+  CHECK(body.defined());
+
+  NodePtr<ProducerConsumer> node = make_node<ProducerConsumer>();
+  node->func = std::move(func);
+  node->is_producer = is_producer;
+  node->body = std::move(body);
+  return Stmt(node);
+}
+
+Stmt For::make(Var loop_var,
+               Expr min,
+               Expr extent,
+               ForType for_type,
+               DeviceAPI device_api,
+               Stmt body) {
+  CHECK(min.defined());
+  CHECK(extent.defined());
+  CHECK(min.type().is_scalar());
+  CHECK(extent.type().is_scalar());
+  CHECK(loop_var.type().is_scalar());
+  CHECK(body.defined());
+
+  NodePtr<For> node = make_node<For>();
+  node->loop_var = std::move(loop_var);
+  node->min = std::move(min);
+  node->extent = std::move(extent);
+  node->for_type = for_type;
+  node->device_api = device_api;
+  node->body = std::move(body);
+  return Stmt(node);
+}
+
+Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+  CHECK(value.defined());
+  CHECK(index.defined());
+  CHECK(predicate.defined());
+  CHECK_EQ(value.type().lanes(), index.type().lanes());
+  CHECK_EQ(value.type().lanes(), predicate.type().lanes());
+
+  NodePtr<Store> node = make_node<Store>();
+  node->buffer_var = std::move(buffer_var);
+  node->value = std::move(value);
+  node->index = std::move(index);
+  node->predicate = std::move(predicate);
+  return Stmt(node);
+}
+
+Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+  CHECK(value_index >=0 && value_index < func->num_outputs())
+      << "value index output function return value bound";
+  CHECK(value.defined()) << "Provide of undefined value\n";
+
+  for (size_t i = 0; i < args.size(); ++i) {
+    CHECK(args[i].defined()) << "Provide to undefined location\n";
+  }
+
+  NodePtr<Provide> node = make_node<Provide>();
+  node->func = std::move(func);
+  node->value_index = value_index;
+  node->value = std::move(value);
+  node->args = std::move(args);
+  return Stmt(node);
+}
+
+Stmt Allocate::make(Var buffer_var,
+                    DataType type,
+                    Array<Expr> extents,
+                    Expr condition,
+                    Stmt body,
+                    Expr new_expr,
+                    std::string free_function) {
+    for (size_t i = 0; i < extents.size(); ++i) {
+      CHECK(extents[i].defined());
+      CHECK(extents[i].type().is_scalar());
+    }
+    CHECK(body.defined());
+    CHECK(condition.defined());
+    CHECK(condition.type().is_bool());
+
+    NodePtr<Allocate> node = make_node<Allocate>();
+    node->buffer_var = std::move(buffer_var);
+    node->type = type;
+    node->extents = std::move(extents);
+    node->condition = std::move(condition);
+    node->body = std::move(body);
+    node->new_expr = std::move(new_expr);
+    node->free_function = std::move(free_function);
+    return Stmt(node);
+}
+
+int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
+  int64_t result = 1;
+  for (size_t i = 0; i < extents.size(); ++i) {
+    if (const IntImm *int_size = extents[i].as<IntImm>()) {
+      result *= int_size->value;
+      if (result > std::numeric_limits<int32_t>::max()) {
+        return 0;
+      }
+    } else {
+      return 0;
+    }
+  }
+  return static_cast<int32_t>(result);
+}
+
+Stmt Free::make(Var buffer_var) {
+  NodePtr<Free> node = make_node<Free>();
+  node->buffer_var = buffer_var;
+  return Stmt(node);
+}
+
+Stmt Realize::make(FunctionRef func,
+                   int value_index,
+                   DataType type,
+                   Region bounds,
+                   Expr condition,
+                   Stmt body) {
+  for (size_t i = 0; i < bounds.size(); ++i) {
+    CHECK(bounds[i]->min.defined());
+    CHECK(bounds[i]->extent.defined());
+    CHECK(bounds[i]->min.type().is_scalar());
+    CHECK(bounds[i]->extent.type().is_scalar());
+  }
+  CHECK(body.defined());
+  CHECK(condition.defined());
+  CHECK(condition.type().is_bool());
+
+  NodePtr<Realize> node = make_node<Realize>();
+  node->func = std::move(func);
+  node->value_index = value_index;
+  node->type = type;
+  node->bounds = std::move(bounds);
+  node->condition = std::move(condition);
+  node->body = std::move(body);
+  return Stmt(node);
+}
+
+Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) {
+  for (size_t i = 0; i < bounds.size(); ++i) {
+    CHECK(bounds[i]->min.defined());
+    CHECK(bounds[i]->extent.defined());
+    CHECK(bounds[i]->min.type().is_scalar());
+    CHECK(bounds[i]->extent.type().is_scalar());
+  }
+
+  NodePtr<Prefetch> node = make_node<Prefetch>();
+  node->func = std::move(func);
+  node->value_index = value_index;
+  node->type = type;
+  node->bounds = std::move(bounds);
+  return Stmt(node);
+}
+
+Stmt Block::make(Stmt first, Stmt rest) {
+  CHECK(first.defined());
+  CHECK(rest.defined());
+  NodePtr<Block> node = make_node<Block>();
+
+  // canonicalize.
+  if (const Block* b = first.as<Block>()) {
+    node->first = b->first;
+    node->rest  = Block::make(b->rest, rest);
+  } else {
+    node->first = std::move(first);
+    node->rest = std::move(rest);
+  }
+  return Stmt(node);
+}
+
+Stmt Block::make(const std::vector<Stmt>& stmts) {
+  if (stmts.empty()) {
+    return Stmt();
+  }
+  Stmt result = stmts.back();
+  for (size_t i = stmts.size() - 1; i != 0; --i) {
+    result = Block::make(stmts[i - 1], result);
+  }
+  return result;
+}
+
+Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
+  CHECK(condition.defined());
+  CHECK(then_case.defined());
+  // else_case may be null.
+
+  NodePtr<IfThenElse> node = make_node<IfThenElse>();
+  node->condition = std::move(condition);
+  node->then_case = std::move(then_case);
+  node->else_case = std::move(else_case);
+  return Stmt(node);
+}
+
+Stmt Evaluate::make(Expr value) {
+  CHECK(value.defined());
+
+  NodePtr<Evaluate> node = make_node<Evaluate>();
+  node->value = std::move(value);
+  return Stmt(node);
+}
+
+// Printers
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<UIntImm>([](const UIntImm* op, IRPrinter* p) {
+    p->stream << "(" << op->type << ")" << op->value;
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FloatImm>([](const FloatImm* op, IRPrinter* p) {
+    auto& stream = p->stream;
+    switch (op->type.bits()) {
+      case 64:
+        stream << op->value;
+        break;
+      case 32:
+        stream << op->value << 'f';
+        break;
+      case 16:
+        stream << op->value << 'h';
+        break;
+      default:
+        LOG(FATAL) << "Unknown float type bits=" << op->type.bits();
+    }
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<StringImm>([](const StringImm* op, IRPrinter* p) {
+    auto& stream = p->stream;
+    stream << '"';
+    for (size_t i = 0; i < op->value.size(); ++i) {
+      unsigned char c = op->value[i];
+      if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
+        stream << c;
+      } else {
+        stream << '\\';
+        switch (c) {
+          case '"':
+            stream << '"';
+            break;
+          case '\\':
+            stream << '\\';
+            break;
+          case '\t':
+            stream << 't';
+            break;
+          case '\r':
+            stream << 'r';
+            break;
+          case '\n':
+            stream << 'n';
+            break;
+          default:
+            const char* hex_digits = "0123456789ABCDEF";
+            stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
+        }
+      }
+    }
+    stream << '"';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Cast>([](const Cast* op, IRPrinter* p) {
+    p->stream << op->type << '(';
+    p->Print(op->value);
+    p->stream << ')';
+  })
+.set_dispatch<Variable>([](const Variable* op, IRPrinter* p) {
+    // omit the type
+    // stream << op->name << "." << op->type;
+    p->stream << op->name_hint;
+  })
+.set_dispatch<Add>([](const Add* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " + ";
+    p->Print(op->b);
+    p->stream << ')';
+  })
+.set_dispatch<Sub>([](const Sub* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " - ";
+    p->Print(op->b);
+    p->stream << ')';
+  })
+.set_dispatch<Mul>([](const Mul* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << "*";
+    p->Print(op->b);
+    p->stream << ')';
+  })
+.set_dispatch<Div>([](const Div* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << "/";
+    p->Print(op->b);
+    p->stream << ')';
+  })
+.set_dispatch<Mod>([](const Mod* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " % ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<Min>([](const Min* op, IRPrinter* p) {
+    p->stream << "min(";
+    p->Print(op->a);
+    p->stream << ", ";
+    p->Print(op->b);
+    p->stream << ")";
+})
+.set_dispatch<Max>([](const Max* op, IRPrinter* p) {
+    p->stream << "max(";
+    p->Print(op->a);
+    p->stream << ", ";
+    p->Print(op->b);
+    p->stream << ")";
+})
+.set_dispatch<EQ>([](const EQ* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " == ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<NE>([](const NE* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " != ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<LT>([](const LT* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " < ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<LE>([](const LE* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " <= ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<GT>([](const GT* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " > ";
+    p->Print(op->b);
+    p->stream << ')';
+})
+.set_dispatch<GE>([](const GE* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " >= ";
+    p->Print(op->b);
+    p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<And>([](const And* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " && ";
+    p->Print(op->b);
+    p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Or>([](const Or* op, IRPrinter* p) {
+    p->stream << '(';
+    p->Print(op->a);
+    p->stream << " || ";
+    p->Print(op->b);
+    p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Not>([](const Not* op, IRPrinter* p) {
+    p->stream << '!';
+    p->Print(op->a);
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Select>([](const Select* op, IRPrinter* p) {
+    p->stream << "select(";
+    p->Print(op->condition);
+    p->stream << ", ";
+    p->Print(op->true_value);
+    p->stream << ", ";
+    p->Print(op->false_value);
+    p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Load>([](const Load* op, IRPrinter* p) {
+    p->stream << op->buffer_var << "[";
+    p->Print(op->index);
+    p->stream << "]";
+    if (!is_one(op->predicate)) {
+        p->stream << " if ";
+        p->Print(op->predicate);
+    }
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Ramp>([](const Ramp* op, IRPrinter* p) {
+    p->stream << "ramp(";
+    p->Print(op->base);
+    p->stream << ", ";
+    p->Print(op->stride);
+    p->stream << ", " << op->lanes << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Broadcast>([](const Broadcast* op, IRPrinter* p) {
+    p->stream << "x" << op->lanes << "(";
+    p->Print(op->value);
+    p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Call>([](const Call* op, IRPrinter* p) {
+    p->stream << op->name << "(";
+    for (size_t i = 0; i < op->args.size(); ++i) {
+      p->Print(op->args[i]);
+      if (i < op->args.size() - 1) {
+        p->stream << ", ";
+      }
+    }
+    p->stream << ")";
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Let>([](const Let* op, IRPrinter* p) {
+    p->stream << "(let " << op->var << " = ";
+    p->Print(op->value);
+    p->stream << " in ";
+    p->Print(op->body);
+    p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<LetStmt>([](const LetStmt* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "let " << op->var << " = ";
+    p->Print(op->value);
+    p->stream << '\n';
+    p->Print(op->body);
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<AttrStmt>([](const AttrStmt* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "// attr [";
+    p->Print(op->node);
+    p->stream << "] "
+              << op->attr_key << " = ";
+    p->Print(op->value);
+    p->stream << '\n';
+    p->Print(op->body);
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<AssertStmt>([](const AssertStmt* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "assert(";
+    p->Print(op->condition);
+    p->stream << ", ";
+    p->Print(op->message);
+    p->stream << ")\n";
+    p->Print(op->body);
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ProducerConsumer>([](const ProducerConsumer* op, IRPrinter* p) {
+    if (op->is_producer) {
+      p->PrintIndent();
+      p->stream << "produce " << op->func->func_name() << " {\n";
+      p->indent += 2;
+      p->Print(op->body);
+      p->indent -= 2;
+      p->PrintIndent();
+      p->stream << "}\n";
+    } else {
+      p->Print(op->body);
+    }
+  });
+
+std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
+  switch (type) {
+    case ForType::Serial:
+      out << "for";
+      break;
+    case ForType::Parallel:
+      out << "parallel";
+      break;
+    case ForType::Unrolled:
+      out << "unrolled";
+      break;
+    case ForType::Vectorized:
+      out << "vectorized";
+      break;
+  }
+  return out;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<For>([](const For* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << op->for_type << " (" << op->loop_var << ", ";
+    p->Print(op->min);
+    p->stream << ", ";
+    p->Print(op->extent);
+    p->stream << ") {\n";
+
+    p->indent += 2;
+    p->Print(op->body);
+    p->indent -= 2;
+
+    p->PrintIndent();
+    p->stream << "}\n";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Store>([](const Store* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << op->buffer_var << "[";
+    p->Print(op->index);
+    p->stream << "] = ";
+    p->Print(op->value);
+    if (!is_one(op->predicate)) {
+      p->stream << " if ";
+      p->Print(op->predicate);
+    }
+    p->stream << '\n';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Provide>([](const Provide* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << op->func->func_name() << "(";
+    for (size_t i = 0; i < op->args.size(); ++i) {
+      p->Print(op->args[i]);
+      if (i < op->args.size() - 1) p->stream << ", ";
+    }
+    p->stream << ")";
+    if (op->func->num_outputs() != 1) {
+      p->stream << ".value[" << op->value_index << "]";
+    }
+    p->stream << " =";
+    p->Print(op->value);
+    p->stream << '\n';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Allocate>([](const Allocate* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "allocate " << op->buffer_var << "[" << op->type;
+    for (size_t i = 0; i < op->extents.size(); ++i) {
+      p->stream << " * ";
+      p->Print(op->extents[i]);
+    }
+    p->stream << "]";
+    if (!is_one(op->condition)) {
+      p->stream << " if ";
+      p->Print(op->condition);
+    }
+    p->stream << "\n";
+    p->Print(op->body);
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Free>([](const Free* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "free " << op->buffer_var;
+    p->stream << '\n';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Realize>([](const Realize* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "realize " << op->func->func_name() << "(";
+    for (size_t i = 0; i < op->bounds.size(); ++i) {
+      p->stream << "[";
+      p->Print(op->bounds[i]->min);
+      p->stream << ", ";
+      p->Print(op->bounds[i]->extent);
+      p->stream << "]";
+      if (i < op->bounds.size() - 1) p->stream << ", ";
+    }
+    p->stream << ")";
+    if (op->func->num_outputs() != 1) {
+      p->stream << ".value[" << op->value_index << "]";
+    }
+    if (!is_one(op->condition)) {
+      p->stream << " if ";
+      p->Print(op->condition);
+    }
+    p->stream << " {\n";
+
+    p->indent += 2;
+    p->Print(op->body);
+    p->indent -= 2;
+
+    p->PrintIndent();
+    p->stream << "}\n";
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Prefetch>([](const Prefetch* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->stream << "prefetch " << op->func->func_name() << "(";
+    for (size_t i = 0; i < op->bounds.size(); ++i) {
+      p->stream << "[";
+      p->Print(op->bounds[i]->min);
+      p->stream << ", ";
+      p->Print(op->bounds[i]->extent);
+      p->stream << "]";
+      if (i < op->bounds.size() - 1) p->stream << ", ";
+    }
+    p->stream << ")";
+    if (op->func->num_outputs() != 1) {
+      p->stream << ".value[" << op->value_index << "]";
+    }
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Block>([](const Block* op, IRPrinter* p) {
+    p->Print(op->first);
+    if (op->rest.defined()) p->Print(op->rest);
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IfThenElse>([](const IfThenElse* op, IRPrinter* p) {
+    p->PrintIndent();
+    while (true) {
+      p->stream << "if (" << op->condition << ") {\n";
+      p->indent += 2;
+      p->Print(op->then_case);
+      p->indent -= 2;
+
+      if (!op->else_case.defined()) {
+        break;
+      }
+
+      if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
+        p->PrintIndent();
+        p->stream << "} else ";
+        op = nested_if;
+      } else {
+        p->PrintIndent();
+        p->stream << "} else {\n";
+        p->indent += 2;
+        p->Print(op->else_case);
+        p->indent -= 2;
+        break;
+      }
+    }
+    p->PrintIndent();
+    p->stream << "}\n";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Evaluate>([](const Evaluate* op, IRPrinter* p) {
+    p->PrintIndent();
+    p->Print(op->value);
+    p->stream << "\n";
+  });
+
+template<typename T>
+void PrintList(const Array<T> &exprs, IRPrinter* p) {
+  for (size_t i = 0; i < exprs.size(); ++i) {
+    p->Print(exprs[i]);
+    if (i < exprs.size() - 1) {
+      p->stream << ", ";
+    }
+  }
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Shuffle>([](const Shuffle* op, IRPrinter* p) {
+    p->stream << "shuffle(";
+    PrintList(op->vectors, p);
+    p->stream << ", ";
+    PrintList(op->indices, p);
+    p->stream << ")";
+  });
+
+// Container printer
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ArrayNode>([](const ArrayNode* op, IRPrinter* p) {
+    p->stream << '[';
+    for (size_t i = 0 ; i < op->data.size(); ++i) {
+      if (i != 0) {
+        p->stream << ", ";
+      }
+      p->Print(NodeRef(op->data[i]));
+    }
+    p->stream << ']';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<MapNode>([](const MapNode* op, IRPrinter* p) {
+    p->stream << '{';
+    for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+      if (it != op->data.begin()) {
+        p->stream << ", ";
+      }
+      p->Print(NodeRef(it->first));
+      p->stream << ": ";
+      p->Print(NodeRef(it->second));
+    }
+    p->stream << '}';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<StrMapNode>([](const StrMapNode* op, IRPrinter* p) {
+    p->stream << '{';
+    for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+      if (it != op->data.begin()) {
+        p->stream << ", ";
+      }
+      p->stream << '\"' << it->first << "\": ";
+      p->Print(NodeRef(it->second));
+    }
+    p->stream << '}';
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Reduce>([](const Reduce* op, IRPrinter* p) {
+    p->stream << "reduce(combiner="
+              << op->combiner;
+    p->stream << ", source=" << op->source;
+    p->stream << ", axis=" << op->axis;
+    p->stream << ", where=" << op->condition;
+    p->stream << ", value_index=" << op->value_index;
+    p->stream << ")";
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<CommReducerNode>([](const CommReducerNode* op, IRPrinter* p) {
+    p->stream << "comm_reducer(result=" << op->result
+              << ", lhs=" << op->lhs
+              << ", rhs=" << op->rhs
+              << ", identity_element=" << op->identity_element
+              << ")";
+  });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
+  p->stream << "?";
+});
+
 TVM_REGISTER_NODE_TYPE(CommReducerNode);
 TVM_REGISTER_NODE_TYPE(Reduce);
 TVM_REGISTER_NODE_TYPE(Any);
index c2f80d10f790f93e893e553ff03e2c84c792c2ac..1ac564293c283d18c3c89620d8806d7ed3e9327d 100644 (file)
 #include <tvm/tensor.h>
 #include <tvm/operation.h>
 #include <tvm/tensor_intrin.h>
-#include <ir/IR.h>
 #include <memory>
 
 namespace tvm {
 
 // Tensor
-
 Expr Tensor::operator()(Array<Var> indices) const {
   Array<Expr> arr(indices.begin(), indices.end());
   return operator()(arr);
 }
 
 Expr Tensor::operator()(Array<Expr> indices) const {
-  using HalideIR::Internal::Call;
+  using ir::Call;
   if (ndim() != 0) {
     CHECK_EQ(ndim(), indices.size())
         << "Tensor dimension mismatch in read"
         << "ndim = " << ndim() << ", indices.size=" << indices.size();
   }
-
   auto n = Call::make(
       (*this)->dtype, (*this)->op->name, indices, Call::Halide,
       (*this)->op, (*this)->value_index);
diff --git a/src/node/node.cc b/src/node/node.cc
new file mode 100644 (file)
index 0000000..393f226
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ *  Implementation of Node API
+ * \file node.cc
+ */
+#include <tvm/node/node.h>
+#include <memory>
+#include <atomic>
+#include <mutex>
+#include <unordered_map>
+
+// TODO(tqchen):
+// Think of re-organize and consolidate with object.
+namespace tvm {
+
+namespace {
+// single manager of operator information.
+struct TypeManager {
+  // mutex to avoid registration from multiple threads.
+  // recursive is needed for trigger(which calls UpdateAttrMap)
+  std::mutex mutex;
+  std::atomic<uint32_t> type_counter{0};
+  std::unordered_map<std::string, uint32_t> key2index;
+  std::vector<std::string> index2key;
+  // get singleton of the
+  static TypeManager* Global() {
+    static TypeManager inst;
+    return &inst;
+  }
+};
+}  // namespace
+
+TVM_DLL const bool Node::_DerivedFrom(uint32_t tid) const {
+  static uint32_t tindex = TypeKey2Index(Node::_type_key);
+  return tid == tindex;
+}
+
+// this is slow, usually caller always hold the result in a static variable.
+TVM_DLL uint32_t Node::TypeKey2Index(const char* key) {
+  TypeManager *t = TypeManager::Global();
+  std::lock_guard<std::mutex>(t->mutex);
+  std::string skey = key;
+  auto it = t->key2index.find(skey);
+  if (it != t->key2index.end()) {
+    return it->second;
+  }
+  uint32_t tid = ++(t->type_counter);
+  t->key2index[skey] = tid;
+  t->index2key.push_back(skey);
+  return tid;
+}
+
+TVM_DLL const char* Node::TypeIndex2Key(uint32_t index) {
+  TypeManager *t = TypeManager::Global();
+  std::lock_guard<std::mutex>(t->mutex);
+  CHECK_NE(index, 0);
+  return t->index2key.at(index - 1).c_str();
+}
+}  // namespace tvm
index bb91ed8d4a9f60b6c64c51283214647cfb12f4de..6467ccf00fceadacc49192d0ac6f18a1f5488949 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \brief Compute Op.
  * \file compute_op.cc
  */
@@ -250,7 +249,7 @@ Stmt BaseComputeOpNode::BuildRealize(
     const std::unordered_map<IterVar, Range>& realize_map,
     const Stmt& body) const {
   CHECK_EQ(stage->op.get(), this);
-  HalideIR::Internal::Region bounds;
+  Region bounds;
   for (IterVar iv : this->axis) {
     bounds.push_back(realize_map.at(iv));
   }
index 7023aebe17ada76e0ee4176b1cb37353a428def5..b5ef80ee2312887c9650dfd6c0cf5e04ddb885d4 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \brief External computation rule.
  * \file extern_op.cc
  */
@@ -140,7 +139,7 @@ Stmt ExternOpNode::BuildRealize(
   Stmt realize_body = body;
   for (int k = 0; k < num_outputs(); ++k) {
     Tensor t = stage->op.output(k);
-    HalideIR::Internal::Region bounds;
+    Region bounds;
     for (size_t i = 0; i < t->shape.size(); ++i) {
       bounds.push_back(
           Range::make_by_min_extent(
index 48773c6447493ced6d1ec586d2598ab620904233..e8f3d5bbed6a765117fc4a6d32c99661d09b63ee 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \brief Hybrid computation rule.
  * \file hybrid_op.cc
  */
@@ -28,7 +27,6 @@
 #include <tvm/ir_mutator.h>
 #include <tvm/ir_pass.h>
 #include <tvm/expr_operator.h>
-#include <ir/Expr.h>
 #include <unordered_set>
 #include <string>
 #include <utility>
@@ -143,7 +141,7 @@ Stmt HybridOpNode::BuildRealize(
   Stmt realize_body = body;
   for (int k = 0; k < num_outputs(); ++k) {
     Tensor t = stage->op.output(k);
-    HalideIR::Internal::Region bounds;
+    Region bounds;
     for (size_t i = 0; i < t->shape.size(); ++i) {
       bounds.push_back(
           Range::make_by_min_extent(
@@ -442,7 +440,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
       }
       const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
       return For::make(target->var, range->min, range->extent,
-                       for_type, HalideIR::DeviceAPI::None, body);
+                       for_type, DeviceAPI::None, body);
     }
   };
 
index e354102f0954e063a5f47876ac9d7da7bfdc1d2b..668408a36e757547be63d24899dd18568070cd8a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \brief Utility to make loop nest.
  * \file op_util.cc
  */
index 78f8c82d97dbf7b98bbbed8b0af7cd3cb2f6ef95..1c17d1c31f40789f6794026a89202b4c73e2f217 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \brief Scan Operator.
  * \file scan_op.cc
  */
@@ -264,7 +263,7 @@ Stmt ScanOpNode::BuildRealize(
   for (size_t i = 0; i < update.size(); ++i) {
     Tensor t = stage->op.output(i);
     CHECK_EQ(static_cast<size_t>(t->value_index), i);
-    HalideIR::Internal::Region bounds;
+    Region bounds;
     bounds.push_back(tdom);
     for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
       IterVar sp_ax = this->spatial_axis_[sp_idx];
index 009748e99e341238c860a3f247b98c733174856b..3ad7f8a22124b5473b9a83724cb7e336d66fe27b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file inject_prefetch.cc
  */
 // Inject prefetch op in HalideIR
@@ -34,7 +33,6 @@ namespace ir {
 
 using arith::IntSet;
 using arith::DomainTouched;
-using HalideIR::Internal::Region;
 
 class PrefetchInjector : public IRMutator {
  public:
index d35ea474d5fcfb97acbf8f4d6861b5cdacb9af3f..aa8340e87ac04754873e98a110f922a4e7dce9c8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -347,8 +347,7 @@ class IRDeepCompare :
     return order_;
   }
 
-  int CompareRegion(const HalideIR::Internal::Region& lhs,
-                    const HalideIR::Internal::Region& rhs) {
+  int CompareRegion(const Region& lhs, const Region& rhs) {
     if (order_ != 0) return order_;
     if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
     for (size_t i = 0; i < lhs.size(); ++i) {
index a11847a5265a5e78c7f9e9750eb0ba6c4076402f..ca9af5a996bce4d554680d141a8d08836871937d 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -225,7 +225,7 @@ Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
 
 Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
   IRMutator* m = this;
-  HalideIR::Internal::Region new_bounds;
+  Region new_bounds;
   bool bounds_changed = false;
 
   // Mutate the bounds
@@ -255,7 +255,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
 
 Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
   IRMutator* m = this;
-  HalideIR::Internal::Region new_bounds;
+  Region new_bounds;
   bool bounds_changed = false;
 
   // Mutate the bounds
index b84fa59c9d1f47a3cdac75b2a1a14d26d0f11f51..4c31463eb0865c4cc186c484b4273d2c1f0121ad 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -26,6 +26,7 @@
 #define TVM_PASS_STORAGE_ACCESS_H_
 
 #include <tvm/ir.h>
+#include <tvm/attrs.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_visitor.h>
 #include <vector>
@@ -56,7 +57,7 @@ class StorageAccessVisitor : public IRVisitor {
     /*! \brief The thread index that access this entry */
     Array<IterVar> threads;
     /*! \brief The buffer variable, if any */
-    VarExpr buffer;
+    Var buffer = NullValue<Var>();
     /*! \brief The access data type */
     Type dtype;
     /*! \brief The touched access range */
@@ -66,7 +67,7 @@ class StorageAccessVisitor : public IRVisitor {
     /*! \brief The storage scope */
     StorageScope scope;
     /*! \brief Whether the access is double buffer write */
-    bool double_buffer_write{false};
+    bool double_buffer_write = false;
   };
   /*! \brief Access pattern about a single statement */
   struct StmtEntry {
index 19e7a32e4acf553b390cd1d9d9fc83a75da7f8e6..02d2313f2c64fdc6986a9759fb3f5a1e5565caf4 100644 (file)
@@ -41,7 +41,6 @@
 namespace tvm {
 namespace ir {
 
-using HalideIR::Internal::Region;
 using runtime::StorageRank;
 using runtime::StorageScope;
 using runtime::ThreadScope;
@@ -186,7 +185,7 @@ class StorageFlattener : public IRMutator {
       }
 
       // use small alignment for small arrays
-      int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
+      int32_t const_size = Allocate::constant_allocation_size(shape);
       int align = GetTempAllocaAlignment(op->type, const_size);
       if (skey.tag.length() != 0) {
         MemoryInfo info = GetMemoryInfo(skey.to_string());
@@ -348,14 +347,14 @@ class StorageFlattener : public IRMutator {
     for (int i = starts; i >= 0; --i) {
       if (i < starts) {
         stmt = For::make(
-            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
+            vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
       } else {
         Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
         Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
         Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
         stmt = Evaluate::make(prefetch);
         Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
-        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
+        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
       }
     }
     return stmt;
index f7e106756818ef5def5c303cf08d6fb4d2d96c02..b85fbf964a1e6e8ec4c4ab37c24e08b65a9de0f9 100644 (file)
@@ -77,7 +77,7 @@ struct GraphCodegen {
 
   std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
     std::unordered_map<std::string, tvm::runtime::NDArray> ret;
-    auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
+    auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
     for (auto expr : names) {
       auto key = expr.as<ir::StringImm>()->value;
       ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
@@ -289,7 +289,7 @@ class RelayBuildModule : public runtime::ModuleNode {
         auto op_node = call_node->op.as<OpNode>();
         if (op_node->name == "cast") {
           auto attrs = call_node->attrs.as<CastAttrs>();
-          if (attrs->dtype == HalideIR::Int(32)) {
+          if (attrs->dtype == Int(32)) {
             *rv = true;
           }
         }
index b14448c59166499bed58295a945679607381c5c8..382ae6954a8097f2f58ed0069f10f9e3fa74373c 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file relay/backend/graph_codegen.cc
  * \brief Graph runtime codegen
  */
@@ -238,7 +237,7 @@ class GraphRuntimeCodegen
    * \param shape
    * \return std::vector<int64_t>
    */
-  std::vector<int64_t> _ShapeToJSON(tvm::Array<HalideIR::Expr> shape) {
+  std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
     std::vector<int64_t> ret;
     for (IndexExpr dim : shape) {
       const int64_t* pval = as_const_int(dim);
@@ -623,9 +622,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
       });
     } else if (name == "list_params_name") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        Array<HalideIR::Expr> ret;
+        Array<tvm::Expr> ret;
         for (const auto &kv : this->output_.params) {
-          HalideIR::Expr name = ir::StringImm::make(kv.first);
+          tvm::Expr name = ir::StringImm::make(kv.first);
           ret.push_back(name);
         }
         *rv = ret;
index 9589a0a9f5b88be37d13fcda423520bfcddeefbd..c1fadb82abed06e4e3acf2b0503961af6df6d3d7 100644 (file)
@@ -102,7 +102,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
     p->stream << "Var(" << node->name_hint();
     if (node->type_annotation.defined()) {
       p->stream << ", ty=";
-      p->print(node->type_annotation);
+      p->Print(node->type_annotation);
     }
     p->stream << ")";
   });
index acefaf3e7e5d6956244c2ac7af89ab3b1f874afa..1ee668a25cafbd350a5c28627372fcf2dcb29963 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -150,7 +150,7 @@ RELAY_REGISTER_OP("nn.upsampling")
     CHECK(base_layout == "NCHW" || layout == "NHWC")
       << "unknown layout: " << uattrs->layout;
 
-    Array<HalideIR::Expr> oshape;
+    Array<IndexExpr> oshape;
     if (base_layout == "NCHW") {
       oshape.push_back(out_tt->shape[2]);
       oshape.push_back(out_tt->shape[3]);
index 09aaaf7ad14d30bf0d6697b38cb5d5554f3b2064..506702ad52b5d56095f2bf6a3fc6f806f2a8adac 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  * \file graph.cc
  * \brief Utilities to get information about schedule graph.
  */
@@ -34,7 +33,7 @@ namespace tvm {
 namespace schedule {
 // key to specific tensor dimension.
 struct TensorDimKey {
-  FunctionRef f;
+  ir::FunctionRef f;
   int value_index;
   int dim;
   TensorDimKey() {}
index 7532f4bcd31c5011c14ea951a3d14a1d55826774..7e61479a5a48f92f7aaa352bf3350b1c7d3a3b9e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  * \file schedule_lang.cc
  */
 #include <tvm/schedule.h>
@@ -813,34 +812,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
 })
 .set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
     p->stream << "split(parent=";
-    p->print(op->parent);
+    p->Print(op->parent);
     p->stream << ", outer=";
-    p->print(op->outer);
+    p->Print(op->outer);
     p->stream << ", inner=";
-    p->print(op->inner);
+    p->Print(op->inner);
     p->stream << ')';
 })
 .set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
     p->stream << "split(";
     p->stream << "outer=";
-    p->print(op->outer);
+    p->Print(op->outer);
     p->stream << ", inner=";
-    p->print(op->inner);
+    p->Print(op->inner);
     p->stream << ", fused=";
-    p->print(op->fused);
+    p->Print(op->fused);
     p->stream << ')';
 })
 .set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
     p->stream << "rebase(";
     p->stream << "parent=";
-    p->print(op->parent);
+    p->Print(op->parent);
     p->stream << ", rebased=";
-    p->print(op->rebased);
+    p->Print(op->rebased);
     p->stream << ')';
 })
 .set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
     p->stream << "singleton(";
-    p->print(op->iter);
+    p->Print(op->iter);
     p->stream << ')';
 })
 .set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
index 5ecf4b764a29088d1acc5c72c4e1ff8f3cc91cc5..30972e762314903cf49e18f473cc7dfa748274d1 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 #include <tvm/expr_operator.h>
 
 namespace {
+using namespace tvm;
 using namespace tvm::ir;
-using namespace HalideIR::Internal;
-using namespace HalideIR;
 
 // replace variable to constant
 class IRVar2Const : public IRMutator {
  public:
-  VarExpr var;
+  Var var;
   int int_val;
   Expr Mutate(Expr expr) final {
     static const FMutateExpr& f = IRVar2Const::vtable_expr();
@@ -49,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
 .set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
     IRVar2Const* vm = static_cast<IRVar2Const*>(m);
     if (e.same_as(vm->var)) {
-      return IntImm::make(Int(32), vm->int_val);
+      return Expr(IntImm::make(Int(32), vm->int_val));
     } else {
       return e;
     }
@@ -58,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRVar2Const, vtable_expr)
 }  // namespace
 
 TEST(IRMutator, Basic) {
-  using namespace HalideIR::Internal;
+  using namespace tvm::ir;
   using namespace tvm;
   Var x("x"), y;
   auto z = x + y;
index 0b5168032ba79cd893640cda47bdd7b3d86ca788..dd9ef3ba40d0e4321191aa8e0ed33c7a565c2d6f 100644 (file)
@@ -23,8 +23,8 @@
 
 
 TEST(IRSSA, Convert) {
-  using namespace HalideIR::Internal;
   using namespace tvm;
+  using namespace tvm::ir;
   Var x("x"), y;
   Expr let = Let::make(x, 1, x + 1);
 
@@ -35,7 +35,7 @@ TEST(IRSSA, Convert) {
 }
 
 TEST(IRSSA, Basic) {
-  using namespace HalideIR::Internal;
+  using namespace tvm::ir;
   using namespace tvm;
   Var x("x"), y;
   auto z = Evaluate::make(x + y);
index 814febb74c979d8a51910f01cc1498b334070a75..079be65079ca91dd0db134d4046467fd90d8193c 100644 (file)
@@ -23,7 +23,6 @@
 #include <tvm/ir_pass.h>
 
 TEST(IRVisitor, CountVar) {
-  using namespace HalideIR::Internal;
   using namespace tvm;
   int n_var = 0;
   Var x("x"), y;
index 372a9e17a2c3a80b6aa1a1bdc421cad3d3915e18..7d7143c1382c27c295d35cfdcfed749c1468e677 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file topi/image/resize.h
  * \brief image resize constructors
  */
@@ -55,17 +54,17 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
                                  const Expr max_y, const Expr max_x) {
   auto in_y = indices[2];
   auto yf = tvm::floor(in_y);
-  auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
+  auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
 
-  auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
+  auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
   auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
   auto y_lerp = in_y - yf;
 
   auto in_x = indices[3];
   auto xf = tvm::floor(in_x);
-  auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
+  auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
 
-  auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
+  auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
   auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
   auto x_lerp = in_x - xf;
 
@@ -268,17 +267,17 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
     out_shape, [&](const Array<Var>& indices) {
     auto in_y = indices[1] * y_ratio;
     auto yf = tvm::floor(in_y);
-    auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
+    auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
 
-    auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
+    auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
     auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
     auto y_lerp  = in_y - yf;
 
     auto in_x = indices[2] * x_ratio;
     auto xf = tvm::floor(in_x);
-    auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
+    auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
 
-    auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
+    auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
     auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
     auto x_lerp  = in_x - xf;
 
index 9ea5e1e08c1fd3e77743e0230f32826658f49dc0..43711dadc2730b6b2a15509e55fcfd13fb434c95 100644 (file)
@@ -689,7 +689,7 @@ inline Tensor sequence_mask(const Tensor& data,
         auto bid = out_index[1 - axis];
         len_index.push_back(bid);
         Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
-                                     tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
+                                     tvm::make_const(data->dtype, mask_value), data(out_index));
         return ret;
       }, name, tag);
   return out;