* [INFA][IR] Build and Evolve Low-level IR. Remove dep from HalideIR.
* Update include/tvm/node/ir_functor.h
Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
* Update include/tvm/node/ir_functor.h
Co-Authored-By: Jared Roesch <roeschinc@gmail.com>
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE)
- add_definitions(-DHalide_SHARED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
endif(MSVC)
# add source group
-FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "3rdparty/HalideIR/src/*.cpp" "nnvm/src/*.cc")
-FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h" "3rdparty/HalideIR/src/*.h"
+FILE(GLOB_RECURSE GROUP_SOURCE "src/*.cc" "nnvm/src/*.cc")
+FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h"
"nnvm/src/*.h" "nnvm/include/*.h")
assign_source_group("Source" ${GROUP_SOURCE})
assign_source_group("Include" ${GROUP_INCLUDE})
src/lang/*.cc
src/pass/*.cc
src/op/*.cc
+ src/node/*.cc
src/schedule/*.cc
)
file(GLOB TOPI_SRCS
topi/src/*.cc
)
-file(GLOB_RECURSE HALIDEIR_SRCS
- 3rdparty/HalideIR/src/base/*.cpp
- 3rdparty/HalideIR/src/ir/*.cpp
- 3rdparty/HalideIR/src/tvm/*.cpp
-)
-list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
+
file(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
# Related headers
target_include_directories(
tvm
- PUBLIC "3rdparty/HalideIR/src"
PUBLIC "topi/include")
target_include_directories(
tvm_topi
FILES_MATCHING
PATTERN "*.h"
)
- install(
- DIRECTORY "3rdparty/HalideIR/src/." DESTINATION "include/HalideIR"
- FILES_MATCHING
- PATTERN "*.h"
- )
install(
DIRECTORY "3rdparty/dlpack/include/." DESTINATION "include"
FILES_MATCHING
# More target definitions
if(MSVC)
- target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
- target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
target_compile_definitions(nnvm_compiler PRIVATE -DNNVM_EXPORTS)
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
const std::unordered_map<const Variable*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
-using ExprIntSetMap = std::unordered_map<Expr, IntSet, ExprHash, ExprEqual>;
+using ExprIntSetMap = std::unordered_map<Expr, IntSet, NodeHash, NodeEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
}
template<>
-inline Type NullValue<Type>() {
- return Type(Type::Handle, 0, 0);
+inline DataType NullValue<DataType>() {
+ return DataType(kHandle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
if (!this->defined()) return -1;
const auto axes = operator->()->axes;
for (size_t i = 0; i < axes.size(); ++i) {
- if (axes[i]->var.get()->name_hint == axis.name()) return static_cast<int32_t>(i);
+ if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
}
return -1;
}
bool Contains(const LayoutAxis& axis) const {
if (!defined()) return false;
for (const IterVar var : operator->()->axes) {
- if (var->var.get()->name_hint == axis.name()) {
+ if (var->var->name_hint == axis.name()) {
return true;
}
}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file tvm/dtype.h
+ * \brief Data type used in IR.
+ */
+#ifndef TVM_DTYPE_H_
+#define TVM_DTYPE_H_
+
+#include "runtime/packed_func.h"
+
+namespace tvm {
+class Expr;
+
+/*!
+ * \brief Primitive data types in tvm.
+ */
+class DataType {
+ public:
+ /*! \brief default constructor */
+ DataType() {}
+ /*!
+ * \brief Constructor
+ * \param dtype The DLDataType
+ */
+ explicit DataType(DLDataType dtype)
+ : data_(dtype) {}
+ /*!
+ * \brief Constructor
+ * \param code The type code.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes.
+ */
+ DataType(int code, int bits, int lanes) {
+ data_.code = static_cast<uint8_t>(code);
+ data_.bits = static_cast<uint8_t>(bits);
+ data_.lanes = static_cast<uint16_t>(lanes);
+ }
+ /*! \return The type code. */
+ int code() const {
+ return static_cast<int>(data_.code);
+ }
+ /*! \return number of bits in the data. */
+ int bits() const {
+ return static_cast<int>(data_.bits);
+ }
+ /*! \return number of bytes to store each scalar. */
+ int bytes() const {
+ return (bits() + 7) / 8;
+ }
+ /*! \return number of lanes in the data. */
+ int lanes() const {
+ return static_cast<int>(data_.lanes);
+ }
+ /*! \return whether type is a scalar type. */
+ bool is_scalar() const {
+ return lanes() == 1;
+ }
+ /*! \return whether type is a scalar type. */
+ bool is_bool() const {
+ return code() == kDLUInt && bits() == 1;
+ }
+ /*! \return whether type is a float type. */
+ bool is_float() const {
+ return code() == kDLFloat;
+ }
+ /*! \return whether type is an int type. */
+ bool is_int() const {
+ return code() == kDLInt;
+ }
+ /*! \return whether type is an uint type. */
+ bool is_uint() const {
+ return code() == kDLUInt;
+ }
+ /*! \return whether type is a handle type. */
+ bool is_handle() const {
+ return code() == kHandle;
+ }
+ /*! \return whether type is a vector type. */
+ bool is_vector() const {
+ return lanes() > 1;
+ }
+ /*!
+ * \brief Create a new data type by change lanes to a specified value.
+ * \param lanes The target number of lanes.
+ * \return the result type.
+ */
+ DataType with_lanes(int lanes) const {
+ return DataType(data_.code, data_.bits, lanes);
+ }
+ /*!
+ * \brief Create a new data type by change bits to a specified value.
+ * \param bits The target number of bits.
+ * \return the result type.
+ */
+ DataType with_bits(int bits) const {
+ return DataType(data_.code, bits, data_.lanes);
+ }
+ /*!
+ * \brief Get the scalar version of the type.
+ * \return the result type.
+ */
+ DataType element_of() const {
+ return with_lanes(1);
+ }
+ // operator overloadings
+ bool operator==(const DataType& other) const {
+ return
+ data_.code == other.data_.code &&
+ data_.bits == other.data_.bits &&
+ data_.lanes == other.data_.lanes;
+ }
+ bool operator!=(const DataType& other) const {
+ return !operator==(other);
+ }
+ operator DLDataType () const {
+ return data_;
+ }
+ /*! \return the maximum possible value in this format. */
+ TVM_DLL Expr max() const;
+ /*! \return the minimum possible value in this format. */
+ TVM_DLL Expr min() const;
+
+ private:
+ DLDataType data_;
+};
+
+/*!
+ * \brief Construct an int type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes.
+ * \return The constructed data type.
+ */
+inline DataType Int(int bits, int lanes = 1) {
+ return DataType(kDLInt, bits, lanes);
+}
+
+/*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType UInt(int bits, int lanes = 1) {
+ return DataType(kDLUInt, bits, lanes);
+}
+
+/*!
+ * \brief Construct a bool type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Bool(int lanes = 1) {
+ return UInt(1, lanes);
+}
+
+/*!
+ * \brief Construct an uint type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Float(int bits, int lanes = 1) {
+ return DataType(kDLFloat, bits, lanes);
+}
+
+/*!
+ * \brief Construct a handle type.
+ * \param bits The number of bits in the type.
+ * \param lanes The number of lanes
+ * \return The constructed data type.
+ */
+inline DataType Handle(int bits = 64, int lanes = 1) {
+ return DataType(kHandle, bits, lanes);
+}
+
+/*!
+ * \brief Get the corresponding type of TVMShapeIndex.
+ * \return The type of TVM shape index.
+ */
+inline DataType TVMShapeIndexType() {
+ if (std::is_signed<tvm_index_t>::value) {
+ return Int(sizeof(tvm_index_t) * 8);
+ } else {
+ return UInt(sizeof(tvm_index_t) * 8);
+ }
+}
+
+/*!
+ * \brief Convert DLDataType to DataType.
+ * \param t The original type.
+ * \return The conversion result.
+ */
+inline DataType TVMType2Type(DLDataType t) {
+ return DataType(t.code, t.bits, t.lanes);
+}
+
+/*!
+ * \brief Convert DataType to DataType.
+ * \param t The original type.
+ * \return The conversion result.
+ */
+inline DLDataType Type2TVMType(DataType t) {
+ return t.operator DLDataType();
+}
+
+/*!
+ * \brief Get the number of bytes needed in a vector.
+ * \param dtype The data type.
+ * \return Number of bytes needed.
+ */
+inline int GetVectorBytes(DataType dtype) {
+ int data_bits = dtype.bits() * dtype.lanes();
+ // allow bool to exist
+ if (dtype == Bool()) return 1;
+ CHECK_EQ(data_bits % 8, 0U)
+ << "Need to load/store by multiple of bytes";
+ return data_bits / 8;
+}
+
+// Overload print function.
+inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
+ using namespace tvm::runtime;
+ return os << dtype.operator DLDataType();
+}
+
+// Backward compatibility
+using Type = DataType;
+} // namespace tvm
+#endif // TVM_DTYPE_H_
-
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
-#include <ir/Expr.h>
-#include <ir/IRPrinter.h>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "base.h"
+#include "dtype.h"
+#include "node/container.h"
+#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
namespace tvm {
-using HalideIR::Type;
-using HalideIR::Float;
-using HalideIR::Bool;
-using HalideIR::Int;
-using HalideIR::UInt;
-using HalideIR::Handle;
-using HalideIR::ExprHash;
-using HalideIR::ExprEqual;
-
-using HalideIR::Expr;
-using HalideIR::VarExpr;
-using HalideIR::IR::RangeNode;
-using HalideIR::IR::FunctionRef;
-using HalideIR::IR::FunctionBaseNode;
-using HalideIR::Internal::IntImm;
-using HalideIR::Internal::Stmt;
-using HalideIR::Internal::IRPrinter;
-using HalideIR::Internal::Variable;
-
-inline Type TVMShapeIndexType() {
- if (std::is_signed<tvm_index_t>::value) {
- return Int(sizeof(tvm_index_t) * 8);
- } else {
- return UInt(sizeof(tvm_index_t) * 8);
+/*! \brief Base node of all expressions. */
+class ExprNode : public Node {
+ public:
+ /*! \brief The data type of the expression. */
+ DataType type;
+
+ static constexpr const char* _type_key = "Expr";
+ TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
+};
+
+/*! \brief Container of all expressions. */
+class Expr : public NodeRef {
+ public:
+ Expr() {}
+ explicit Expr(NodePtr<Node> ptr) : NodeRef(ptr) {}
+ /*!
+ * \brief construct from integer.
+ * \param value The value to be constructed.
+ */
+ TVM_DLL Expr(int32_t value); // NOLINT(*)
+ /*!
+ * \brief construct from float.
+ * \param value The value to be constructed.
+ */
+ TVM_DLL Expr(float value); // NOLINT(*)
+ /*!
+ * \brief construct from string.
+ * \param str The value to be constructed.
+ */
+ TVM_DLL Expr(std::string str); // NOLINT(*)
+
+ /*! \return the data type of this expression. */
+ DataType type() const {
+ return static_cast<const ExprNode*>(get())->type;
}
-}
-inline Type TVMType2Type(TVMType t) {
- return Type(static_cast<halideir_type_code_t>(t.code), t.bits, t.lanes);
-}
+ /*! \brief type indicate the container type */
+ using ContainerType = ExprNode;
+};
-inline TVMType Type2TVMType(Type t) {
- TVMType ret;
- ret.code = static_cast<uint8_t>(t.code());
- ret.bits = static_cast<uint8_t>(t.bits());
- ret.lanes = static_cast<uint16_t>(t.lanes());
- return ret;
-}
+/*! \brief Base node of all statements. */
+class StmtNode : public Node {
+ public:
+ static constexpr const char* _type_key = "Stmt";
+ TVM_DECLARE_BASE_NODE_INFO(StmtNode, Node);
+};
-// Get number of bytes considering vector type.
-inline int GetVectorBytes(Type dtype) {
- int data_bits = dtype.bits() * dtype.lanes();
- // allow bool to exist
- if (dtype == Bool()) return 1;
- CHECK_EQ(data_bits % 8, 0U)
- << "Need to load/store by multiple of bytes";
- return data_bits / 8;
-}
+/*! \brief Container of all statements */
+class Stmt : public NodeRef {
+ public:
+ TVM_DEFINE_NODE_REF_METHODS(Stmt, NodeRef, StmtNode);
+};
+
+class Var;
+/*!
+ * \brief A variable node in the IR.
+ *
+ * A vraible is uniquely identified by its address.
+ *
+ * Each variable is only binded once in the following nodes:
+ * - Allocate
+ * - For
+ * - Let
+ * - LetStmt
+ */
+class Variable : public ExprNode {
+ public:
+ /*!
+ * \brief The hint to the variable name.
+ * \note Each variable is uniquely identified by its address.
+ */
+ std::string name_hint;
+
+ static Var make(DataType dtype, std::string name_hint);
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("name", &name_hint);
+ }
+
+ static constexpr const char* _type_key = "Variable";
+ TVM_DECLARE_NODE_TYPE_INFO(Variable, ExprNode);
+};
/*! \brief a named variable in TVM */
-class Var : public HalideIR::VarExpr {
+class Var : public Expr {
public:
- EXPORT explicit Var(const std::string& name_hint = "v",
- Type t = Int(32)) : VarExpr(name_hint, t) {}
- explicit Var(NodePtr<Node> n) : VarExpr(n) {}
- explicit Var(VarExpr v) : VarExpr(v) {}
+ explicit Var(NodePtr<Node> n) : Expr(n) {}
+ TVM_DLL explicit Var(std::string name_hint = "v",
+ Type t = Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
}
+ /*!
+ * \brief Get pointer to the internal value.
+ * \return the corresponding Variable.
+ */
+ const Variable* operator->() const {
+ return get();
+ }
+ /*!
+ * \brief Get pointer to the internal value.
+ * \return the corresponding Variable.
+ */
+ const Variable* get() const {
+ return static_cast<Variable*>(node_.get());
+ }
/*! \brief type indicate the container type */
using ContainerType = Variable;
};
+// Backward compatibility, will be removed later.
+using VarExpr = Var;
+using BaseExprNode = ExprNode;
+using ExprHash = NodeHash;
+using ExprEqual = NodeEqual;
+
+class Integer;
+/*! \brief ExprNode: constant integer. */
+class IntImm : public ExprNode {
+ public:
+ /*! \brief the Internal value. */
+ int64_t value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static Integer make(DataType t, int64_t value);
+
+ static constexpr const char* _type_key = "IntImm";
+ TVM_DECLARE_NODE_TYPE_INFO(IntImm, ExprNode);
+};
/*!
* \brief Container of constant integer (IntImm).
using ContainerType = IntImm;
};
+/*! \brief range over one dimension */
+class RangeNode : public Node {
+ public:
+ /*! \brief beginning of the node */
+ Expr min;
+ /*! \brief the extend of range */
+ Expr extent;
+ /*! \brief constructor */
+ RangeNode() {}
+ RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
-/*! \brief container class of iteration variable. */
-class IterVarNode;
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("min", &min);
+ v->Visit("extent", &extent);
+ }
-/*!
- * \brief same as HalideIR::IR::Range
- * except it provide an constructor with (begin, end)
- *
- * \note Traditional Halide's Range have a constructor with
- * (begin, extent), which does not match the convention in e.g. python.
- * We decided to correct it by removing the constructor in HalideIR,
- * and add it back in TVM's range.
- */
-class Range : public HalideIR::IR::Range {
+ static constexpr const char* _type_key = "Range";
+ TVM_DECLARE_NODE_TYPE_INFO(RangeNode, Node);
+};
+
+/*! \brief Range constainer */
+class Range : public NodeRef {
public:
- /*! \brief constructor */
- Range() {}
- explicit Range(NodePtr<Node> n) : HalideIR::IR::Range(n) {}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
TVM_DLL Range(Expr begin, Expr end);
-
- TVM_DLL static Range make_by_min_extent(Expr min, Expr extent);
+ /*!
+ * \brief construct a new range with min and extent
+ * The corresponding constructor is removed,
+ * because that is counter convention of tradition meaning
+ * of range(begin, end)
+ *
+ * \param min The minimum range.
+ * \param extent The extent of the range.
+ */
+ static Range make_by_min_extent(Expr min, Expr extent);
+ // declare range.
+ TVM_DEFINE_NODE_REF_METHODS(Range, NodeRef, RangeNode);
};
+/*! \brief container class of iteration variable. */
+class IterVarNode;
+
using Region = Array<Range>;
/*!
using Domain = Array<Range>;
-// print functions for expr
-TVM_DLL std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
-
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
-TVM_DLL Var var(const std::string& name_hint, Type t = Int(32));
+TVM_DLL Var var(std::string name_hint, Type t = Int(32));
/*
* \brief Template function to convert Map to unordered_map
}
return ret;
}
+
+// Printer infra.
+/*! \brief A Pretty printer class to print the IR. */
+class IRPrinter {
+ public:
+ /*! \brief The output stream */
+ std::ostream& stream;
+ /*! \brief The indentation level. */
+ int indent{0};
+ explicit IRPrinter(std::ostream& stream) // NOLINT(*)
+ : stream(stream) {}
+
+ /*! \brief The node to be printed. */
+ TVM_DLL void Print(const NodeRef& node);
+ /*! \brief Print indent to the stream */
+ TVM_DLL void PrintIndent();
+ // Allow registration to be printer.
+ using FType = IRFunctor<void(const NodeRef&, IRPrinter *)>;
+ TVM_DLL static FType& vtable();
+};
+
+// default print function for all nodes
+inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
+ IRPrinter(os).Print(n);
+ return os;
+}
} // namespace tvm
namespace std {
* specific language governing permissions and limitations
* under the License.
*/
-
/*!
* \file tvm/ir.h
* \brief Additional high level nodes in the IR
*/
+// Acknowledgement: Most low-level IR nodes originate from Halide.
+
#ifndef TVM_IR_H_
#define TVM_IR_H_
-#include <ir/Expr.h>
-#include <ir/IR.h>
#include <type_traits>
#include <string>
+#include <vector>
+#include <utility>
#include "base.h"
#include "expr.h"
#include "runtime/util.h"
namespace tvm {
namespace ir {
-using HalideIR::Internal::BaseExprNode;
-using HalideIR::Internal::ExprNode;
-using HalideIR::Internal::StmtNode;
-using HalideIR::Internal::IRNodeType;
-using HalideIR::Internal::ForType;
-using HalideIR::DeviceAPI;
+using IntImm = tvm::IntImm;
+using Variable = tvm::Variable;
+
+/*! \brief constant unsigned integer. */
+class UIntImm : public ExprNode {
+ public:
+ /*! \brief The constant value content. */
+ uint64_t value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static Expr make(Type t, uint64_t value);
+
+ static constexpr const char* _type_key = "UIntImm";
+ TVM_DECLARE_NODE_TYPE_INFO(UIntImm, ExprNode);
+};
+
+/*! \brief Floating point constants. */
+class FloatImm : public ExprNode {
+ public:
+ /*! \brief The constant value content. */
+ double value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static Expr make(Type t, double value);
+
+ static constexpr const char* _type_key = "FloatImm";
+ TVM_DECLARE_NODE_TYPE_INFO(FloatImm, ExprNode);
+};
+
+/*! \brief String constants, only used in asserts. */
+class StringImm : public ExprNode {
+ public:
+ /*! \brief The constant value content. */
+ std::string value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL Expr static make(std::string value);
+
+ static constexpr const char* _type_key = "StringImm";
+ TVM_DECLARE_NODE_TYPE_INFO(StringImm, ExprNode);
+};
+
+/*!
+ * \brief Cast value from one data type to another.
+ * \note The lanes of value should keep fixed.
+ */
+class Cast : public ExprNode {
+ public:
+ /*! \brief Original data type. */
+ Expr value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static Expr make(Type t, Expr v);
+
+ static constexpr const char* _type_key = "Cast";
+ TVM_DECLARE_NODE_TYPE_INFO(Cast, ExprNode);
+};
+
+/*!
+ * \brief Base template to implement binary ops.
+ * \tparam T The type of the child class.
+ */
+template<typename T>
+class BinaryOpNode : public ExprNode {
+ public:
+ /*! \brief The left operand. */
+ Expr a;
+ /*! \brief The right operand. */
+ Expr b;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &(this->type));
+ v->Visit("a", &a);
+ v->Visit("b", &b);
+ }
+
+ static Expr make(Expr a, Expr b) {
+ CHECK(a.defined()) << "ValueError: a is undefined\n";
+ CHECK(b.defined()) << "ValueError: b is undefined\n";
+ CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+ NodePtr<T> node = make_node<T>();
+ node->type = a.type();
+ node->a = std::move(a);
+ node->b = std::move(b);
+ return Expr(node);
+ }
+
+ TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
+};
+
+/*! \brief a + b */
+class Add : public BinaryOpNode<Add> {
+ public:
+ static constexpr const char* _type_key = "Add";
+};
+
+/*! \brief a - b */
+class Sub : public BinaryOpNode<Sub> {
+ public:
+ static constexpr const char* _type_key = "Sub";
+};
+
+/*! \brief a * b */
+class Mul : public BinaryOpNode<Mul> {
+ public:
+ static constexpr const char* _type_key = "Mul";
+};
+
+/*!
+ * \brief a / b in the C semnatics.
+ * \note For integer division, C standard uses trunc div.
+ */
+class Div : public BinaryOpNode<Div> {
+ public:
+ static constexpr const char* _type_key = "Div";
+};
+
+/*!
+ * \brief a % b in the C semnatics.
+ * \note For integer division, C standard uses trunc div.
+ */
+class Mod : public BinaryOpNode<Mod> {
+ public:
+ static constexpr const char* _type_key = "Mod";
+};
+
+/*! \brief min(a, b) */
+class Min : public BinaryOpNode<Min> {
+ public:
+ static constexpr const char* _type_key = "Min";
+};
+
+/*! \brief max(a, b) */
+class Max : public BinaryOpNode<Max> {
+ public:
+ static constexpr const char* _type_key = "Max";
+};
+
+/*!
+ * \brief Base template to implement comparison ops.
+ * \tparam T The type of the child class.
+ */
+template<typename T>
+class CmpOpNode : public ExprNode {
+ public:
+ /*! \brief The left operand. */
+ Expr a;
+ /*! \brief The right operand. */
+ Expr b;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &(this->type));
+ v->Visit("a", &a);
+ v->Visit("b", &b);
+ }
+
+ static Expr make(Expr a, Expr b) {
+ CHECK(a.defined()) << "ValueError: a is undefined\n";
+ CHECK(b.defined()) << "ValueError: b is undefined\n";
+ CHECK(a.type() == b.type()) << "TypeError: mismatched types\n";
+ NodePtr<T> node = make_node<T>();
+ node->type = Bool(a.type().lanes());
+ node->a = std::move(a);
+ node->b = std::move(b);
+ return Expr(node);
+ }
+
+ TVM_DECLARE_NODE_TYPE_INFO(T, ExprNode);
+};
+
+/*! \brief a == b */
+class EQ : public CmpOpNode<EQ> {
+ public:
+ static constexpr const char* _type_key = "EQ";
+};
+
+/*! \brief a != b */
+class NE : public CmpOpNode<NE> {
+ public:
+ static constexpr const char* _type_key = "NE";
+};
+
+/*! \brief a < b */
+class LT : public CmpOpNode<LT> {
+ public:
+ static constexpr const char* _type_key = "LT";
+};
+
+/*! \brief a <= b */
+struct LE : public CmpOpNode<LE> {
+ public:
+ static constexpr const char* _type_key = "LE";
+};
+
+/*! \brief a > b */
+class GT : public CmpOpNode<GT> {
+ public:
+ static constexpr const char* _type_key = "GT";
+};
+
+/*! \brief a >= b */
+class GE : public CmpOpNode<GE> {
+ public:
+ static constexpr const char* _type_key = "GE";
+};
+
+/*! \brief a && b */
+class And : public ExprNode {
+ public:
+ /*! \brief The left operand. */
+ Expr a;
+ /*! \brief The right operand. */
+ Expr b;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &(this->type));
+ v->Visit("a", &a);
+ v->Visit("b", &b);
+ }
+
+ TVM_DLL static Expr make(Expr a, Expr b);
+
+ static constexpr const char* _type_key = "And";
+ TVM_DECLARE_NODE_TYPE_INFO(And, ExprNode);
+};
+
+/*! \brief a || b */
+class Or : public ExprNode {
+ public:
+ /*! \brief The left operand. */
+ Expr a;
+ /*! \brief The right operand. */
+ Expr b;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("a", &a);
+ v->Visit("b", &b);
+ }
+
+ TVM_DLL static Expr make(Expr a, Expr b);
+
+ static constexpr const char* _type_key = "Or";
+ TVM_DECLARE_NODE_TYPE_INFO(Or, ExprNode);
+};
+
+/*! \brief !a */
+class Not : public ExprNode {
+ public:
+ /*! \brief The input operand. */
+ Expr a;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("a", &a);
+ }
+
+ TVM_DLL static Expr make(Expr a);
+
+ static constexpr const char* _type_key = "Not";
+ TVM_DECLARE_NODE_TYPE_INFO(Not, ExprNode);
+};
+
+/*!
+ * \brief return true_value if condition is true, otherwise return false_value.
+ * \note Both true_value and false_value could be evaluated
+ * regardless of the condition value.
+ * Do not use it to guard against out of bound access,
+ * please use if_then_else instead.
+ */
+class Select : public ExprNode {
+ public:
+ /*! \brief The condition */
+ Expr condition;
+ /*! \brief value to be returned when condition is true. */
+ Expr true_value;
+ /*! \brief value to be returned when condition is false. */
+ Expr false_value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("condition", &condition);
+ v->Visit("true_value", &true_value);
+ v->Visit("false_value", &false_value);
+ }
+
+ TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
+
+ static constexpr const char* _type_key = "Select";
+ TVM_DECLARE_NODE_TYPE_INFO(Select, ExprNode);
+};
+
+/*!
+ * \brief Load the value from buffer_var.
+ *
+ * Equivalent to ((DType*)buffer_var)[index]
+ * where DType is the type specified by type().element_of().
+ *
+ * For example, if type = float32x3, then the load will corresponds to
+ *
+ * \code
+ *
+ * auto buffer = static_cast<float*>(buffer_var);
+ * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
+ *
+ * \endcode
+ */
+class Load : public ExprNode {
+ public:
+ /*! \brief The buffer variable. */
+ Var buffer_var;
+ /*! \brief The index locations to be loaded. */
+ Expr index;
+ /*! \brief The predicate to mask which lanes would be loaded. */
+ Expr predicate;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("buffer_var", &buffer_var);
+ v->Visit("index", &index);
+ v->Visit("predicate", &predicate);
+ }
+
+ TVM_DLL static Expr make(Type type, Var buffer_var, Expr index, Expr predicate);
+
+ static constexpr const char* _type_key = "Load";
+ TVM_DECLARE_NODE_TYPE_INFO(Load, ExprNode);
+};
+
+/*!
+ * \brief Construct a vector with lanes elements
+ * where its i-th element equals base + i * stride.
+ * This is useful to construct a index for a continuous vector load.
+ *
+ * Examples:
+ * - ramp(0, 1, 3) = [0, 1, 2]
+ * - ramp(1, 2, 4) = [1, 3, 5, 7]
+ */
+class Ramp : public ExprNode {
+ public:
+ /*! \brief The base value. */
+ Expr base;
+ /*! \brief The stride of each step. */
+ Expr stride;
+ /*! \brief Total number of lanes. */
+ int lanes;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("base", &base);
+ v->Visit("stride", &stride);
+ v->Visit("lanes", &lanes);
+ }
+
+ TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
+
+ static constexpr const char* _type_key = "Ramp";
+ TVM_DECLARE_NODE_TYPE_INFO(Ramp, ExprNode);
+};
+
+/*! \brief Create a vector where all the elements are value. */
+class Broadcast : public ExprNode {
+ public:
+ /*! \brief The base value. */
+ Expr value;
+ /*! \brief The numerb of lanes. */
+ int lanes;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("value", &value);
+ v->Visit("lanes", &lanes);
+ }
+
+ TVM_DLL static Expr make(Expr value, int lanes);
+
+ static constexpr const char* _type_key = "Broadcast";
+ TVM_DECLARE_NODE_TYPE_INFO(Broadcast, ExprNode);
+};
+
+/*!
+ * \brief Let binding. Bind var to value then evaluate body.
+ */
+class Let : public ExprNode {
+ public:
+ /*! \brief The variable. */
+ Var var;
+ /*! \brief The value to be binded. */
+ Expr value;
+ /*! \brief The result expression. */
+ Expr body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Expr make(Var var, Expr value, Expr body);
+
+ static constexpr const char* _type_key = "Let";
+ TVM_DECLARE_NODE_TYPE_INFO(Let, ExprNode);
+};
+
+// Call node, represent a function call or a multi-dimensional array load.
+//
+// TODO(tvm-team):
+// Refactor call with more explicit property registrations.
+// rather than calling a string symbol.
+// We should move most information into function itself and remove name.
+
+/*! \brief Base node of internal functions. */
+class FunctionBaseNode : public Node {
+ public:
+ /*! \return the name of the function */
+ virtual const std::string& func_name() const = 0;
+ /*! \return the number of outputs of this function */
+ virtual int num_outputs() const = 0;
+};
+
+/*! \brief reference to a function */
+class FunctionRef : public NodeRef {
+ public:
+ TVM_DEFINE_NODE_REF_METHODS(FunctionRef, NodeRef, FunctionBaseNode);
+};
+
+/*!
+ * \brief Call node.
+ */
+class Call : public ExprNode {
+ public:
+ /*! \brief Possible types of calls. */
+ enum CallType : int {
+ /*! \brief Extern "C" function. */
+ Extern = 0,
+ /*! \brief Extern CXX function. */
+ ExternCPlusPlus = 1,
+ /*! \brief Extern "C" without side-effect. */
+ PureExtern = 2,
+ /*! \brief Halide-style call, evaluates func(args). */
+ Halide = 3,
+ /*! \brief Intrinsic functions. */
+ Intrinsic = 4,
+ /*! \brief Intrinsic functions that are pure. */
+ PureIntrinsic = 5
+ };
+ /*! \brief The name of the function/intrinsic. */
+ std::string name;
+ /*! \brief The arguments. */
+ Array<Expr> args;
+ /*! \brief Type of calls. */
+ CallType call_type;
+ /*! \brief The function to be called. */
+ FunctionRef func;
+ /*! \brief The output value index if func's value is a tuple. */
+ int value_index{0};
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("dtype", &type);
+ v->Visit("name", &name);
+ v->Visit("args", &args);
+ v->Visit("call_type", &call_type);
+ v->Visit("func", &func);
+ v->Visit("value_index", &value_index);
+ }
+
+ TVM_DLL static Expr make(Type type,
+ std::string name,
+ Array<Expr> args,
+ CallType call_type,
+ FunctionRef func = FunctionRef(),
+ int value_index = 0);
+
+ /*! \return Whether call node is pure. */
+ bool is_pure() const {
+ return (call_type == PureExtern ||
+ call_type == PureIntrinsic ||
+ call_type == Halide);
+ }
+
+ /*!
+ * \return Whether call node corresponds to a defined intrinsic.
+ * \param intrin_name The name of the intrinsic.
+ */
+ bool is_intrinsic(const char* intrin_name) const {
+ return
+ ((call_type == Intrinsic ||
+ call_type == PureIntrinsic) &&
+ name == intrin_name);
+ }
+
+ static constexpr const char* _type_key = "Call";
+ TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
+
+ // Build-in intrinsics
+ static constexpr const char* reinterpret = "reinterpret";
+ static constexpr const char* bitwise_and = "bitwise_and";
+ static constexpr const char* bitwise_not = "bitwise_not";
+ static constexpr const char* bitwise_xor = "bitwise_xor";
+ static constexpr const char* bitwise_or = "bitwise_or";
+ static constexpr const char* shift_left = "shift_left";
+ static constexpr const char* shift_right = "shift_right";
+ static constexpr const char* popcount = "popcount";
+ static constexpr const char* likely = "likely";
+ static constexpr const char* glsl_texture_store = "glsl_texture_store";
+ static constexpr const char* prefetch = "prefetch";
+};
+
+/*!
+ * \brief Shuffle instruction.
+ * vec = concat(vectors)
+ * result = (vec[indices[0]], vec[indices[1]] ...)
+ */
+class Shuffle : public ExprNode {
+ public:
+ /*! \brief the input vectors. */
+ Array<Expr> vectors;
+ /*! \brief The indices of each element. */
+ Array<Expr> indices;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("vectors", &vectors);
+ v->Visit("indices", &indices);
+ }
+
+ TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
+ TVM_DLL static Expr make_concat(Array<Expr> vectors);
+ TVM_DLL static Expr make_extract_element(Expr vector, int index);
-// Node container for CommReducer
-struct CommReducerNode;
+ static constexpr const char* _type_key = "Shuffle";
+ TVM_DECLARE_NODE_TYPE_INFO(Shuffle, ExprNode);
+};
-struct CommReducer : public NodeRef {
+// Reduce operator
+class CommReducerNode;
+
+class CommReducer : public NodeRef {
+ public:
CommReducer() {}
explicit CommReducer(NodePtr<Node> n) : NodeRef(n) {}
/*!
* \brief A commutative reducer node to represent a commutative
* binary operator with identity element
*/
-struct CommReducerNode : public Node {
+class CommReducerNode : public Node {
+ public:
/*! \brief The left argument of reducer */
Array<Var> lhs;
/*! \brief The right argument of reducer */
/*! \brief Function call operator to combine a and b */
Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
- TVM_DLL static CommReducer make(Array<Var> lhs, Array<Var> rhs,
- Array<Expr> result, Array<Expr> identity_element);
+ TVM_DLL static CommReducer make(Array<Var> lhs,
+ Array<Var> rhs,
+ Array<Expr> result,
+ Array<Expr> identity_element);
void VisitAttrs(AttrVisitor* v) final {
v->Visit("lhs", &lhs);
}
/*! \brief Reduction operator operator */
-struct Reduce : public ExprNode<Reduce> {
+class Reduce : public ExprNode {
+ public:
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
}
- static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
+
static constexpr const char* _type_key = "Reduce";
+ TVM_DECLARE_NODE_TYPE_INFO(Reduce, ExprNode);
};
/*! \brief Any shape. */
-struct Any : public ExprNode<Any> {
+class Any : public ExprNode {
+ public:
+ void VisitAttrs(AttrVisitor* v) final {}
+
TVM_DLL static Expr make();
- void VisitAttrs(AttrVisitor* v) final {}
- static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
static constexpr const char* _type_key = "Any";
+ TVM_DECLARE_NODE_TYPE_INFO(Any, ExprNode);
+};
+
+// Statements
+/*!
+ * \brief Let binding, bind var to value, then run body.
+ */
+class LetStmt : public StmtNode {
+ public:
+ /*! \brief The variable. */
+ Var var;
+ /*! \brief The value to be binded. */
+ Expr value;
+ /*! \brief The body block. */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
+
+ static constexpr const char* _type_key = "LetStmt";
+ TVM_DECLARE_NODE_TYPE_INFO(LetStmt, StmtNode);
+};
+
+/*!
+ * \brief Define certain auxiliary attribute for the body to be a symbolic value.
+ * This provide auxiliary information for IR passes that transforms body.
+ *
+ * In terms of effect, this is equivalent to Block(Evaluate(value), body).
+ *
+ * Examples of possible usage:
+ * - Bound of function, variables.
+ * - Hint which block corresponds to a parallel region.
+ */
+class AttrStmt : public StmtNode {
+ public:
+ /*! \brief this is attribute about certain node */
+ NodeRef node;
+ /*! \brief the type key of the attribute */
+ std::string attr_key;
+ /*! \brief The attribute value, value is well defined at current scope. */
+ Expr value;
+ /*! \brief The body statement to be executed */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("node", &node);
+ v->Visit("attr_key", &attr_key);
+ v->Visit("value", &value);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(NodeRef node,
+ std::string type_key,
+ Expr value,
+ Stmt body);
+
+ static constexpr const char* _type_key = "AttrStmt";
+ TVM_DECLARE_NODE_TYPE_INFO(AttrStmt, StmtNode);
+};
+
+/*!
+ * \brief Assert condition, if an error occurs, return the error message.
+ */
+class AssertStmt : public StmtNode {
+ public:
+ /*! \brief Condition to be checked. */
+ Expr condition;
+ /*! \brief Error message when assertion failed. */
+ Expr message;
+ /*!
+ * \brief Body which this assertion holds true.
+ * Will be executed after the assertion.
+ */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("condition", &condition);
+ v->Visit("message", &message);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
+
+ static constexpr const char* _type_key = "AssertStmt";
+ TVM_DECLARE_NODE_TYPE_INFO(AssertStmt, StmtNode);
+};
+
+// TODO(tvm-team): consider consolidate with AttrStmt.
+/*! \brief annotation node of producer/consumer relation. */
+class ProducerConsumer : public StmtNode {
+ public:
+ /*! \brief The corresponding tensor. */
+ FunctionRef func;
+ /*! \brief Whether the relation is producer. */
+ bool is_producer;
+ /*! \brief Body to be executed. */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("func", &func);
+ v->Visit("is_producer", &is_producer);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
+
+ static constexpr const char* _type_key = "ProducerConsumer";
+ TVM_DECLARE_NODE_TYPE_INFO(ProducerConsumer, StmtNode);
+};
+
+/*!
+ * \brief Store value to the buffer.
+ *
+ * Equivalent to ((DType*)buffer_var)[index] = value.
+ * where DType is the type specified by type().element_of().
+ *
+ * For example, if type = float32x3, then the load will corresponds to
+ *
+ * \code
+ *
+ * auto buffer = static_cast<float*>(buffer_var);
+ * buffer[index.v0] = value.v0;
+ * buffer[index.v1] = value.v1;
+ * buffer[index.v2] = value.v2;
+ *
+ * \endcode
+ * \sa Load
+ */
+class Store : public StmtNode {
+ public:
+ /*! \brief The buffer variable. */
+ Var buffer_var;
+ /*! \brief The value to be stored. */
+ Expr value;
+ /*! \brief The index locations to be stored. */
+ Expr index;
+ /*! \brief The predicate to mask which lanes would be stored. */
+ Expr predicate;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("buffer_var", &buffer_var);
+ v->Visit("value", &value);
+ v->Visit("index", &index);
+ v->Visit("predicate", &predicate);
+ }
+
+ TVM_DLL static Stmt make(Var buffer_var,
+ Expr value,
+ Expr index,
+ Expr predicate);
+
+ static constexpr const char* _type_key = "Store";
+ TVM_DECLARE_NODE_TYPE_INFO(Store, StmtNode);
+};
+
+/*!
+ * \brief Store value into mult-dimensional array defined by func.
+ */
+class Provide : public StmtNode {
+ public:
+ /*! \brief The function to be updated. */
+ FunctionRef func;
+ /*! \brief The output value index if func's value is a tuple. */
+ int value_index{0};
+ /*! \brief The value to be stored. */
+ Expr value;
+ /*! \brief The index arguments of the function. */
+ Array<Expr> args;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("func", &func);
+ v->Visit("value_index", &value_index);
+ v->Visit("value", &value);
+ v->Visit("args", &args);
+ }
+
+ TVM_DLL static Stmt make(FunctionRef func,
+ int value_index,
+ Expr value,
+ Array<Expr> args);
+
+ static constexpr const char* _type_key = "Provide";
+ TVM_DECLARE_NODE_TYPE_INFO(Provide, StmtNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class Allocate : public StmtNode {
+ public:
+ /*! \brief The buffer variable. */
+ Var buffer_var;
+ /*! \brief The type of the buffer. */
+ DataType type;
+ /*! \brief The extents of the buffer. */
+ Array<Expr> extents;
+ /*! \brief Only allocate buffer when condition is satisfied. */
+ Expr condition;
+ /*! \brief The body to be executed. */
+ Stmt body;
+ // The following two fields are deprecated
+ // kept for backward compatibility and will be refactored later.
+ Expr new_expr;
+ std::string free_function;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("buffer_var", &buffer_var);
+ v->Visit("dtype", &type);
+ v->Visit("extents", &extents);
+ v->Visit("condition", &condition);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(Var buffer_var,
+ DataType type,
+ Array<Expr> extents,
+ Expr condition,
+ Stmt body,
+ Expr new_expr = Expr(),
+ std::string free_function = std::string());
+
+ /*!
+ * \brief If the buffer size is constant, return the size.
+ * Otherwise return 0.
+ * \return The result.
+ */
+ int32_t constant_allocation_size() const {
+ return constant_allocation_size(extents);
+ }
+ /*!
+ * \brief If the buffer size is constant, return the size.
+ * Otherwise return 0.
+ * \param extents The extents of the buffer.
+ * \return The result.
+ */
+ TVM_DLL static int32_t constant_allocation_size(
+ const Array<Expr>& extents);
+
+ static constexpr const char* _type_key = "Allocate";
+ TVM_DECLARE_NODE_TYPE_INFO(Allocate, StmtNode);
+};
+
+/*! \brief Free the resources in the buffer before the scope ends. */
+class Free : public StmtNode {
+ public:
+ /*! \brief The buffer variable. */
+ Var buffer_var;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("buffer_var", &buffer_var);
+ }
+
+ TVM_DLL static Stmt make(Var buffer_var);
+
+ static constexpr const char* _type_key = "Free";
+ TVM_DECLARE_NODE_TYPE_INFO(Free, StmtNode);
+};
+
+/*!
+ * \brief Annotate the bounds where func need to be written and read in body.
+ * We will need to allocate space for the corresponding regions.
+ */
+class Realize : public StmtNode {
+ public:
+ /*! \brief The function to be realized. */
+ FunctionRef func;
+ /*! \brief The output value index if func's value is a tuple. */
+ int value_index;
+ /*! \brief The data type of the array. */
+ DataType type;
+ /*! \brief Bounds to be realized. */
+ Region bounds;
+ /*! \brief Only realize if condition holds. */
+ Expr condition;
+ /*! \brief The body of realization. */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("func", &func);
+ v->Visit("value_index", &value_index);
+ v->Visit("dtype", &type);
+ v->Visit("bounds", &bounds);
+ v->Visit("condition", &condition);
+ v->Visit("body", &body);
+ }
+
+ TVM_DLL static Stmt make(FunctionRef func,
+ int value_index,
+ DataType type,
+ Region bounds,
+ Expr condition,
+ Stmt body);
+
+ static constexpr const char* _type_key = "Realize";
+ TVM_DECLARE_NODE_TYPE_INFO(Realize, StmtNode);
+};
+
+/*!
+ * \brief A sequence of statements.
+ */
+class Block : public StmtNode {
+ public:
+ /*! \brief The first statement. */
+ Stmt first;
+ /*! \brief The restof statments. */
+ Stmt rest;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("first", &first);
+ v->Visit("rest", &rest);
+ }
+
+ TVM_DLL static Stmt make(Stmt first, Stmt rest);
+ TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
+
+ static constexpr const char* _type_key = "Block";
+ TVM_DECLARE_NODE_TYPE_INFO(Block, StmtNode);
+};
+
+/*!
+ * \brief IfThenElse statment.
+ */
+class IfThenElse : public StmtNode {
+ public:
+ /*! \brief The condition. */
+ Expr condition;
+ /*! \brief The branch to be executed when condition is true. */
+ Stmt then_case;
+ /*! \brief The branch to be executed when condition is false, can be null. */
+ Stmt else_case;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("condition", &condition);
+ v->Visit("then_case", &then_case);
+ v->Visit("else_case", &else_case);
+ }
+
+ TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
+
+ static constexpr const char* _type_key = "IfThenElse";
+ TVM_DECLARE_NODE_TYPE_INFO(IfThenElse, StmtNode);
+};
+
+/*!
+ * \brief Evaluates an expression.
+ * This is mostly used for putting a Call node into Stmt.
+ *
+ * If value do not have side-effect, this node can be safely removed.
+ */
+class Evaluate : public StmtNode {
+ public:
+ /*! \brief The expression to be evaluated. */
+ Expr value;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("value", &value);
+ }
+
+ TVM_DLL static Stmt make(Expr v);
+
+ static constexpr const char* _type_key = "Evaluate";
+ TVM_DECLARE_NODE_TYPE_INFO(Evaluate, StmtNode);
+};
+
+/*! \brief Additional annotation of for loop. */
+enum class ForType : int {
+ /*! \brief serial execution. */
+ Serial = 0,
+ /*! \brief parallel execution on CPU. */
+ Parallel = 1,
+ /*! \brief Vector SIMD loop annotaion. */
+ Vectorized = 2,
+ /*! \brief Unroll annotation. */
+ Unrolled = 3
+};
+
+// Kevice api of for loop
+// kept for backward compatibility
+// consider refactor and remove later.
+enum class DeviceAPI: int {
+ None = 0
+};
+
+/*!
+ * \brief A for loop, with poissible type annotations.
+ *
+ * \code
+ *
+ * for (loop_var = min; loop_var < min + extent; ++loop_var) {
+ * // body
+ * }
+ * \endcode
+ */
+class For : public StmtNode {
+ public:
+ /*! \brief The loop variable. */
+ Var loop_var;
+ /*! \brief The minimum value of iteration. */
+ Expr min;
+ /*! \brief The extent of the iteration. */
+ Expr extent;
+ /*! \brief The type of the for loop. */
+ ForType for_type;
+ /*!
+ * \brief Deprecated, reserved for backward compatibility.
+ * Consider refactor and remove later.
+ */
+ DeviceAPI device_api;
+ /*! \brief The body of the for loop. */
+ Stmt body;
+
+ TVM_DLL static Stmt make(Var loop_var,
+ Expr min,
+ Expr extent,
+ ForType for_type,
+ DeviceAPI device_api,
+ Stmt body);
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("loop_var", &loop_var);
+ v->Visit("min", &min);
+ v->Visit("extent", &extent);
+ v->Visit("for_type", &for_type);
+ v->Visit("device_api", &device_api);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "For";
+ TVM_DECLARE_NODE_TYPE_INFO(For, StmtNode);
+};
+
+/*!
+ * \brief A prefetch hint of func.
+ */
+class Prefetch : public StmtNode {
+ public:
+ /*! \brief The function to be prefetched. */
+ FunctionRef func;
+ /*! \brief The output value index if func's value is a tuple. */
+ int value_index;
+ /*! \brief The data type of the array. */
+ DataType type;
+ /*! \brief Bounds to be prefetched. */
+ Region bounds;
+
+ void VisitAttrs(AttrVisitor* v) final {
+ v->Visit("func", &func);
+ v->Visit("value_index", &value_index);
+ v->Visit("type", &type);
+ v->Visit("bounds", &bounds);
+ }
+
+ TVM_DLL static Stmt make(FunctionRef func,
+ int value_index,
+ DataType type,
+ Region bounds);
+
+ static constexpr const char* _type_key = "Prefetch";
+ TVM_DECLARE_NODE_TYPE_INFO(Prefetch, StmtNode);
};
/*!
} // namespace intrinsic
-// Reuse IR node defintiion from HalideIR
-using HalideIR::Internal::IntImm;
-using HalideIR::Internal::UIntImm;
-using HalideIR::Internal::FloatImm;
-using HalideIR::Internal::StringImm;
-using HalideIR::Internal::Cast;
-using HalideIR::Internal::Add;
-using HalideIR::Internal::Sub;
-using HalideIR::Internal::Mul;
-using HalideIR::Internal::Div;
-using HalideIR::Internal::Mod;
-using HalideIR::Internal::Min;
-using HalideIR::Internal::Max;
-using HalideIR::Internal::EQ;
-using HalideIR::Internal::NE;
-using HalideIR::Internal::LT;
-using HalideIR::Internal::LE;
-using HalideIR::Internal::GT;
-using HalideIR::Internal::GE;
-using HalideIR::Internal::And;
-using HalideIR::Internal::Or;
-using HalideIR::Internal::Not;
-using HalideIR::Internal::Select;
-using HalideIR::Internal::Load;
-using HalideIR::Internal::Ramp;
-using HalideIR::Internal::Broadcast;
-using HalideIR::Internal::Call;
-using HalideIR::Internal::Let;
-using HalideIR::Internal::LetStmt;
-using HalideIR::Internal::AttrStmt;
-using HalideIR::Internal::AssertStmt;
-using HalideIR::Internal::ProducerConsumer;
-using HalideIR::Internal::For;
-using HalideIR::Internal::Store;
-using HalideIR::Internal::Provide;
-using HalideIR::Internal::Allocate;
-using HalideIR::Internal::Free;
-using HalideIR::Internal::Realize;
-using HalideIR::Internal::Prefetch;
-using HalideIR::Internal::Block;
-using HalideIR::Internal::IfThenElse;
-using HalideIR::Internal::Evaluate;
-using HalideIR::Internal::Shuffle;
-
/*!
* \brief Create a type annotation expression
* \param dtype The data type
"type_annotation", {},
ir::Call::PureIntrinsic);
}
+
+// overload printing of for type.
+TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
+
} // namespace ir
} // namespace tvm
#define TVM_IR_MUTATOR_H_
#include <unordered_map>
+#include <utility>
#include "expr.h"
#include "ir.h"
#include "tvm/node/ir_functor.h"
#ifndef TVM_LOWERED_FUNC_H_
#define TVM_LOWERED_FUNC_H_
-#include <ir/FunctionBase.h>
#include <string>
#include "base.h"
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
-class LoweredFunc : public FunctionRef {
+class LoweredFunc : public ir::FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(NodePtr<Node> n) : FunctionRef(n) {}
};
/*! \brief Node container of LoweredFunc */
-class LoweredFuncNode : public FunctionBaseNode {
+class LoweredFuncNode : public ir::FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/container.h
+ * \brief Array/Map container in the DSL graph.
+ */
+#ifndef TVM_NODE_CONTAINER_H_
+#define TVM_NODE_CONTAINER_H_
+
+#include <type_traits>
+#include <vector>
+#include <initializer_list>
+#include <unordered_map>
+#include <utility>
+#include <string>
+#include "node.h"
+#include "memory.h"
+
+namespace tvm {
+
+/*! \brief array node content in array */
+class ArrayNode : public Node {
+ public:
+ /*! \brief the data content */
+ std::vector<NodePtr<Node> > data;
+
+ void VisitAttrs(AttrVisitor* visitor) final {
+ // Visitor to array have no effect.
+ }
+
+ static constexpr const char* _type_key = "Array";
+ TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node);
+};
+
+/*! \brief map node content */
+class MapNode : public Node {
+ public:
+ void VisitAttrs(AttrVisitor* visitor) final {
+ // Visitor to map have no effect.
+ }
+ // hash function
+ struct Hash {
+ size_t operator()(const NodePtr<Node>& n) const {
+ return std::hash<Node*>()(n.get());
+ }
+ };
+ // comparator
+ struct Equal {
+ bool operator()(
+ const NodePtr<Node>& a,
+ const NodePtr<Node>& b) const {
+ return a.get() == b.get();
+ }
+ };
+
+ /*! \brief The corresponding conatiner type */
+ using ContainerType = std::unordered_map<
+ NodePtr<Node>,
+ NodePtr<Node>,
+ Hash, Equal>;
+
+ /*! \brief the data content */
+ ContainerType data;
+
+ static constexpr const char* _type_key = "Map";
+ TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node);
+};
+
+
+/*! \brief specialized map node with string as key */
+class StrMapNode : public Node {
+ public:
+ void VisitAttrs(AttrVisitor* visitor) final {
+ // Visitor to map have no effect.
+ }
+ /*! \brief The corresponding conatiner type */
+ using ContainerType = std::unordered_map<
+ std::string,
+ NodePtr<Node> >;
+
+ /*! \brief the data content */
+ ContainerType data;
+
+ static constexpr const char* _type_key = "StrMap";
+ TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node);
+};
+
+/*!
+ * \brief iterator adapter that adapts TIter to return another type.
+ * \tparam Converter a struct that contains converting function
+ * \tparam TIter the content iterator type.
+ */
+template<typename Converter,
+ typename TIter>
+class IterAdapter {
+ public:
+ explicit IterAdapter(TIter iter) : iter_(iter) {}
+ inline IterAdapter& operator++() { // NOLINT(*)
+ ++iter_;
+ return *this;
+ }
+ inline IterAdapter& operator++(int) { // NOLINT(*)
+ ++iter_;
+ return *this;
+ }
+ inline IterAdapter operator+(int offset) const { // NOLINT(*)
+ return IterAdapter(iter_ + offset);
+ }
+ inline bool operator==(IterAdapter other) const {
+ return iter_ == other.iter_;
+ }
+ inline bool operator!=(IterAdapter other) const {
+ return !(*this == other);
+ }
+ inline const typename Converter::ResultType operator*() const {
+ return Converter::convert(*iter_);
+ }
+
+ private:
+ TIter iter_;
+};
+
+/*!
+ * \brief Array container of NodeRef in DSL graph.
+ * Array implements copy on write semantics, which means array is mutable
+ * but copy will happen when array is referenced in more than two places.
+ *
+ * operator[] only provide const acces, use Set to mutate the content.
+ * \tparam T The content NodeRef type.
+ */
+template<typename T,
+ typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
+class Array : public NodeRef {
+ public:
+ /*!
+ * \brief default constructor
+ */
+ Array() {
+ node_ = make_node<ArrayNode>();
+ }
+ /*!
+ * \brief move constructor
+ * \param other source
+ */
+ Array(Array<T> && other) { // NOLINT(*)
+ node_ = std::move(other.node_);
+ }
+ /*!
+ * \brief copy constructor
+ * \param other source
+ */
+ Array(const Array<T> &other) : NodeRef(other.node_) { // NOLINT(*)
+ }
+ /*!
+ * \brief constructor from pointer
+ * \param n the container pointer
+ */
+ explicit Array(NodePtr<Node> n) : NodeRef(n) {}
+ /*!
+ * \brief constructor from iterator
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template<typename IterType>
+ Array(IterType begin, IterType end) {
+ assign(begin, end);
+ }
+ /*!
+ * \brief constructor from initializer list
+ * \param init The initalizer list
+ */
+ Array(std::initializer_list<T> init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+ /*!
+ * \brief constructor from vector
+ * \param init The vector
+ */
+ Array(const std::vector<T>& init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+ /*!
+ * \brief Constructs a container with n elements. Each element is a copy of val
+ * \param n The size of the container
+ * \param val The init value
+ */
+ explicit Array(size_t n, const T& val) {
+ auto tmp_node = make_node<ArrayNode>();
+ for (size_t i = 0; i < n; ++i) {
+ tmp_node->data.push_back(val.node_);
+ }
+ node_ = std::move(tmp_node);
+ }
+ /*!
+ * \brief move assign operator
+ * \param other The source of assignment
+ * \return reference to self.
+ */
+ Array<T>& operator=(Array<T> && other) {
+ node_ = std::move(other.node_);
+ return *this;
+ }
+ /*!
+ * \brief copy assign operator
+ * \param other The source of assignment
+ * \return reference to self.
+ */
+ Array<T>& operator=(const Array<T> & other) {
+ node_ = other.node_;
+ return *this;
+ }
+ /*!
+ * \brief reset the array to content from iterator.
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template<typename IterType>
+ void assign(IterType begin, IterType end) {
+ auto n = make_node<ArrayNode>();
+ for (IterType it = begin; it != end; ++it) {
+ n->data.push_back((*it).node_);
+ }
+ node_ = std::move(n);
+ }
+ /*!
+ * \brief Read i-th element from array.
+ * \param i The index
+ * \return the i-th element.
+ */
+ inline const T operator[](size_t i) const {
+ return T(static_cast<const ArrayNode*>(node_.get())->data[i]);
+ }
+ /*! \return The size of the array */
+ inline size_t size() const {
+ if (node_.get() == nullptr) return 0;
+ return static_cast<const ArrayNode*>(node_.get())->data.size();
+ }
+ /*!
+ * \brief copy on write semantics
+ * Do nothing if current handle is the unique copy of the array.
+ * Otherwise make a new copy of the array to ensure the current handle
+ * hold a unique copy.
+ *
+ * \return Handle to the internal node container(which ganrantees to be unique)
+ */
+ inline ArrayNode* CopyOnWrite() {
+ if (node_.get() == nullptr || !node_.unique()) {
+ NodePtr<ArrayNode> n = make_node<ArrayNode>();
+ n->data = static_cast<ArrayNode*>(node_.get())->data;
+ NodePtr<Node>(std::move(n)).swap(node_);
+ }
+ return static_cast<ArrayNode*>(node_.get());
+ }
+ /*!
+ * \brief push a new item to the back of the list
+ * \param item The item to be pushed.
+ */
+ inline void push_back(const T& item) {
+ ArrayNode* n = this->CopyOnWrite();
+ n->data.push_back(item.node_);
+ }
+ /*!
+ * \brief set i-th element of the array.
+ * \param i The index
+ * \param value The value to be setted.
+ */
+ inline void Set(size_t i, const T& value) {
+ ArrayNode* n = this->CopyOnWrite();
+ n->data[i] = value.node_;
+ }
+ /*! \return whether array is empty */
+ inline bool empty() const {
+ return size() == 0;
+ }
+ /*! \brief specify container node */
+ using ContainerType = ArrayNode;
+
+ struct Ptr2NodeRef {
+ using ResultType = T;
+ static inline T convert(const NodePtr<Node>& n) {
+ return T(n);
+ }
+ };
+ using iterator = IterAdapter<Ptr2NodeRef,
+ std::vector<NodePtr<Node> >::const_iterator>;
+
+ using reverse_iterator = IterAdapter<
+ Ptr2NodeRef,
+ std::vector<NodePtr<Node> >::const_reverse_iterator>;
+
+ /*! \return begin iterator */
+ inline iterator begin() const {
+ return iterator(static_cast<const ArrayNode*>(node_.get())->data.begin());
+ }
+ /*! \return end iterator */
+ inline iterator end() const {
+ return iterator(static_cast<const ArrayNode*>(node_.get())->data.end());
+ }
+ /*! \return rbegin iterator */
+ inline reverse_iterator rbegin() const {
+ return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rbegin());
+ }
+ /*! \return rend iterator */
+ inline reverse_iterator rend() const {
+ return reverse_iterator(static_cast<const ArrayNode*>(node_.get())->data.rend());
+ }
+};
+
+/*!
+ * \brief Map container of NodeRef->NodeRef in DSL graph.
+ * Map implements copy on write semantics, which means map is mutable
+ * but copy will happen when array is referenced in more than two places.
+ *
+ * operator[] only provide const acces, use Set to mutate the content.
+ * \tparam K The key NodeRef type.
+ * \tparam V The value NodeRef type.
+ */
+template<typename K,
+ typename V,
+ typename = typename std::enable_if<
+ std::is_base_of<NodeRef, K>::value ||
+ std::is_base_of<std::string, K>::value >::type,
+ typename = typename std::enable_if<std::is_base_of<NodeRef, V>::value>::type>
+class Map : public NodeRef {
+ public:
+ /*!
+ * \brief default constructor
+ */
+ Map() {
+ node_ = make_node<MapNode>();
+ }
+ /*!
+ * \brief move constructor
+ * \param other source
+ */
+ Map(Map<K, V> && other) { // NOLINT(*)
+ node_ = std::move(other.node_);
+ }
+ /*!
+ * \brief copy constructor
+ * \param other source
+ */
+ Map(const Map<K, V> &other) : NodeRef(other.node_) { // NOLINT(*)
+ }
+ /*!
+ * \brief constructor from pointer
+ * \param n the container pointer
+ */
+ explicit Map(NodePtr<Node> n) : NodeRef(n) {}
+ /*!
+ * \brief constructor from iterator
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template<typename IterType>
+ Map(IterType begin, IterType end) {
+ assign(begin, end);
+ }
+ /*!
+ * \brief constructor from initializer list
+ * \param init The initalizer list
+ */
+ Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+ /*!
+ * \brief constructor from vector
+ * \param init The vector
+ */
+ template<typename Hash, typename Equal>
+ Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+ /*!
+ * \brief move assign operator
+ * \param other The source of assignment
+ * \return reference to self.
+ */
+ Map<K, V>& operator=(Map<K, V> && other) {
+ node_ = std::move(other.node_);
+ return *this;
+ }
+ /*!
+ * \brief copy assign operator
+ * \param other The source of assignment
+ * \return reference to self.
+ */
+ Map<K, V>& operator=(const Map<K, V> & other) {
+ node_ = other.node_;
+ return *this;
+ }
+ /*!
+ * \brief reset the array to content from iterator.
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template<typename IterType>
+ void assign(IterType begin, IterType end) {
+ NodePtr<MapNode> n = make_node<MapNode>();
+ for (IterType i = begin; i != end; ++i) {
+ n->data.emplace(std::make_pair(i->first.node_,
+ i->second.node_));
+ }
+ node_ = std::move(n);
+ }
+ /*!
+ * \brief Read element from map.
+ * \param key The key
+ * \return the corresonding element.
+ */
+ inline const V operator[](const K& key) const {
+ return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
+ }
+ /*!
+ * \brief Read element from map.
+ * \param key The key
+ * \return the corresonding element.
+ */
+ inline const V at(const K& key) const {
+ return V(static_cast<const MapNode*>(node_.get())->data.at(key.node_));
+ }
+ /*! \return The size of the array */
+ inline size_t size() const {
+ if (node_.get() == nullptr) return 0;
+ return static_cast<const MapNode*>(node_.get())->data.size();
+ }
+ /*! \return The number of elements of the key */
+ inline size_t count(const K& key) const {
+ if (node_.get() == nullptr) return 0;
+ return static_cast<const MapNode*>(node_.get())->data.count(key.node_);
+ }
+ /*!
+ * \brief copy on write semantics
+ * Do nothing if current handle is the unique copy of the array.
+ * Otherwise make a new copy of the array to ensure the current handle
+ * hold a unique copy.
+ *
+ * \return Handle to the internal node container(which ganrantees to be unique)
+ */
+ inline MapNode* CopyOnWrite() {
+ if (node_.get() == nullptr || !node_.unique()) {
+ NodePtr<MapNode> n = make_node<MapNode>();
+ n->data = static_cast<const MapNode*>(node_.get())->data;
+ NodePtr<Node>(std::move(n)).swap(node_);
+ }
+ return static_cast<MapNode*>(node_.get());
+ }
+ /*!
+ * \brief set the Map.
+ * \param key The index key.
+ * \param value The value to be setted.
+ */
+ inline void Set(const K& key, const V& value) {
+ MapNode* n = this->CopyOnWrite();
+ n->data[key.node_] = value.node_;
+ }
+
+ /*! \return whether array is empty */
+ inline bool empty() const {
+ return size() == 0;
+ }
+ /*! \brief specify container node */
+ using ContainerType = MapNode;
+
+ struct Ptr2NodeRef {
+ using ResultType = std::pair<K, V>;
+ static inline ResultType convert(const std::pair<
+ NodePtr<Node>,
+ NodePtr<Node> >& n) {
+ return std::make_pair(K(n.first), V(n.second));
+ }
+ };
+
+ using iterator = IterAdapter<
+ Ptr2NodeRef, MapNode::ContainerType::const_iterator>;
+
+ /*! \return begin iterator */
+ inline iterator begin() const {
+ return iterator(static_cast<const MapNode*>(node_.get())->data.begin());
+ }
+ /*! \return end iterator */
+ inline iterator end() const {
+ return iterator(static_cast<const MapNode*>(node_.get())->data.end());
+ }
+ /*! \return begin iterator */
+ inline iterator find(const K& key) const {
+ return iterator(static_cast<const MapNode*>(node_.get())->data.find(key.node_));
+ }
+};
+
+// specialize of string map
+template<typename V, typename T1, typename T2>
+class Map<std::string, V, T1, T2> : public NodeRef {
+ public:
+ // for code reuse
+ Map() {
+ node_ = make_node<StrMapNode>();
+ }
+ Map(Map<std::string, V> && other) { // NOLINT(*)
+ node_ = std::move(other.node_);
+ }
+ Map(const Map<std::string, V> &other) : NodeRef(other.node_) { // NOLINT(*)
+ }
+ explicit Map(NodePtr<Node> n) : NodeRef(n) {}
+ template<typename IterType>
+ Map(IterType begin, IterType end) {
+ assign(begin, end);
+ }
+ Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+
+ template<typename Hash, typename Equal>
+ Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
+ assign(init.begin(), init.end());
+ }
+ Map<std::string, V>& operator=(Map<std::string, V> && other) {
+ node_ = std::move(other.node_);
+ return *this;
+ }
+ Map<std::string, V>& operator=(const Map<std::string, V> & other) {
+ node_ = other.node_;
+ return *this;
+ }
+ template<typename IterType>
+ void assign(IterType begin, IterType end) {
+ auto n = make_node<StrMapNode>();
+ for (IterType i = begin; i != end; ++i) {
+ n->data.emplace(std::make_pair(i->first,
+ i->second.node_));
+ }
+ node_ = std::move(n);
+ }
+ inline const V operator[](const std::string& key) const {
+ return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
+ }
+ inline const V at(const std::string& key) const {
+ return V(static_cast<const StrMapNode*>(node_.get())->data.at(key));
+ }
+ inline size_t size() const {
+ if (node_.get() == nullptr) return 0;
+ return static_cast<const StrMapNode*>(node_.get())->data.size();
+ }
+ inline size_t count(const std::string& key) const {
+ if (node_.get() == nullptr) return 0;
+ return static_cast<const StrMapNode*>(node_.get())->data.count(key);
+ }
+ inline StrMapNode* CopyOnWrite() {
+ if (node_.get() == nullptr || !node_.unique()) {
+ NodePtr<StrMapNode> n = make_node<StrMapNode>();
+ n->data = static_cast<const StrMapNode*>(node_.get())->data;
+ NodePtr<Node>(std::move(n)).swap(node_);
+ }
+ return static_cast<StrMapNode*>(node_.get());
+ }
+ inline void Set(const std::string& key, const V& value) {
+ StrMapNode* n = this->CopyOnWrite();
+ n->data[key] = value.node_;
+ }
+ inline bool empty() const {
+ return size() == 0;
+ }
+ using ContainerType = StrMapNode;
+
+ struct Ptr2NodeRef {
+ using ResultType = std::pair<std::string, V>;
+ static inline ResultType convert(const std::pair<
+ std::string,
+ NodePtr<Node> >& n) {
+ return std::make_pair(n.first, V(n.second));
+ }
+ };
+
+ using iterator = IterAdapter<
+ Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>;
+
+ /*! \return begin iterator */
+ inline iterator begin() const {
+ return iterator(static_cast<const StrMapNode*>(node_.get())->data.begin());
+ }
+ /*! \return end iterator */
+ inline iterator end() const {
+ return iterator(static_cast<const StrMapNode*>(node_.get())->data.end());
+ }
+ /*! \return begin iterator */
+ inline iterator find(const std::string& key) const {
+ return iterator(static_cast<const StrMapNode*>(node_.get())->data.find(key));
+ }
+};
+
+} // namespace tvm
+#endif // TVM_NODE_CONTAINER_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/ir_functor.h
+ * \brief Defines the IRFunctor data structures.
+ */
+#ifndef TVM_NODE_IR_FUNCTOR_H_
+#define TVM_NODE_IR_FUNCTOR_H_
+
+#include <dmlc/logging.h>
+#include <string>
+#include <vector>
+#include <memory>
+#include <type_traits>
+#include <utility>
+#include <functional>
+#include "node.h"
+
+namespace tvm {
+/*!
+ * \brief A dynamically dispatched functor on NodeRef in the first argument.
+ *
+ * \code
+ * IRFunctor<std::string (const NodeRef& n, std::string prefix)> tostr;
+ * tostr.set_dispatch<Add>([](const Add* op, std::string prefix) {
+ * return prefix + "Add";
+ * });
+ * tostr.set_dispatch<IntImm>([](const IntImm* op) {
+ * return prefix + "IntImm"
+ * });
+ *
+ * Expr x = make_const(1);
+ * Expr y = x + x;
+ * // dispatch to IntImm, outputs "MyIntImm"
+ * LOG(INFO) << tostr(x, "My");
+ * // dispatch to IntImm, outputs "MyAdd"
+ * LOG(INFO) << tostr(y, "My");
+ * \endcode
+ *
+ * \tparam FType function signiture
+ * This type if only defined for FType with function signature
+ */
+template<typename FType>
+class IRFunctor;
+
+template<typename R, typename ...Args>
+class IRFunctor<R(const NodeRef& n, Args...)> {
+ private:
+ using Function = std::function<R (const NodeRef&n, Args...)>;
+ using TSelf = IRFunctor<R (const NodeRef& n, Args...)>;
+ /*! \brief internal function table */
+ std::vector<Function> func_;
+
+ public:
+ /*! \brief the result type of this functor */
+ using result_type = R;
+ /*!
+ * \brief Whether the functor can dispatch the corresponding Node
+ * \param n The node to be dispatched
+ * \return Whether dispatching function is registered for n's type.
+ */
+ inline bool can_dispatch(const NodeRef& n) const {
+ uint32_t type_index = n.type_index();
+ return type_index < func_.size() && func_[type_index] != nullptr;
+ }
+ /*!
+ * \brief invoke the functor , dispatch on type of n
+ * \param n The Node argument
+ * \param args The additional arguments
+ * \return The result.
+ */
+ inline R operator()(const NodeRef& n, Args... args) const {
+ uint32_t type_index = n.type_index();
+ CHECK(type_index < func_.size() &&
+ func_[type_index] != nullptr)
+ << "IRFunctor calls un-registered function on type "
+ << Node::TypeIndex2Key(type_index);
+ return func_[type_index](n, std::forward<Args>(args)...);
+ }
+ /*!
+ * \brief set the dispacher for type TNode
+ * \param f The function to be set.
+ * \tparam TNode the type of Node to be dispatched.
+ * \return reference to self.
+ */
+ template<typename TNode>
+ inline TSelf& set_dispatch(Function f) { // NOLINT(*)
+ uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
+ if (func_.size() <= tindex) {
+ func_.resize(tindex + 1, nullptr);
+ }
+ CHECK(func_[tindex] == nullptr)
+ << "Dispatch for " << Node::TypeIndex2Key(tindex)
+ << " is already set";
+ func_[tindex] = f;
+ return *this;
+ }
+ /*!
+ * \brief set the dispacher for type TNode
+ * This allows f to used detailed const Node pointer to replace NodeRef
+ *
+ * \param f The function to be set.
+ * \tparam TNode the type of Node to be dispatched.
+ * \return reference to self.
+ */
+ template<typename TNode>
+ inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
+ Function fun = [f](const NodeRef& n, Args... args) {
+ return f(static_cast<const TNode*>(n.node_.get()),
+ std::forward<Args>(args)...);
+ };
+ return this->set_dispatch<TNode>(fun);
+ }
+ /*!
+ * \brief unset the dispacher for type TNode
+ *
+ * \tparam TNode the type of Node to be dispatched.
+ * \return reference to self.
+ */
+ template<typename TNode>
+ inline TSelf& clear_dispatch() { // NOLINT(*)
+ uint32_t tindex = Node::TypeKey2Index(TNode::_type_key);
+ CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
+ func_[tindex] = nullptr;
+ return *this;
+ }
+};
+
+#if defined(__GNUC__)
+#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
+#else
+#define TVM_ATTRIBUTE_UNUSED
+#endif
+
+/*! \brief helper macro to generate string concat */
+#define TVM_STR_CONCAT_(__x, __y) __x##__y
+#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
+
+#define TVM_REGISTER_VAR_DEF(ClsName) \
+ static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName
+
+/*!
+ * \brief Useful macro to set IRFunctor dispatch in a global static field.
+ *
+ * \code
+ * // Use IRFunctor to implement IRPrinter similar to Visitor Pattern.
+ * // vtable allows easy patch in of new Node types, without changing
+ * // interface of IRPrinter.
+ *
+ * class IRPrinter {
+ * public:
+ * std::ostream& stream;
+ * // the dispatch function.
+ * void print(Expr e) {
+ * const static FType& f = *vtable();
+ * f(e, this);
+ * }
+ *
+ * using FType = IRFunctor<void (const NodeRef&, IRPrinter *)>;
+ * // function to return global function table
+ * static FType& vtable();
+ * };
+ *
+ * // in cpp/cc file
+ * IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
+ * static FType inst; return inst;
+ * }
+ *
+ * TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+ * .set_dispatch<Add>([](const Add* n, IRPrinter* p) {
+ * p->print(n->a);
+ * p->stream << '+'
+ * p->print(n->b);
+ * });
+ *
+ *
+ * \endcode
+ *
+ * \param ClsName The name of the class
+ * \param FField The static function that returns a singleton of IRFunctor.
+ */
+#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
+ TVM_STR_CONCAT(TVM_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
+ ClsName::FField()
+
+ /*!
+ * \brief A container for a list of callbacks. All callbacks are invoked when
+ * the object is destructed.
+ */
+class IRFunctorCleanList {
+ public:
+ ~IRFunctorCleanList() {
+ for (auto &f : clean_items) {
+ f();
+ }
+ }
+
+ void append(std::function<void()> func) {
+ clean_items.push_back(func);
+ }
+
+ private:
+ std::vector< std::function<void()> > clean_items;
+};
+
+/*!
+* \brief A wrapper around IRFunctor that will record calls to set_dispatch
+* and make a corresponding call to clear_dispatch when the last copy of
+* the IRFunctorStaticRegistry is destructed. When assigned to a static variable,
+* this can be used by NNVM and other libraries to unregister callbacks when
+* the library is unloaded. This prevents crashes when the underlying IRFunctor
+* is destructed as it will no longer contain std::function instances allocated
+* by a library that has been unloaded.
+*/
+template<typename FType>
+class IRFunctorStaticRegistry;
+
+template<typename R, typename ...Args>
+class IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> {
+ private:
+ IRFunctor<R(const NodeRef& n, Args...)> *irf_;
+ std::shared_ptr<IRFunctorCleanList> free_list;
+
+ using TSelf = IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>;
+
+ public:
+ IRFunctorStaticRegistry(IRFunctor<R(const NodeRef& n, Args...)> *irf) {
+ irf_ = irf;
+ free_list = std::make_shared<IRFunctorCleanList>();
+ }
+
+ template<typename TNode>
+ inline TSelf& set_dispatch(std::function<R(const TNode* n, Args...)> f) { // NOLINT(*)
+ irf_->template set_dispatch<TNode>(f);
+ auto irf_copy = irf_;
+ free_list.get()->append([irf_copy] {
+ irf_copy->template clear_dispatch<TNode>();
+ });
+ return *this;
+ }
+};
+
+/*!
+* \brief Helper function for constructing an IRFunctorStaticRegistry. This allows
+* the compiler to deduce the template types.
+*/
+template<typename R, typename ...Args>
+IRFunctorStaticRegistry<R(const NodeRef& n, Args...)> MakeIRFunctorStaticRegistry(
+ IRFunctor<R(const NodeRef& n, Args...)> *irf) {
+ return IRFunctorStaticRegistry<R(const NodeRef& n, Args...)>(irf);
+}
+
+#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \
+ static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName
+
+/*!
+* \brief Macro to set IRFunctor dispatch in a global static field using an IRFunctorStaticRegistry.
+* Usage is exactly the same as TVM_STATIC_IR_FUNCTOR. Libraries should use this instead of
+* TVM_STATIC_IR_FUNCTOR.
+*/
+#define TVM_STATIC_IR_FUNCTOR_REGISTER(ClsName, FField) \
+ TVM_STR_CONCAT(TVM_AUTO_REGISTER_VAR_DEF(ClsName), __COUNTER__) = \
+ MakeIRFunctorStaticRegistry(&ClsName::FField())
+
+} // namespace tvm
+#endif // TVM_NODE_IR_FUNCTOR_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/memory.h
+ * \brief Node memory management.
+ */
+#ifndef TVM_NODE_MEMORY_H_
+#define TVM_NODE_MEMORY_H_
+
+#include <utility>
+#include "node.h"
+
+namespace tvm {
+/*!
+ * \brief Allocate a node object.
+ * \param args arguments to the constructor.
+ * \tparam T the node type.
+ * \return The NodePtr to the allocated object.
+ */
+template<typename T, typename... Args>
+inline NodePtr<T> make_node(Args&&... args);
+
+// Detail implementations after this
+//
+// The current design allows swapping the
+// allocator pattern when necessary.
+//
+// Possible future allocator optimizations:
+// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
+// - Thread-local object pools: one pool per size and alignment requirement.
+// - Can specialize by type of object to give the specific allocator to each object.
+//
+template<typename T>
+class SimpleNodeAllocator {
+ public:
+ template<typename... Args>
+ static T* New(Args&&... args) {
+ return new T(std::forward<Args>(args)...);
+ }
+ static NodeBase::FDeleter Deleter() {
+ return Deleter_;
+ }
+
+ private:
+ static void Deleter_(NodeBase* ptr) {
+ delete static_cast<T*>(ptr);
+ }
+};
+
+template<typename T, typename... Args>
+inline NodePtr<T> make_node(Args&&... args) {
+ using Allocator = SimpleNodeAllocator<T>;
+ static_assert(std::is_base_of<NodeBase, T>::value,
+ "make_node can only be used to create NodeBase");
+ T* node = Allocator::New(std::forward<Args>(args)...);
+ node->deleter_ = Allocator::Deleter();
+ return NodePtr<T>(node);
+}
+
+} // namespace tvm
+#endif // TVM_NODE_MEMORY_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/node.h
+ * \brief Node system data structure.
+ */
+#ifndef TVM_NODE_NODE_H_
+#define TVM_NODE_NODE_H_
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/node_base.h>
+#include <string>
+#include <vector>
+#include <utility>
+#include <type_traits>
+
+
+namespace tvm {
+// forward declaration
+class DataType;
+class Node;
+class NodeRef;
+
+namespace runtime {
+// forward declaration
+class NDArray;
+// forward declaration
+class Object;
+} // namespace runtime
+
+/*!
+ * \brief Visitor class to each node content.
+ * The content is going to be called for each field.
+ */
+class TVM_DLL AttrVisitor {
+ public:
+//! \cond Doxygen_Suppress
+ virtual ~AttrVisitor() = default;
+ virtual void Visit(const char* key, double* value) = 0;
+ virtual void Visit(const char* key, int64_t* value) = 0;
+ virtual void Visit(const char* key, uint64_t* value) = 0;
+ virtual void Visit(const char* key, int* value) = 0;
+ virtual void Visit(const char* key, bool* value) = 0;
+ virtual void Visit(const char* key, std::string* value) = 0;
+ virtual void Visit(const char* key, void** value) = 0;
+ virtual void Visit(const char* key, DataType* value) = 0;
+ virtual void Visit(const char* key, NodeRef* value) = 0;
+ virtual void Visit(const char* key, runtime::NDArray* value) = 0;
+ virtual void Visit(const char* key, runtime::Object* value) = 0;
+ template<typename ENum,
+ typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ void Visit(const char* key, ENum* ptr) {
+ static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
+ "declare enum to be enum int to use visitor");
+ this->Visit(key, reinterpret_cast<int*>(ptr));
+ }
+//! \endcond
+};
+
+/*!
+ * \brief base class of node container in DSL AST.
+ */
+class TVM_DLL Node : public NodeBase {
+ public:
+ /*! \brief virtual destructor */
+ virtual ~Node() {}
+ /*! \return The unique type key of the node */
+ virtual const char* type_key() const = 0;
+ /*!
+ * \brief Apply visitor to each field of the Node
+ * Visitor could mutate the content of the node.
+ * override if Node contains attribute fields.
+ * \param visitor The visitor
+ */
+ virtual void VisitAttrs(AttrVisitor* visitor) {}
+ /*! \return the type index of the node */
+ virtual const uint32_t type_index() const = 0;
+ /*!
+ * \brief Whether this node derives from node with type_index=tid.
+ * Implemented by TVM_DECLARE_NODE_TYPE_INFO
+ *
+ * \param tid The type index.
+ * \return the check result.
+ */
+ virtual const bool _DerivedFrom(uint32_t tid) const;
+ /*!
+ * \brief get a runtime unique type index given a type key
+ * \param type_key Type key of a type.
+ * \return the corresponding type index.
+ */
+ static uint32_t TypeKey2Index(const char* type_key);
+ /*!
+ * \brief get type key from type index.
+ * \param index The type index
+ * \return the corresponding type key.
+ */
+ static const char* TypeIndex2Key(uint32_t index);
+ /*!
+ * \return whether the type is derived from
+ */
+ template<typename T>
+ inline bool derived_from() const;
+ /*!
+ * \return whether the node is of type T
+ * \tparam The type to be checked.
+ */
+ template<typename T>
+ inline bool is_type() const;
+ /*!
+ * \brief Get a NodePtr that holds reference to this Node.
+ * \return the NodePtr
+ */
+ inline NodePtr<Node> GetNodePtr() const;
+ // node ref can see this
+ friend class NodeRef;
+ static constexpr const char* _type_key = "Node";
+};
+
+/*! \brief Base class of all node reference object */
+class NodeRef {
+ public:
+ /*! \brief type indicate the container type */
+ using ContainerType = Node;
+ /*!
+ * \brief Comparator
+ * \param other Another node ref.
+ * \return the compare result.
+ */
+ inline bool operator==(const NodeRef& other) const;
+ /*!
+ * \brief Comparator
+ * \param other Another node ref.
+ * \return the compare result.
+ */
+ inline bool same_as(const NodeRef& other) const;
+ /*!
+ * \brief Comparator
+ * \param other Another node ref.
+ * \return the compare result.
+ */
+ inline bool operator<(const NodeRef& other) const;
+ /*!
+ * \brief Comparator
+ * \param other Another node ref.
+ * \return the compare result.
+ */
+ inline bool operator!=(const NodeRef& other) const;
+ /*! \return the hash function for NodeRef */
+ inline size_t hash() const;
+ /*! \return whether the expression is null */
+ inline bool defined() const;
+ /*! \return the internal type index of IRNode */
+ inline uint32_t type_index() const;
+ /*! \return the internal node pointer */
+ inline const Node* get() const;
+ /*! \return the internal node pointer */
+ inline const Node* operator->() const;
+ /*!
+ * \brief Downcast this ir node to its actual type (e.g. Add, or
+ * Select). This returns nullptr if the node is not of the requested
+ * type. Example usage:
+ *
+ * if (const Add *add = node->as<Add>()) {
+ * // This is an add node
+ * }
+ * \tparam T the target type, must be subtype of IRNode
+ */
+ template<typename T>
+ inline const T *as() const;
+ /*!
+ * \brief A more powerful version of as that also works with
+ * intermediate base types.
+ * \tparam T the target type, must be subtype of IRNode
+ */
+ template<typename T>
+ inline const T *as_derived() const;
+ /*! \brief default constructor */
+ NodeRef() = default;
+ explicit NodeRef(NodePtr<Node> node) : node_(node) {}
+ /*! \brief the internal node object, do not touch */
+ NodePtr<Node> node_;
+};
+
+/*!
+ * \brief Get a reference type from a Node ptr type
+ *
+ * It is always important to get a reference type
+ * if we want to return a value as reference or keep
+ * the node alive beyond the scope of the function.
+ *
+ * \param ptr The node pointer
+ * \tparam RefType The reference type
+ * \tparam NodeType The node type
+ * \return The corresponding RefType
+ */
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr);
+
+/*!
+ * \brief Downcast a base reference type to a more specific type.
+ *
+ * \param ref The inptut reference
+ * \return The corresponding SubRef.
+ * \tparam SubRef The target specific reference type.
+ * \tparam BaseRef the current reference type.
+ */
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref);
+
+/*!
+ * \brief helper macro to declare type information in a base node.
+ */
+#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
+ const bool _DerivedFrom(uint32_t tid) const override { \
+ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
+ if (tidx == tid) return true; \
+ return Parent::_DerivedFrom(tid); \
+ }
+
+/*!
+ * \brief helper macro to declare type information in a terminal node
+ */
+#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \
+ const char* type_key() const final { \
+ return TypeName::_type_key; \
+ } \
+ const uint32_t type_index() const final { \
+ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
+ return tidx; \
+ } \
+ const bool _DerivedFrom(uint32_t tid) const final { \
+ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
+ if (tidx == tid) return true; \
+ return Parent::_DerivedFrom(tid); \
+ }
+
+// implementations of inline functions after this
+template<typename T>
+inline bool Node::derived_from() const {
+ // use static field so query only happens once.
+ static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
+ return this->_DerivedFrom(type_id);
+}
+
+
+template<typename T>
+inline bool Node::is_type() const {
+ // use static field so query only happens once.
+ static uint32_t type_id = Node::TypeKey2Index(T::_type_key);
+ return type_id == this->type_index();
+}
+
+
+inline NodePtr<Node> Node::GetNodePtr() const {
+ return NodePtr<Node>(const_cast<Node*>(this));
+}
+
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr) {
+ static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
+ "Can only cast to the ref of same container type");
+ return RefType(ptr->GetNodePtr());
+}
+
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref) {
+ CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
+ ref->template derived_from<typename SubRef::ContainerType>())
+ << "Downcast from " << ref->type_key() << " to "
+ << SubRef::ContainerType::_type_key << " failed.";
+ return SubRef(std::move(ref.node_));
+}
+
+inline const Node* NodeRef::get() const {
+ return node_.get();
+}
+
+inline const Node* NodeRef::operator->() const {
+ return node_.get();
+}
+
+inline bool NodeRef::defined() const {
+ return node_.get() != nullptr;
+}
+
+inline bool NodeRef::operator==(const NodeRef& other) const {
+ return node_.get() == other.node_.get();
+}
+
+inline bool NodeRef::same_as(const NodeRef& other) const {
+ return node_.get() == other.node_.get();
+}
+
+inline bool NodeRef::operator<(const NodeRef& other) const {
+ return node_.get() < other.node_.get();
+}
+
+inline bool NodeRef::operator!=(const NodeRef& other) const {
+ return node_.get() != other.node_.get();
+}
+
+inline size_t NodeRef::hash() const {
+ return std::hash<Node*>()(node_.get());
+}
+
+inline uint32_t NodeRef::type_index() const {
+ CHECK(node_.get() != nullptr)
+ << "null type";
+ return get()->type_index();
+}
+
+template<typename T>
+inline const T* NodeRef::as() const {
+ const Node* ptr = static_cast<const Node*>(get());
+ if (ptr && ptr->is_type<T>()) {
+ return static_cast<const T*>(ptr);
+ }
+ return nullptr;
+}
+
+template<typename T>
+inline const T* NodeRef::as_derived() const {
+ const Node* ptr = static_cast<const Node*>(get());
+ if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
+ return static_cast<const T*>(ptr);
+ }
+ return nullptr;
+}
+
+/*! \brief The hash function for nodes */
+struct NodeHash {
+ size_t operator()(const NodeRef& a) const {
+ return a.hash();
+ }
+};
+
+/*! \brief The equal comparator for nodes */
+struct NodeEqual {
+ bool operator()(const NodeRef& a, const NodeRef& b) const {
+ return a.get() == b.get();
+ }
+};
+} // namespace tvm
+#endif // TVM_NODE_NODE_H_
/*!
* \brief Base class of all operation nodes
*/
-class OperationNode : public FunctionBaseNode {
+class OperationNode : public ir::FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
- EXPORT static Operation make(std::string name,
+ TVM_DLL static Operation make(std::string name,
std::string tag,
Map<std::string, NodeRef> attrs,
Array<Tensor> inputs,
v->Visit("axis", &axis);
v->Visit("body", &body);
}
- EXPORT static Operation make(std::string name,
- std::string tag,
- Map<std::string, NodeRef> attrs,
- Array<Tensor> inputs,
- Array<Tensor> outputs,
- Stmt body);
+ TVM_DLL static Operation make(std::string name,
+ std::string tag,
+ Map<std::string, NodeRef> attrs,
+ Array<Tensor> inputs,
+ Array<Tensor> outputs,
+ Stmt body);
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_NODE_TYPE_INFO(HybridOpNode, OperationNode);
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
- if (!NodeTypeChecker<T>::Check(p.get())) return false;
+ if (!NodeTypeChecker<T>::Check(p.get())) {
+ return false;
+ }
}
return true;
}
return TNodeRef(sptr);
}
-inline TVMArgValue::operator HalideIR::Expr() const {
+inline TVMArgValue::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
}
// type related stuffs
-inline TVMRetValue& TVMRetValue::operator=(const HalideIR::Type& t) {
- return this->operator=(Type2TVMType(t));
+inline TVMRetValue& TVMRetValue::operator=(const DataType& t) {
+ return this->operator=(t.operator DLDataType());
}
-inline TVMRetValue::operator HalideIR::Type() const {
- return TVMType2Type(operator TVMType());
+inline TVMRetValue::operator tvm::DataType() const {
+ return DataType(operator DLDataType());
}
-inline TVMArgValue::operator HalideIR::Type() const {
- return TVMType2Type(operator TVMType());
+inline TVMArgValue::operator tvm::DataType() const {
+ return DataType(operator DLDataType());
}
inline void TVMArgsSetter::operator()(
- size_t i, const HalideIR::Type& t) const {
- this->operator()(i, Type2TVMType(t));
+ size_t i, const DataType& t) const {
+ this->operator()(i, t.operator DLDataType());
}
} // namespace runtime
} // namespace tvm
#include "object.h"
#include "node_base.h"
-namespace HalideIR {
-// Forward declare type for extensions
-// The header works fine without depending on this.
-struct Type;
-struct Expr;
-}
-
-
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
namespace tvm {
// forward declarations
class Integer;
+class DataType;
+class Expr;
namespace runtime {
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const;
- inline operator HalideIR::Type() const;
- inline operator HalideIR::Expr() const;
+ inline operator tvm::DataType() const;
+ inline operator tvm::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr();
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const NodePtr<Node>& other);
// type related
- inline operator HalideIR::Type() const;
- inline TVMRetValue& operator=(const HalideIR::Type& other);
+ inline operator tvm::DataType() const;
+ inline TVMRetValue& operator=(const tvm::DataType& other);
private:
template<typename T>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
- inline void operator()(size_t i, const HalideIR::Type& t) const;
+ inline void operator()(size_t i, const tvm::DataType& t) const;
private:
/*! \brief The values fields */
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
- EXPORT Stage& set_scope(std::string scope); // NOLINT(*)
+ TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
- EXPORT Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
+ TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline.
* \return reference to self.
*/
- EXPORT Stage& compute_inline(); // NOLINT(*)
+ TVM_DLL Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at group root.
* \return reference to self.
*/
- EXPORT Stage& compute_root(); // NOLINT(*)
+ TVM_DLL Stage& compute_root(); // NOLINT(*)
/*!
* \brief Bind the IterVar to thread index.
*
* \param thread_ivar The thread axis to be bound.
* \return reference to self.
*/
- EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar);
+ TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
/*!
* \brief Set the predicate to determine whether a store to the array should be performed.
* Use this when there are multiple threads performing the same store and we only
* \param predicate The condition to be checked.
* \return reference to self.
*/
- EXPORT Stage& set_store_predicate(Expr predicate);
+ TVM_DLL Stage& set_store_predicate(Expr predicate);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
* This is a beta feature.
* \return reference to self.
*/
- EXPORT Stage& env_threads(Array<IterVar> threads);
+ TVM_DLL Stage& env_threads(Array<IterVar> threads);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
* \param p_inner The result inner domain.
* \return reference to self.
*/
- EXPORT Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
* \param p_inner The result inner domain.
* \return reference to self.
*/
- EXPORT Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
* \param p_target The result target domain.
* \return reference to self.
*/
- EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
+ TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Fuse all the axes together into a single axis.
*
*
* \return reference to self.
*/
- EXPORT Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
+ TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
- EXPORT Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
+ TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
* \param p_y_inner Inner axis of y dimension
* \return reference to self.
*/
- EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
+ TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
* \param var The axis to be vectorized.
* \return reference to self.
*/
- EXPORT Stage& vectorize(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
- EXPORT Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
+ TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
*/
- EXPORT Stage& unroll(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
- EXPORT Stage& parallel(IterVar var); // NOLINT(*)
+ TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief Annotate the iteration with pragma
*
*
* \return reference to self.
*/
- EXPORT Stage& pragma(IterVar var,
+ TVM_DLL Stage& pragma(IterVar var,
const std::string& pragma_type,
const Expr& pragma_value = Expr()); // NOLINT(*)
/*!
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
- EXPORT Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
+ TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
* \param offset The required offset factor.
* \return reference to self
*/
- EXPORT Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
+ TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
- EXPORT Stage& double_buffer(); // NOLINT(*)
+ TVM_DLL Stage& double_buffer(); // NOLINT(*)
/*!
* \brief Schedule for OpenGL fragment shader.
* \return reference to self.
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
- EXPORT Stage operator[](const Operation& op);
+ TVM_DLL Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
- EXPORT Stage operator[](const Tensor& tensor) {
+ TVM_DLL Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
- EXPORT Stage create_group(const Array<Tensor>& outputs,
+ TVM_DLL Stage create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs = false);
/*!
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
- EXPORT Tensor cache_read(const Tensor& tensor,
+ TVM_DLL Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
* \param scope The scope of the storage.
* \return The created tensor.
*/
- EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
+ TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
/*!
* \brief Create a cache write tensor for producing tensor.
* The the tensor will take over body of original tensor op.
* \param scope The scope of the storage.
* \return The created tensor.
*/
- EXPORT Tensor cache_write(const Tensor& tensor, const std::string& scope);
+ TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
* \param factor_axis The position where the new axis is placed.
* \return The created factored tensors.
*/
- EXPORT Array<Tensor> rfactor(const Tensor& tensor,
+ TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis,
int factor_axis = 0);
/*!
* \param op The candidate Operation.
* \return true if the schedule has the Operation. Otherwise, false.
*/
- EXPORT bool Contain(const Operation& op) const;
+ TVM_DLL bool Contain(const Operation& op) const;
/*!
* \brief Check if the schedule contains a Tensor.
* \param tensor The candidate tensor.
* \return true if the schedule has the tensor. Otherwise, false.
*/
- EXPORT bool Contain(const Tensor& tensor) const {
+ TVM_DLL bool Contain(const Tensor& tensor) const {
return Contain(tensor->op);
}
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
- EXPORT static Schedule make(Array<Operation> ops);
+ TVM_DLL static Schedule make(Array<Operation> ops);
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
*
* \param sch The schedule to be inlined.
*/
-EXPORT void AutoInlineInjective(Schedule sch);
+TVM_DLL void AutoInlineInjective(Schedule sch);
} // namespace schedule
} // namespace tvm
#ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_
-#include <ir/FunctionBase.h>
#include <tvm/node/container.h>
#include <string>
#include <vector>
// internal node container for Operation
class OperationNode;
-using HalideIR::IR::FunctionRef;
-
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
};
/*! \brief Operation that produces tensors */
-class Operation : public FunctionRef {
+class Operation : public ir::FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
TVM_REGISTER_API("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
VarExpr loop_var, Expr min, Expr extent,
- int for_type, int device_api, Stmt body
-) {
+ int for_type, int device_api, Stmt body) {
return For::make(loop_var,
- min,
- extent,
- static_cast<ForType>(for_type),
- static_cast<HalideIR::DeviceAPI>(device_api),
- body);
+ min,
+ extent,
+ static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api),
+ body);
});
TVM_REGISTER_API("make.Load")
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2016 by Contributors
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/
namespace tvm {
TVM_REGISTER_API("_min_value")
-.set_body_method(&Type::min);
+.set_body_method(&DataType::min);
TVM_REGISTER_API("_max_value")
-.set_body_method(&Type::max);
+.set_body_method(&DataType::max);
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// overrides
void VisitAttrs(tvm::AttrVisitor* v) final {
}
- void accept(HalideIR::Internal::IRVisitor* v, const Expr& e) const final {
- LOG(FATAL) << "not supported";
- }
- IRNodeType type_info() const final {
- return IRNodeType::ExtensionExpr;
- }
static constexpr const char* _type_key = "arith.CanonicalExpr";
TVM_DECLARE_BASE_NODE_INFO(CanonicalExprNode, BaseExprNode);
// Override visitor behaviors
Entry VisitExprDefault_(const Node* op) final {
return Everything(
- static_cast<const ir::BaseExprNode*>(op)->type);
+ static_cast<const ExprNode*>(op)->type);
}
Entry VisitExpr(const Expr& expr) final {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
switch (t.code()) {
- case halideir_type_int:
+ case kDLInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
os << "int";
break;
- case halideir_type_uint:
+ case kDLUInt:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
os << "uint";
break;
- case halideir_type_float:
+ case kDLFloat:
CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
os << "float";
break;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/expr_operator.h>
-#include <ir/IRPrinter.h>
#include <memory>
+#include <limits>
namespace tvm {
-using HalideIR::IR::RangeNode;
+// maximum and min values
+Expr DataType::max() const {
+ using namespace ir;
+ CHECK_EQ(lanes(), 1);
+ if (is_int()) {
+ if (bits() == 64) {
+ return IntImm::make(*this, std::numeric_limits<int64_t>::max());
+ } else if (bits() < 64) {
+ int64_t val = 1;
+ val = (val << (bits() - 1)) - 1;
+ return IntImm::make(*this, val);
+ }
+ } else if (is_uint()) {
+ if (bits() == 64) {
+ return UIntImm::make(*this, std::numeric_limits<uint64_t>::max());
+ } else if (bits() < 64) {
+ uint64_t val = 1;
+ val = (val << static_cast<uint64_t>(bits())) - 1;
+ return UIntImm::make(*this, val);
+ }
+ } else if (is_float()) {
+ if (bits() == 64) {
+ return FloatImm::make(*this, std::numeric_limits<double>::max());
+ } else if (bits() == 32) {
+ return FloatImm::make(*this, std::numeric_limits<float>::max());
+ } else if (bits() == 16) {
+ return FloatImm::make(*this, 65504.0);
+ }
+ }
+ LOG(FATAL) << "Cannot decide max_value for type" << *this;
+ return Expr();
+}
+
+Expr DataType::min() const {
+ using namespace ir;
+ CHECK_EQ(lanes(), 1);
+ if (is_int()) {
+ if (bits() == 64) {
+ return IntImm::make(*this, std::numeric_limits<int64_t>::lowest());
+ } else if (bits() < 64) {
+ int64_t val = 1;
+ val = -(val << (bits() - 1));
+ return IntImm::make(*this, val);
+ }
+ } else if (is_uint()) {
+ return UIntImm::make(*this, 0);
+ } else if (is_float()) {
+ if (bits() == 64) {
+ return FloatImm::make(*this, std::numeric_limits<double>::lowest());
+ } else if (bits() == 32) {
+ return FloatImm::make(*this, std::numeric_limits<float>::lowest());
+ } else if (bits() == 16) {
+ return FloatImm::make(*this, -65504.0);
+ }
+ }
+ LOG(FATAL) << "Cannot decide min_value for type" << *this;
+ return Expr();
+}
+
+Expr::Expr(int32_t value)
+ : Expr(IntImm::make(Int(32), value)) {}
+
+Expr::Expr(float value)
+ : Expr(ir::FloatImm::make(Float(32), value)) {}
+
+Expr::Expr(std::string str)
+ : Expr(ir::StringImm::make(str)) {}
+
+Var::Var(std::string name_hint, DataType t)
+ : Var(Variable::make(t, name_hint)) {}
+
+Var Variable::make(DataType t, std::string name_hint) {
+ NodePtr<Variable> node = make_node<Variable>();
+ node->type = t;
+ node->name_hint = std::move(name_hint);
+ return Var(node);
+}
Range::Range(Expr begin, Expr end)
: Range(make_node<RangeNode>(
is_zero(begin) ? end : (end - begin))) {
}
+Integer IntImm::make(Type t, int64_t value) {
+ CHECK(t.is_int() && t.is_scalar())
+ << "ValueError: IntImm can only take scalar.";
+ NodePtr<IntImm> node = make_node<IntImm>();
+ node->type = t;
+ node->value = value;
+ return Integer(node);
+}
+
Range Range::make_by_min_extent(Expr min, Expr extent) {
- return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
+ return Range(make_node<RangeNode>(min, extent));
}
-IterVar IterVarNode::make(Range dom, Var var,
- IterVarType t, std::string thread_tag) {
+IterVar IterVarNode::make(Range dom,
+ Var var,
+ IterVarType t,
+ std::string thread_tag) {
NodePtr<IterVarNode> n = make_node<IterVarNode>();
n->dom = dom;
n->var = var;
dom, Var(name), kCommReduce);
}
-std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
- IRPrinter(os).print(n);
- return os;
-}
-
void Dump(const NodeRef& n) {
std::cerr << n << "\n";
}
-Var var(const std::string& name_hint, Type t) {
+Var var(std::string name_hint, Type t) {
return Var(name_hint, t);
}
+void IRPrinter::Print(const NodeRef& ir) {
+ static const FType& f = vtable();
+ if (!ir.defined()) {
+ stream << "(nullptr)";
+ } else {
+ if (f.can_dispatch(ir)) {
+ f(ir, this);
+ } else {
+ // default value, output type key and addr.
+ stream << ir->type_key() << "(" << ir.get() << ")";
+ }
+ }
+}
+
+void IRPrinter::PrintIndent() {
+ for (int i = 0; i < indent; ++i) {
+ stream << ' ';
+ }
+}
+
+IRPrinter::FType& IRPrinter::vtable() {
+ static FType inst;
+ return inst;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IntImm>([](const IntImm *op, IRPrinter *p) {
+ if (op->type == Int(32)) {
+ p->stream << op->value;
+ } else {
+ p->stream << "(" << op->type << ")" << op->value;
+ }
+ });
+
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
p->stream << "iter_var(";
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) {
+.set_dispatch<RangeNode>([](const RangeNode* op, IRPrinter* p) {
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
-
TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
*/
/*!
- * Copyright (c) 2016 by Contributors
* \file ir.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
-#include <ir/IR.h>
-#include <ir/IRPrinter.h>
#include <memory>
#include "../pass/ir_util.h"
-namespace HalideIR {
-namespace Internal {
+namespace tvm {
+namespace ir {
-using tvm::ir::CommReducerNode;
-using tvm::ir::Reduce;
-using tvm::ir::Any;
-using tvm::ir::AttrStmt;
+// constructors
+Expr UIntImm::make(DataType t, uint64_t value) {
+ CHECK(t.is_uint() && t.lanes() == 1)
+ << "ValueError: UIntImm can only take scalar";
+ NodePtr<UIntImm> node = make_node<UIntImm>();
+ node->type = t;
+ node->value = value;
+ return Expr(node);
+}
-template<>
-void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
- LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
+Expr FloatImm::make(DataType t, double value) {
+ CHECK_EQ(t.lanes(), 1)
+ << "ValueError: FloatImm can only take scalar";
+ NodePtr<FloatImm> node = make_node<FloatImm>();
+ node->type = t;
+ node->value = value;
+ return Expr(node);
}
-template<>
-void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
- LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
+Expr StringImm::make(std::string value) {
+ NodePtr<StringImm> node = make_node<StringImm>();
+ node->type = Handle();
+ node->value = std::move(value);
+ return Expr(node);
}
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
- p->stream << "?";
-});
+Expr Cast::make(DataType t, Expr value) {
+ CHECK(value.defined());
+ CHECK_EQ(t.lanes(), value.type().lanes());
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
- p->stream << "reduce(combiner="
- << op->combiner;
- p->stream << ", source=" << op->source;
- p->stream << ", axis=" << op->axis;
- p->stream << ", where=" << op->condition;
- p->stream << ", value_index=" << op->value_index;
- p->stream << ")";
-});
+ NodePtr<Cast> node = make_node<Cast>();
+ node->type = t;
+ node->value = std::move(value);
+ return Expr(node);
+}
-TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
- p->stream << "comm_reducer(result=" << op->result
- << ", lhs=" << op->lhs
- << ", rhs=" << op->rhs
- << ", identity_element=" << op->identity_element
- << ")";
-});
-} // namespace Internal
-} // namespace HalideIR
-namespace tvm {
-namespace ir {
+Expr And::make(Expr a, Expr b) {
+ CHECK(a.defined()) << "ValueError: a is undefined";
+ CHECK(b.defined()) << "ValueError: b is undefined";
+ CHECK(a.type().is_bool());
+ CHECK(b.type().is_bool());
+ CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+
+ NodePtr<And> node = make_node<And>();
+ node->type = Bool(a.type().lanes());
+ node->a = std::move(a);
+ node->b = std::move(b);
+ return Expr(node);
+}
+
+Expr Or::make(Expr a, Expr b) {
+ CHECK(a.defined()) << "ValueError: a is undefined";
+ CHECK(b.defined()) << "ValueError: b is undefined";
+ CHECK(a.type().is_bool());
+ CHECK(b.type().is_bool());
+ CHECK(a.type() == b.type()) << "TypeError: mismatched types";
+
+ NodePtr<Or> node = make_node<Or>();
+ node->type = Bool(a.type().lanes());
+ node->a = std::move(a);
+ node->b = std::move(b);
+ return Expr(node);
+}
+
+Expr Not::make(Expr a) {
+ CHECK(a.defined()) << "ValueError: a is undefined";
+ CHECK(a.type().is_bool());
+
+ NodePtr<Not> node = make_node<Not>();
+ node->type = Bool(a.type().lanes());
+ node->a = std::move(a);
+ return Expr(node);
+}
+
+Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
+ CHECK(condition.defined()) << "ValueError: condition is undefined";
+ CHECK(true_value.defined()) << "ValueError: true_value is undefined";
+ CHECK(false_value.defined()) << "ValueError: true_value is undefined";
+ CHECK(condition.type().is_bool());
+ CHECK_EQ(condition.type().lanes(), true_value.type().lanes());
+ CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types";
+
+ NodePtr<Select> node = make_node<Select>();
+ node->type = true_value.type();
+ node->condition = std::move(condition);
+ node->true_value = std::move(true_value);
+ node->false_value = std::move(false_value);
+ return Expr(node);
+}
+
+Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
+ CHECK(buffer_var.defined());
+ CHECK(predicate.defined());
+ CHECK(index.defined());
+ CHECK_EQ(type.lanes(), index.type().lanes());
+ CHECK_EQ(type.lanes(), predicate.type().lanes());
+
+ NodePtr<Load> node = make_node<Load>();
+ node->type = type;
+ node->buffer_var = std::move(buffer_var);
+ node->index = std::move(index);
+ node->predicate = std::move(predicate);
+
+ return Expr(node);
+}
+
+Expr Ramp::make(Expr base, Expr stride, int lanes) {
+ CHECK(base.defined());
+ CHECK(stride.defined());
+ CHECK(base.type().is_scalar());
+ CHECK(stride.type().is_scalar());
+ CHECK_GT(lanes, 1);
+ CHECK_EQ(stride.type(), base.type());
+
+ NodePtr<Ramp> node = make_node<Ramp>();
+ node->type = base.type().with_lanes(lanes);
+ node->base = base;
+ node->stride = stride;
+ node->lanes = lanes;
+ return Expr(node);
+}
+
+Expr Broadcast::make(Expr value, int lanes) {
+ CHECK(value.defined());
+ CHECK(value.type().is_scalar());
+ CHECK_GT(lanes, 1);
+
+ NodePtr<Broadcast> node = make_node<Broadcast>();
+ node->type = value.type().with_lanes(lanes);
+ node->value = std::move(value);
+ node->lanes = lanes;
+ return Expr(node);
+}
+
+Expr Let::make(Var var, Expr value, Expr body) {
+ CHECK(value.defined());
+ CHECK(body.defined());
+ CHECK_EQ(value.type(), var.type());
+
+ NodePtr<Let> node = make_node<Let>();
+ node->type = body.type();
+ node->var = std::move(var);
+ node->value = std::move(value);
+ node->body = std::move(body);
+ return Expr(node);
+}
+
+Expr Call::make(DataType type,
+ std::string name,
+ Array<Expr> args,
+ CallType call_type,
+ FunctionRef func,
+ int value_index) {
+ for (size_t i = 0; i < args.size(); ++i) {
+ CHECK(args[i].defined());
+ }
+
+ if (call_type == Halide) {
+ for (size_t i = 0; i < args.size(); ++i) {
+ CHECK(args[i].type().is_int());
+ }
+ }
+
+ NodePtr<Call> node = make_node<Call>();
+ node->type = type;
+ node->name = std::move(name);
+ node->args = std::move(args);
+ node->call_type = call_type;
+ node->func = std::move(func);
+ node->value_index = value_index;
+ return Expr(node);
+}
+
+Expr Shuffle::make(Array<Expr> vectors,
+ Array<Expr> indices) {
+ CHECK_NE(vectors.size(), 0U);
+ CHECK_NE(indices.size(), 0U);
+
+ Type base_type = vectors[0].type().element_of();
+ int total_lanes = 0;
+
+ for (Expr val : vectors) {
+ CHECK(val.type().element_of() == base_type);
+ total_lanes += val.type().lanes();
+ }
+ CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
+
+ NodePtr<Shuffle> node = make_node<Shuffle>();
+ node->type = base_type.with_lanes(static_cast<int>(indices.size()));
+ node->vectors = std::move(vectors);
+ node->indices = std::move(indices);
+ return Expr(node);
+}
+
+Expr Shuffle::make_concat(Array<Expr> vectors) {
+ CHECK_NE(vectors.size(), 0);
+ if (vectors.size() == 1) {
+ return vectors[0];
+ }
+ Array<Expr> indices;
+ int index = 0;
+ for (const Expr& e : vectors) {
+ for (int i = 0; i < e.type().lanes(); ++i) {
+ indices.push_back(IntImm::make(Int(32), index++));
+ }
+ }
+ return make(vectors, indices);
+}
+
+Expr Shuffle::make_extract_element(Expr vector, int index) {
+ return make({vector}, {Integer(index)});
+}
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
return Expr(n);
}
+Stmt LetStmt::make(Var var, Expr value, Stmt body) {
+ CHECK(value.defined());
+ CHECK(body.defined());
+ CHECK_EQ(value.type(), var.type());
+
+ NodePtr<LetStmt> node = make_node<LetStmt>();
+ node->var = std::move(var);
+ node->value = std::move(value);
+ node->body = std::move(body);
+ return Stmt(node);
+}
+
+Stmt AttrStmt::make(NodeRef node,
+ std::string attr_key,
+ Expr value,
+ Stmt body) {
+ auto n = make_node<AttrStmt>();
+ n->node = node;
+ n->attr_key = std::move(attr_key);
+ n->value = std::move(value);
+ n->body = std::move(body);
+ return Stmt(n);
+}
+
+Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
+ CHECK(condition.defined());
+ CHECK(message.type() == Int(32) ||
+ message.as<StringImm>())
+ << "TypeError: AssertStmt message must be an int or string:"
+ << message << "\n";
+
+ NodePtr<AssertStmt> node = make_node<AssertStmt>();
+ node->condition = std::move(condition);
+ node->message = std::move(message);
+ node->body = std::move(body);
+ return Stmt(node);
+}
+
+Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
+ CHECK(body.defined());
+
+ NodePtr<ProducerConsumer> node = make_node<ProducerConsumer>();
+ node->func = std::move(func);
+ node->is_producer = is_producer;
+ node->body = std::move(body);
+ return Stmt(node);
+}
+
+Stmt For::make(Var loop_var,
+ Expr min,
+ Expr extent,
+ ForType for_type,
+ DeviceAPI device_api,
+ Stmt body) {
+ CHECK(min.defined());
+ CHECK(extent.defined());
+ CHECK(min.type().is_scalar());
+ CHECK(extent.type().is_scalar());
+ CHECK(loop_var.type().is_scalar());
+ CHECK(body.defined());
+
+ NodePtr<For> node = make_node<For>();
+ node->loop_var = std::move(loop_var);
+ node->min = std::move(min);
+ node->extent = std::move(extent);
+ node->for_type = for_type;
+ node->device_api = device_api;
+ node->body = std::move(body);
+ return Stmt(node);
+}
+
+Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+ CHECK(value.defined());
+ CHECK(index.defined());
+ CHECK(predicate.defined());
+ CHECK_EQ(value.type().lanes(), index.type().lanes());
+ CHECK_EQ(value.type().lanes(), predicate.type().lanes());
+
+ NodePtr<Store> node = make_node<Store>();
+ node->buffer_var = std::move(buffer_var);
+ node->value = std::move(value);
+ node->index = std::move(index);
+ node->predicate = std::move(predicate);
+ return Stmt(node);
+}
+
+Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+ CHECK(value_index >=0 && value_index < func->num_outputs())
+ << "value index output function return value bound";
+ CHECK(value.defined()) << "Provide of undefined value\n";
+
+ for (size_t i = 0; i < args.size(); ++i) {
+ CHECK(args[i].defined()) << "Provide to undefined location\n";
+ }
+
+ NodePtr<Provide> node = make_node<Provide>();
+ node->func = std::move(func);
+ node->value_index = value_index;
+ node->value = std::move(value);
+ node->args = std::move(args);
+ return Stmt(node);
+}
+
+Stmt Allocate::make(Var buffer_var,
+ DataType type,
+ Array<Expr> extents,
+ Expr condition,
+ Stmt body,
+ Expr new_expr,
+ std::string free_function) {
+ for (size_t i = 0; i < extents.size(); ++i) {
+ CHECK(extents[i].defined());
+ CHECK(extents[i].type().is_scalar());
+ }
+ CHECK(body.defined());
+ CHECK(condition.defined());
+ CHECK(condition.type().is_bool());
+
+ NodePtr<Allocate> node = make_node<Allocate>();
+ node->buffer_var = std::move(buffer_var);
+ node->type = type;
+ node->extents = std::move(extents);
+ node->condition = std::move(condition);
+ node->body = std::move(body);
+ node->new_expr = std::move(new_expr);
+ node->free_function = std::move(free_function);
+ return Stmt(node);
+}
+
+int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
+ int64_t result = 1;
+ for (size_t i = 0; i < extents.size(); ++i) {
+ if (const IntImm *int_size = extents[i].as<IntImm>()) {
+ result *= int_size->value;
+ if (result > std::numeric_limits<int32_t>::max()) {
+ return 0;
+ }
+ } else {
+ return 0;
+ }
+ }
+ return static_cast<int32_t>(result);
+}
+
+Stmt Free::make(Var buffer_var) {
+ NodePtr<Free> node = make_node<Free>();
+ node->buffer_var = buffer_var;
+ return Stmt(node);
+}
+
+Stmt Realize::make(FunctionRef func,
+ int value_index,
+ DataType type,
+ Region bounds,
+ Expr condition,
+ Stmt body) {
+ for (size_t i = 0; i < bounds.size(); ++i) {
+ CHECK(bounds[i]->min.defined());
+ CHECK(bounds[i]->extent.defined());
+ CHECK(bounds[i]->min.type().is_scalar());
+ CHECK(bounds[i]->extent.type().is_scalar());
+ }
+ CHECK(body.defined());
+ CHECK(condition.defined());
+ CHECK(condition.type().is_bool());
+
+ NodePtr<Realize> node = make_node<Realize>();
+ node->func = std::move(func);
+ node->value_index = value_index;
+ node->type = type;
+ node->bounds = std::move(bounds);
+ node->condition = std::move(condition);
+ node->body = std::move(body);
+ return Stmt(node);
+}
+
+Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) {
+ for (size_t i = 0; i < bounds.size(); ++i) {
+ CHECK(bounds[i]->min.defined());
+ CHECK(bounds[i]->extent.defined());
+ CHECK(bounds[i]->min.type().is_scalar());
+ CHECK(bounds[i]->extent.type().is_scalar());
+ }
+
+ NodePtr<Prefetch> node = make_node<Prefetch>();
+ node->func = std::move(func);
+ node->value_index = value_index;
+ node->type = type;
+ node->bounds = std::move(bounds);
+ return Stmt(node);
+}
+
+Stmt Block::make(Stmt first, Stmt rest) {
+ CHECK(first.defined());
+ CHECK(rest.defined());
+ NodePtr<Block> node = make_node<Block>();
+
+ // canonicalize.
+ if (const Block* b = first.as<Block>()) {
+ node->first = b->first;
+ node->rest = Block::make(b->rest, rest);
+ } else {
+ node->first = std::move(first);
+ node->rest = std::move(rest);
+ }
+ return Stmt(node);
+}
+
+Stmt Block::make(const std::vector<Stmt>& stmts) {
+ if (stmts.empty()) {
+ return Stmt();
+ }
+ Stmt result = stmts.back();
+ for (size_t i = stmts.size() - 1; i != 0; --i) {
+ result = Block::make(stmts[i - 1], result);
+ }
+ return result;
+}
+
+Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
+ CHECK(condition.defined());
+ CHECK(then_case.defined());
+ // else_case may be null.
+
+ NodePtr<IfThenElse> node = make_node<IfThenElse>();
+ node->condition = std::move(condition);
+ node->then_case = std::move(then_case);
+ node->else_case = std::move(else_case);
+ return Stmt(node);
+}
+
+Stmt Evaluate::make(Expr value) {
+ CHECK(value.defined());
+
+ NodePtr<Evaluate> node = make_node<Evaluate>();
+ node->value = std::move(value);
+ return Stmt(node);
+}
+
+// Printers
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<UIntImm>([](const UIntImm* op, IRPrinter* p) {
+ p->stream << "(" << op->type << ")" << op->value;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<FloatImm>([](const FloatImm* op, IRPrinter* p) {
+ auto& stream = p->stream;
+ switch (op->type.bits()) {
+ case 64:
+ stream << op->value;
+ break;
+ case 32:
+ stream << op->value << 'f';
+ break;
+ case 16:
+ stream << op->value << 'h';
+ break;
+ default:
+ LOG(FATAL) << "Unknown float type bits=" << op->type.bits();
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<StringImm>([](const StringImm* op, IRPrinter* p) {
+ auto& stream = p->stream;
+ stream << '"';
+ for (size_t i = 0; i < op->value.size(); ++i) {
+ unsigned char c = op->value[i];
+ if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
+ stream << c;
+ } else {
+ stream << '\\';
+ switch (c) {
+ case '"':
+ stream << '"';
+ break;
+ case '\\':
+ stream << '\\';
+ break;
+ case '\t':
+ stream << 't';
+ break;
+ case '\r':
+ stream << 'r';
+ break;
+ case '\n':
+ stream << 'n';
+ break;
+ default:
+ const char* hex_digits = "0123456789ABCDEF";
+ stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
+ }
+ }
+ }
+ stream << '"';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Cast>([](const Cast* op, IRPrinter* p) {
+ p->stream << op->type << '(';
+ p->Print(op->value);
+ p->stream << ')';
+ })
+.set_dispatch<Variable>([](const Variable* op, IRPrinter* p) {
+ // omit the type
+ // stream << op->name << "." << op->type;
+ p->stream << op->name_hint;
+ })
+.set_dispatch<Add>([](const Add* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " + ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+.set_dispatch<Sub>([](const Sub* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " - ";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+.set_dispatch<Mul>([](const Mul* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << "*";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+.set_dispatch<Div>([](const Div* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << "/";
+ p->Print(op->b);
+ p->stream << ')';
+ })
+.set_dispatch<Mod>([](const Mod* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " % ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<Min>([](const Min* op, IRPrinter* p) {
+ p->stream << "min(";
+ p->Print(op->a);
+ p->stream << ", ";
+ p->Print(op->b);
+ p->stream << ")";
+})
+.set_dispatch<Max>([](const Max* op, IRPrinter* p) {
+ p->stream << "max(";
+ p->Print(op->a);
+ p->stream << ", ";
+ p->Print(op->b);
+ p->stream << ")";
+})
+.set_dispatch<EQ>([](const EQ* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " == ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<NE>([](const NE* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " != ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<LT>([](const LT* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " < ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<LE>([](const LE* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " <= ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<GT>([](const GT* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " > ";
+ p->Print(op->b);
+ p->stream << ')';
+})
+.set_dispatch<GE>([](const GE* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " >= ";
+ p->Print(op->b);
+ p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<And>([](const And* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " && ";
+ p->Print(op->b);
+ p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Or>([](const Or* op, IRPrinter* p) {
+ p->stream << '(';
+ p->Print(op->a);
+ p->stream << " || ";
+ p->Print(op->b);
+ p->stream << ')';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Not>([](const Not* op, IRPrinter* p) {
+ p->stream << '!';
+ p->Print(op->a);
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Select>([](const Select* op, IRPrinter* p) {
+ p->stream << "select(";
+ p->Print(op->condition);
+ p->stream << ", ";
+ p->Print(op->true_value);
+ p->stream << ", ";
+ p->Print(op->false_value);
+ p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Load>([](const Load* op, IRPrinter* p) {
+ p->stream << op->buffer_var << "[";
+ p->Print(op->index);
+ p->stream << "]";
+ if (!is_one(op->predicate)) {
+ p->stream << " if ";
+ p->Print(op->predicate);
+ }
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Ramp>([](const Ramp* op, IRPrinter* p) {
+ p->stream << "ramp(";
+ p->Print(op->base);
+ p->stream << ", ";
+ p->Print(op->stride);
+ p->stream << ", " << op->lanes << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Broadcast>([](const Broadcast* op, IRPrinter* p) {
+ p->stream << "x" << op->lanes << "(";
+ p->Print(op->value);
+ p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Call>([](const Call* op, IRPrinter* p) {
+ p->stream << op->name << "(";
+ for (size_t i = 0; i < op->args.size(); ++i) {
+ p->Print(op->args[i]);
+ if (i < op->args.size() - 1) {
+ p->stream << ", ";
+ }
+ }
+ p->stream << ")";
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Let>([](const Let* op, IRPrinter* p) {
+ p->stream << "(let " << op->var << " = ";
+ p->Print(op->value);
+ p->stream << " in ";
+ p->Print(op->body);
+ p->stream << ")";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<LetStmt>([](const LetStmt* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "let " << op->var << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<AttrStmt>([](const AttrStmt* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "// attr [";
+ p->Print(op->node);
+ p->stream << "] "
+ << op->attr_key << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<AssertStmt>([](const AssertStmt* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "assert(";
+ p->Print(op->condition);
+ p->stream << ", ";
+ p->Print(op->message);
+ p->stream << ")\n";
+ p->Print(op->body);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ProducerConsumer>([](const ProducerConsumer* op, IRPrinter* p) {
+ if (op->is_producer) {
+ p->PrintIndent();
+ p->stream << "produce " << op->func->func_name() << " {\n";
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
+ p->PrintIndent();
+ p->stream << "}\n";
+ } else {
+ p->Print(op->body);
+ }
+ });
+
+std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
+ switch (type) {
+ case ForType::Serial:
+ out << "for";
+ break;
+ case ForType::Parallel:
+ out << "parallel";
+ break;
+ case ForType::Unrolled:
+ out << "unrolled";
+ break;
+ case ForType::Vectorized:
+ out << "vectorized";
+ break;
+ }
+ return out;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<For>([](const For* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << op->for_type << " (" << op->loop_var << ", ";
+ p->Print(op->min);
+ p->stream << ", ";
+ p->Print(op->extent);
+ p->stream << ") {\n";
+
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
+
+ p->PrintIndent();
+ p->stream << "}\n";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Store>([](const Store* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << op->buffer_var << "[";
+ p->Print(op->index);
+ p->stream << "] = ";
+ p->Print(op->value);
+ if (!is_one(op->predicate)) {
+ p->stream << " if ";
+ p->Print(op->predicate);
+ }
+ p->stream << '\n';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Provide>([](const Provide* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << op->func->func_name() << "(";
+ for (size_t i = 0; i < op->args.size(); ++i) {
+ p->Print(op->args[i]);
+ if (i < op->args.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (op->func->num_outputs() != 1) {
+ p->stream << ".value[" << op->value_index << "]";
+ }
+ p->stream << " =";
+ p->Print(op->value);
+ p->stream << '\n';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Allocate>([](const Allocate* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "allocate " << op->buffer_var << "[" << op->type;
+ for (size_t i = 0; i < op->extents.size(); ++i) {
+ p->stream << " * ";
+ p->Print(op->extents[i]);
+ }
+ p->stream << "]";
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << "\n";
+ p->Print(op->body);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Free>([](const Free* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "free " << op->buffer_var;
+ p->stream << '\n';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Realize>([](const Realize* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "realize " << op->func->func_name() << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (op->func->num_outputs() != 1) {
+ p->stream << ".value[" << op->value_index << "]";
+ }
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << " {\n";
+
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
+
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Prefetch>([](const Prefetch* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->stream << "prefetch " << op->func->func_name() << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (op->func->num_outputs() != 1) {
+ p->stream << ".value[" << op->value_index << "]";
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Block>([](const Block* op, IRPrinter* p) {
+ p->Print(op->first);
+ if (op->rest.defined()) p->Print(op->rest);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<IfThenElse>([](const IfThenElse* op, IRPrinter* p) {
+ p->PrintIndent();
+ while (true) {
+ p->stream << "if (" << op->condition << ") {\n";
+ p->indent += 2;
+ p->Print(op->then_case);
+ p->indent -= 2;
+
+ if (!op->else_case.defined()) {
+ break;
+ }
+
+ if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
+ p->PrintIndent();
+ p->stream << "} else ";
+ op = nested_if;
+ } else {
+ p->PrintIndent();
+ p->stream << "} else {\n";
+ p->indent += 2;
+ p->Print(op->else_case);
+ p->indent -= 2;
+ break;
+ }
+ }
+ p->PrintIndent();
+ p->stream << "}\n";
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Evaluate>([](const Evaluate* op, IRPrinter* p) {
+ p->PrintIndent();
+ p->Print(op->value);
+ p->stream << "\n";
+ });
+
+template<typename T>
+void PrintList(const Array<T> &exprs, IRPrinter* p) {
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ p->Print(exprs[i]);
+ if (i < exprs.size() - 1) {
+ p->stream << ", ";
+ }
+ }
+}
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Shuffle>([](const Shuffle* op, IRPrinter* p) {
+ p->stream << "shuffle(";
+ PrintList(op->vectors, p);
+ p->stream << ", ";
+ PrintList(op->indices, p);
+ p->stream << ")";
+ });
+
+// Container printer
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<ArrayNode>([](const ArrayNode* op, IRPrinter* p) {
+ p->stream << '[';
+ for (size_t i = 0 ; i < op->data.size(); ++i) {
+ if (i != 0) {
+ p->stream << ", ";
+ }
+ p->Print(NodeRef(op->data[i]));
+ }
+ p->stream << ']';
+});
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<MapNode>([](const MapNode* op, IRPrinter* p) {
+ p->stream << '{';
+ for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+ if (it != op->data.begin()) {
+ p->stream << ", ";
+ }
+ p->Print(NodeRef(it->first));
+ p->stream << ": ";
+ p->Print(NodeRef(it->second));
+ }
+ p->stream << '}';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<StrMapNode>([](const StrMapNode* op, IRPrinter* p) {
+ p->stream << '{';
+ for (auto it = op->data.begin(); it != op->data.end(); ++it) {
+ if (it != op->data.begin()) {
+ p->stream << ", ";
+ }
+ p->stream << '\"' << it->first << "\": ";
+ p->Print(NodeRef(it->second));
+ }
+ p->stream << '}';
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Reduce>([](const Reduce* op, IRPrinter* p) {
+ p->stream << "reduce(combiner="
+ << op->combiner;
+ p->stream << ", source=" << op->source;
+ p->stream << ", axis=" << op->axis;
+ p->stream << ", where=" << op->condition;
+ p->stream << ", value_index=" << op->value_index;
+ p->stream << ")";
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<CommReducerNode>([](const CommReducerNode* op, IRPrinter* p) {
+ p->stream << "comm_reducer(result=" << op->result
+ << ", lhs=" << op->lhs
+ << ", rhs=" << op->rhs
+ << ", identity_element=" << op->identity_element
+ << ")";
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
+ p->stream << "?";
+});
+
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(Any);
#include <tvm/tensor.h>
#include <tvm/operation.h>
#include <tvm/tensor_intrin.h>
-#include <ir/IR.h>
#include <memory>
namespace tvm {
// Tensor
-
Expr Tensor::operator()(Array<Var> indices) const {
Array<Expr> arr(indices.begin(), indices.end());
return operator()(arr);
}
Expr Tensor::operator()(Array<Expr> indices) const {
- using HalideIR::Internal::Call;
+ using ir::Call;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
-
auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * Implementation of Node API
+ * \file node.cc
+ */
+#include <tvm/node/node.h>
+#include <memory>
+#include <atomic>
+#include <mutex>
+#include <unordered_map>
+
+// TODO(tqchen):
+// Think of re-organize and consolidate with object.
+namespace tvm {
+
+namespace {
+// single manager of operator information.
+struct TypeManager {
+ // mutex to avoid registration from multiple threads.
+ // recursive is needed for trigger(which calls UpdateAttrMap)
+ std::mutex mutex;
+ std::atomic<uint32_t> type_counter{0};
+ std::unordered_map<std::string, uint32_t> key2index;
+ std::vector<std::string> index2key;
+ // get singleton of the
+ static TypeManager* Global() {
+ static TypeManager inst;
+ return &inst;
+ }
+};
+} // namespace
+
+TVM_DLL const bool Node::_DerivedFrom(uint32_t tid) const {
+ static uint32_t tindex = TypeKey2Index(Node::_type_key);
+ return tid == tindex;
+}
+
+// this is slow, usually caller always hold the result in a static variable.
+TVM_DLL uint32_t Node::TypeKey2Index(const char* key) {
+ TypeManager *t = TypeManager::Global();
+ std::lock_guard<std::mutex>(t->mutex);
+ std::string skey = key;
+ auto it = t->key2index.find(skey);
+ if (it != t->key2index.end()) {
+ return it->second;
+ }
+ uint32_t tid = ++(t->type_counter);
+ t->key2index[skey] = tid;
+ t->index2key.push_back(skey);
+ return tid;
+}
+
+TVM_DLL const char* Node::TypeIndex2Key(uint32_t index) {
+ TypeManager *t = TypeManager::Global();
+ std::lock_guard<std::mutex>(t->mutex);
+ CHECK_NE(index, 0);
+ return t->index2key.at(index - 1).c_str();
+}
+} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \brief Compute Op.
* \file compute_op.cc
*/
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
CHECK_EQ(stage->op.get(), this);
- HalideIR::Internal::Region bounds;
+ Region bounds;
for (IterVar iv : this->axis) {
bounds.push_back(realize_map.at(iv));
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \brief External computation rule.
* \file extern_op.cc
*/
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
- HalideIR::Internal::Region bounds;
+ Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2019 by Contributors
* \brief Hybrid computation rule.
* \file hybrid_op.cc
*/
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
-#include <ir/Expr.h>
#include <unordered_set>
#include <string>
#include <utility>
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Tensor t = stage->op.output(k);
- HalideIR::Internal::Region bounds;
+ Region bounds;
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(
Range::make_by_min_extent(
}
const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For::make(target->var, range->min, range->extent,
- for_type, HalideIR::DeviceAPI::None, body);
+ for_type, DeviceAPI::None, body);
}
};
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \brief Utility to make loop nest.
* \file op_util.cc
*/
*/
/*!
- * Copyright (c) 2017 by Contributors
* \brief Scan Operator.
* \file scan_op.cc
*/
for (size_t i = 0; i < update.size(); ++i) {
Tensor t = stage->op.output(i);
CHECK_EQ(static_cast<size_t>(t->value_index), i);
- HalideIR::Internal::Region bounds;
+ Region bounds;
bounds.push_back(tdom);
for (size_t k = 1; k < this->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
using arith::IntSet;
using arith::DomainTouched;
-using HalideIR::Internal::Region;
class PrefetchInjector : public IRMutator {
public:
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
return order_;
}
- int CompareRegion(const HalideIR::Internal::Region& lhs,
- const HalideIR::Internal::Region& rhs) {
+ int CompareRegion(const Region& lhs, const Region& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
IRMutator* m = this;
- HalideIR::Internal::Region new_bounds;
+ Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
IRMutator* m = this;
- HalideIR::Internal::Region new_bounds;
+ Region new_bounds;
bool bounds_changed = false;
// Mutate the bounds
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
+#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <vector>
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
- VarExpr buffer;
+ Var buffer = NullValue<Var>();
/*! \brief The access data type */
Type dtype;
/*! \brief The touched access range */
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
- bool double_buffer_write{false};
+ bool double_buffer_write = false;
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
namespace tvm {
namespace ir {
-using HalideIR::Internal::Region;
using runtime::StorageRank;
using runtime::StorageScope;
using runtime::ThreadScope;
}
// use small alignment for small arrays
- int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
+ int32_t const_size = Allocate::constant_allocation_size(shape);
int align = GetTempAllocaAlignment(op->type, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
for (int i = starts; i >= 0; --i) {
if (i < starts) {
stmt = For::make(
- vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::Host, stmt);
+ vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
Expr address = Call::make(Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
Expr prefetch = Call::make(op->type, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
stmt = Evaluate::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
- stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::Host, stmt);
+ stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
return stmt;
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
- auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
+ auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<ir::StringImm>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
- if (attrs->dtype == HalideIR::Int(32)) {
+ if (attrs->dtype == Int(32)) {
*rv = true;
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2018 by Contributors
* \file relay/backend/graph_codegen.cc
* \brief Graph runtime codegen
*/
* \param shape
* \return std::vector<int64_t>
*/
- std::vector<int64_t> _ShapeToJSON(tvm::Array<HalideIR::Expr> shape) {
+ std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
std::vector<int64_t> ret;
for (IndexExpr dim : shape) {
const int64_t* pval = as_const_int(dim);
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- Array<HalideIR::Expr> ret;
+ Array<tvm::Expr> ret;
for (const auto &kv : this->output_.params) {
- HalideIR::Expr name = ir::StringImm::make(kv.first);
+ tvm::Expr name = ir::StringImm::make(kv.first);
ret.push_back(name);
}
*rv = ret;
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
- p->print(node->type_annotation);
+ p->Print(node->type_annotation);
}
p->stream << ")";
});
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;
- Array<HalideIR::Expr> oshape;
+ Array<IndexExpr> oshape;
if (base_layout == "NCHW") {
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2016 by Contributors
* \file graph.cc
* \brief Utilities to get information about schedule graph.
*/
namespace schedule {
// key to specific tensor dimension.
struct TensorDimKey {
- FunctionRef f;
+ ir::FunctionRef f;
int value_index;
int dim;
TensorDimKey() {}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2016 by Contributors
* \file schedule_lang.cc
*/
#include <tvm/schedule.h>
})
.set_dispatch<SplitNode>([](const SplitNode *op, IRPrinter *p) {
p->stream << "split(parent=";
- p->print(op->parent);
+ p->Print(op->parent);
p->stream << ", outer=";
- p->print(op->outer);
+ p->Print(op->outer);
p->stream << ", inner=";
- p->print(op->inner);
+ p->Print(op->inner);
p->stream << ')';
})
.set_dispatch<FuseNode>([](const FuseNode *op, IRPrinter *p) {
p->stream << "split(";
p->stream << "outer=";
- p->print(op->outer);
+ p->Print(op->outer);
p->stream << ", inner=";
- p->print(op->inner);
+ p->Print(op->inner);
p->stream << ", fused=";
- p->print(op->fused);
+ p->Print(op->fused);
p->stream << ')';
})
.set_dispatch<RebaseNode>([](const RebaseNode *op, IRPrinter *p) {
p->stream << "rebase(";
p->stream << "parent=";
- p->print(op->parent);
+ p->Print(op->parent);
p->stream << ", rebased=";
- p->print(op->rebased);
+ p->Print(op->rebased);
p->stream << ')';
})
.set_dispatch<SingletonNode>([](const SingletonNode *op, IRPrinter *p) {
p->stream << "singleton(";
- p->print(op->iter);
+ p->Print(op->iter);
p->stream << ')';
})
.set_dispatch<ScheduleNode>([](const ScheduleNode *op, IRPrinter *p) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <tvm/expr_operator.h>
namespace {
+using namespace tvm;
using namespace tvm::ir;
-using namespace HalideIR::Internal;
-using namespace HalideIR;
// replace variable to constant
class IRVar2Const : public IRMutator {
public:
- VarExpr var;
+ Var var;
int int_val;
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRVar2Const::vtable_expr();
.set_dispatch<Variable>([](const Variable* op, const Expr &e, IRMutator* m) {
IRVar2Const* vm = static_cast<IRVar2Const*>(m);
if (e.same_as(vm->var)) {
- return IntImm::make(Int(32), vm->int_val);
+ return Expr(IntImm::make(Int(32), vm->int_val));
} else {
return e;
}
} // namespace
TEST(IRMutator, Basic) {
- using namespace HalideIR::Internal;
+ using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
TEST(IRSSA, Convert) {
- using namespace HalideIR::Internal;
using namespace tvm;
+ using namespace tvm::ir;
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
}
TEST(IRSSA, Basic) {
- using namespace HalideIR::Internal;
+ using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
auto z = Evaluate::make(x + y);
#include <tvm/ir_pass.h>
TEST(IRVisitor, CountVar) {
- using namespace HalideIR::Internal;
using namespace tvm;
int n_var = 0;
Var x("x"), y;
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file topi/image/resize.h
* \brief image resize constructors
*/
const Expr max_y, const Expr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
- auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
+ auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
- auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
+ auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3];
auto xf = tvm::floor(in_x);
- auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
+ auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
- auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
+ auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;
out_shape, [&](const Array<Var>& indices) {
auto in_y = indices[1] * y_ratio;
auto yf = tvm::floor(in_y);
- auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
+ auto yc = tvm::cast(Int(32), tvm::ceil(in_y));
- auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
+ auto y0 = tvm::cast(Int(32), tvm::floor(in_y));
auto y1 = tvm::if_then_else((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[2] * x_ratio;
auto xf = tvm::floor(in_x);
- auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
+ auto xc = tvm::cast(Int(32), tvm::ceil(in_x));
- auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
+ auto x0 = tvm::cast(Int(32), tvm::floor(in_x));
auto x1 = tvm::if_then_else((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf;
auto bid = out_index[1 - axis];
len_index.push_back(bid);
Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
- tvm::cast(data->dtype, Expr(mask_value)), data(out_index));
+ tvm::make_const(data->dtype, mask_value), data(out_index));
return ret;
}, name, tag);
return out;