* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
- const std::unordered_map<const Variable*, IntSet>& dom_map);
+ const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s,
- const std::unordered_map<const Variable*, IntSet>& dom_map);
+ const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r,
- const std::unordered_map<const Variable*, IntSet>& dom_map);
+ const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
*/
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
- const std::unordered_map<const Variable*, IntSet>& dom_map);
+ const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
- const std::unordered_map<const Variable*, IntSet>& hint_map,
- const std::unordered_map<const Variable*, IntSet>& relax_map);
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map);
/*!
* \brief Infer a regular domain that covers all the calls or provides within the given statement.
} else {
Expr expr = val;
CHECK(expr.defined());
- if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
- } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
+ } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
*ptr = val.operator std::string();
} else {
Expr expr = val;
- const ir::StringImm* op = expr.as<ir::StringImm>();
+ const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
}
} else {
Expr expr = val;
CHECK(expr.defined());
- if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
- } else if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+ } else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
- } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
+ } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
* - Let
* - LetStmt
*/
-class Variable : public ExprNode {
+class VarNode : public ExprNode {
public:
/*!
* \brief The hint to the variable name.
}
static constexpr const char* _type_key = "Variable";
- TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};
/*! \brief a named variable in TVM */
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const Variable* operator->() const {
+ const VarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
- const Variable* get() const {
- return static_cast<const Variable*>(data_.get());
+ const VarNode* get() const {
+ return static_cast<const VarNode*>(data_.get());
}
/*! \brief type indicate the container type */
- using ContainerType = Variable;
+ using ContainerType = VarNode;
};
// Backward compatibility, will be removed later.
class Integer;
/*! \brief ExprNode: constant integer. */
-class IntImm : public ExprNode {
+class IntImmNode : public ExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
};
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
- const IntImm* operator->() const {
- return static_cast<const IntImm*>(get());
+ const IntImmNode* operator->() const {
+ return static_cast<const IntImmNode*>(get());
}
/*!
* \brief convert to int64_t
return (*this)->value;
}
/*! \brief type indicate the container type */
- using ContainerType = IntImm;
+ using ContainerType = IntImmNode;
};
/*! \brief range over one dimension */
*/
inline const int64_t* as_const_int(const Expr& x) {
if (!x.defined()) return nullptr;
- if (const ir::IntImm* op = x.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
return &(op->value);
} else {
return nullptr;
*/
inline const uint64_t* as_const_uint(const Expr& x) {
if (!x.defined()) return nullptr;
- if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
+ if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
} else {
return nullptr;
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
- return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
+ return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
// Implementation details after this
inline bool is_const(const Expr& x) {
- if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
+ if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
return true;
- } else if (const auto* op = x.as<ir::Broadcast>()) {
+ } else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
- if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
+ if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
return true;
}
}
}
inline bool is_positive_const(const Expr& a) {
- if (const ir::IntImm* op = a.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
- } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
+ } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
return op->value > 0;
} else {
return false;
}
inline bool is_negative_const(const Expr& a) {
- if (const ir::IntImm* op = a.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value < 0;
} else {
return false;
}
inline bool is_const_int(const Expr& x, int64_t value) {
- if (const auto* op = x.as<ir::IntImm>()) {
+ if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
- } else if (const auto* op = x.as<ir::UIntImm>()) {
+ } else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
- } else if (const auto* op = x.as<ir::Broadcast>()) {
+ } else if (const auto* op = x.as<ir::BroadcastNode>()) {
const Expr& val = op->value;
- if (const auto* opv = val.as<ir::IntImm>()) {
+ if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
- } else if (const auto* opv = val.as<ir::UIntImm>()) {
+ } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
return opv->value == static_cast<uint64_t>(value);
}
}
inline bool is_no_op(const Stmt& stmt) {
if (!stmt.defined()) return true;
- if (const auto* op = stmt.as<ir::Evaluate>()) {
+ if (const auto* op = stmt.as<ir::EvaluateNode>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
template<typename ValueType>
inline Expr MakeConstScalar(DataType t, ValueType value) {
- if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
- if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
- if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
+ if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
+ if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
+ if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
- return ir::FloatImm::make(t, static_cast<double>(value));
+ return ir::FloatImmNode::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
return Expr();
}
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
- return ir::Broadcast::make(
+ return ir::BroadcastNode::make(
MakeConstScalar(t.element_of(), value), t.lanes());
}
}
namespace tvm {
namespace ir {
-using IntImm = tvm::IntImm;
-using Variable = tvm::Variable;
+using IntImmNode = tvm::IntImmNode;
+using VarNode = tvm::VarNode;
/*! \brief constant unsigned integer. */
-class UIntImm : public ExprNode {
+class UIntImmNode : public ExprNode {
public:
/*! \brief The constant value content. */
uint64_t value;
TVM_DLL static Expr make(DataType t, uint64_t value);
static constexpr const char* _type_key = "UIntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(UIntImm, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode);
};
/*! \brief Floating point constants. */
-class FloatImm : public ExprNode {
+class FloatImmNode : public ExprNode {
public:
/*! \brief The constant value content. */
double value;
TVM_DLL static Expr make(DataType t, double value);
static constexpr const char* _type_key = "FloatImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(FloatImm, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode);
};
/*! \brief String constants, only used in asserts. */
-class StringImm : public ExprNode {
+class StringImmNode : public ExprNode {
public:
/*! \brief The constant value content. */
std::string value;
TVM_DLL Expr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(StringImm, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode);
};
/*!
* \brief Cast value from one data type to another.
* \note The lanes of value should keep fixed.
*/
-class Cast : public ExprNode {
+class CastNode : public ExprNode {
public:
/*! \brief Original data type. */
Expr value;
TVM_DLL static Expr make(DataType t, Expr v);
static constexpr const char* _type_key = "Cast";
- TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode);
};
/*!
};
/*! \brief a + b */
-class Add : public BinaryOpNode<Add> {
+class AddNode : public BinaryOpNode<AddNode> {
public:
static constexpr const char* _type_key = "Add";
};
/*! \brief a - b */
-class Sub : public BinaryOpNode<Sub> {
+class SubNode : public BinaryOpNode<SubNode> {
public:
static constexpr const char* _type_key = "Sub";
};
/*! \brief a * b */
-class Mul : public BinaryOpNode<Mul> {
+class MulNode : public BinaryOpNode<MulNode> {
public:
static constexpr const char* _type_key = "Mul";
};
* \brief a / b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
-class Div : public BinaryOpNode<Div> {
+class DivNode : public BinaryOpNode<DivNode> {
public:
static constexpr const char* _type_key = "Div";
};
* \brief a % b in the C semnatics.
* \note For integer division, C standard uses trunc div.
*/
-class Mod : public BinaryOpNode<Mod> {
+class ModNode : public BinaryOpNode<ModNode> {
public:
static constexpr const char* _type_key = "Mod";
};
/*! \brief Floor division, floor(a/b) */
-class FloorDiv : public BinaryOpNode<FloorDiv> {
+class FloorDivNode : public BinaryOpNode<FloorDivNode> {
public:
static constexpr const char* _type_key = "FloorDiv";
};
/*! \brief The remainder of the floordiv */
-class FloorMod : public BinaryOpNode<FloorMod> {
+class FloorModNode : public BinaryOpNode<FloorModNode> {
public:
static constexpr const char* _type_key = "FloorMod";
};
/*! \brief min(a, b) */
-class Min : public BinaryOpNode<Min> {
+class MinNode : public BinaryOpNode<MinNode> {
public:
static constexpr const char* _type_key = "Min";
};
/*! \brief max(a, b) */
-class Max : public BinaryOpNode<Max> {
+class MaxNode : public BinaryOpNode<MaxNode> {
public:
static constexpr const char* _type_key = "Max";
};
};
/*! \brief a == b */
-class EQ : public CmpOpNode<EQ> {
+class EQNode : public CmpOpNode<EQNode> {
public:
static constexpr const char* _type_key = "EQ";
};
/*! \brief a != b */
-class NE : public CmpOpNode<NE> {
+class NENode : public CmpOpNode<NENode> {
public:
static constexpr const char* _type_key = "NE";
};
/*! \brief a < b */
-class LT : public CmpOpNode<LT> {
+class LTNode : public CmpOpNode<LTNode> {
public:
static constexpr const char* _type_key = "LT";
};
/*! \brief a <= b */
-struct LE : public CmpOpNode<LE> {
+struct LENode : public CmpOpNode<LENode> {
public:
static constexpr const char* _type_key = "LE";
};
/*! \brief a > b */
-class GT : public CmpOpNode<GT> {
+class GTNode : public CmpOpNode<GTNode> {
public:
static constexpr const char* _type_key = "GT";
};
/*! \brief a >= b */
-class GE : public CmpOpNode<GE> {
+class GENode : public CmpOpNode<GENode> {
public:
static constexpr const char* _type_key = "GE";
};
/*! \brief a && b */
-class And : public ExprNode {
+class AndNode : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
TVM_DLL static Expr make(Expr a, Expr b);
static constexpr const char* _type_key = "And";
- TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode);
};
/*! \brief a || b */
-class Or : public ExprNode {
+class OrNode : public ExprNode {
public:
/*! \brief The left operand. */
Expr a;
TVM_DLL static Expr make(Expr a, Expr b);
static constexpr const char* _type_key = "Or";
- TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode);
};
/*! \brief !a */
-class Not : public ExprNode {
+class NotNode : public ExprNode {
public:
/*! \brief The input operand. */
Expr a;
TVM_DLL static Expr make(Expr a);
static constexpr const char* _type_key = "Not";
- TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode);
};
/*!
* Do not use it to guard against out of bound access,
* please use if_then_else instead.
*/
-class Select : public ExprNode {
+class SelectNode : public ExprNode {
public:
/*! \brief The condition */
Expr condition;
TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
static constexpr const char* _type_key = "Select";
- TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode);
};
/*!
*
* \endcode
*/
-class Load : public ExprNode {
+class LoadNode : public ExprNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate);
static constexpr const char* _type_key = "Load";
- TVM_DECLARE_FINAL_OBJECT_INFO(Load, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode);
};
/*!
* - ramp(0, 1, 3) = [0, 1, 2]
* - ramp(1, 2, 4) = [1, 3, 5, 7]
*/
-class Ramp : public ExprNode {
+class RampNode : public ExprNode {
public:
/*! \brief The base value. */
Expr base;
TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
- TVM_DECLARE_FINAL_OBJECT_INFO(Ramp, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode);
};
/*! \brief Create a vector where all the elements are value. */
-class Broadcast : public ExprNode {
+class BroadcastNode : public ExprNode {
public:
/*! \brief The base value. */
Expr value;
TVM_DLL static Expr make(Expr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
- TVM_DECLARE_FINAL_OBJECT_INFO(Broadcast, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode);
};
/*!
* \brief Let binding. Bind var to value then evaluate body.
*/
-class Let : public ExprNode {
+class LetNode : public ExprNode {
public:
/*! \brief The variable. */
Var var;
TVM_DLL static Expr make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "Let";
- TVM_DECLARE_FINAL_OBJECT_INFO(Let, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
// Call node, represent a function call or a multi-dimensional array load.
/*!
* \brief Call node.
*/
-class Call : public ExprNode {
+class CallNode : public ExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
bool is_vectorizable() const;
static constexpr const char* _type_key = "Call";
- TVM_DECLARE_FINAL_OBJECT_INFO(Call, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
// Build-in intrinsics
static constexpr const char* reinterpret = "reinterpret";
* vec = concat(vectors)
* result = (vec[indices[0]], vec[indices[1]] ...)
*/
-class Shuffle : public ExprNode {
+class ShuffleNode : public ExprNode {
public:
/*! \brief the input vectors. */
Array<Expr> vectors;
TVM_DLL static Expr make_extract_element(Expr vector, int index);
static constexpr const char* _type_key = "Shuffle";
- TVM_DECLARE_FINAL_OBJECT_INFO(Shuffle, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode);
};
// Reduce operator
}
/*! \brief Reduction operator operator */
-class Reduce : public ExprNode {
+class ReduceNode : public ExprNode {
public:
/*! \brief The commutative combiner */
CommReducer combiner;
}
static constexpr const char* _type_key = "Reduce";
- TVM_DECLARE_FINAL_OBJECT_INFO(Reduce, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode);
};
/*! \brief Any shape. */
-class Any : public ExprNode {
+class AnyNode : public ExprNode {
public:
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
Var ToVar() const {
- return Variable::make(DataType::Int(32), "any_dim");
+ return VarNode::make(DataType::Int(32), "any_dim");
}
TVM_DLL static Expr make();
static constexpr const char* _type_key = "Any";
- TVM_DECLARE_FINAL_OBJECT_INFO(Any, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode);
};
// Statements
/*!
* \brief Let binding, bind var to value, then run body.
*/
-class LetStmt : public StmtNode {
+class LetStmtNode : public StmtNode {
public:
/*! \brief The variable. */
Var var;
TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
- TVM_DECLARE_FINAL_OBJECT_INFO(LetStmt, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};
/*!
* - Bound of function, variables.
* - Hint which block corresponds to a parallel region.
*/
-class AttrStmt : public StmtNode {
+class AttrStmtNode : public StmtNode {
public:
/*! \brief this is attribute about certain node */
ObjectRef node;
Stmt body);
static constexpr const char* _type_key = "AttrStmt";
- TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmt, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};
/*!
* \brief Assert condition, if an error occurs, return the error message.
*/
-class AssertStmt : public StmtNode {
+class AssertStmtNode : public StmtNode {
public:
/*! \brief Condition to be checked. */
Expr condition;
TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
- TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmt, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};
// TODO(tvm-team): consider consolidate with AttrStmt.
/*! \brief annotation node of producer/consumer relation. */
-class ProducerConsumer : public StmtNode {
+class ProducerConsumerNode : public StmtNode {
public:
/*! \brief The corresponding tensor. */
FunctionRef func;
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
- TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumer, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
};
/*!
* buffer[index.v2] = value.v2;
*
* \endcode
- * \sa Load
+ * \sa LoadNode
*/
-class Store : public StmtNode {
+class StoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
Expr predicate);
static constexpr const char* _type_key = "Store";
- TVM_DECLARE_FINAL_OBJECT_INFO(Store, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};
/*!
* \brief Store value into mult-dimensional array defined by func.
*/
-class Provide : public StmtNode {
+class ProvideNode : public StmtNode {
public:
/*! \brief The function to be updated. */
FunctionRef func;
Array<Expr> args);
static constexpr const char* _type_key = "Provide";
- TVM_DECLARE_FINAL_OBJECT_INFO(Provide, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
};
/*!
* \brief Allocate a buffer that can be used in body.
*/
-class Allocate : public StmtNode {
+class AllocateNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
const Array<Expr>& extents);
static constexpr const char* _type_key = "Allocate";
- TVM_DECLARE_FINAL_OBJECT_INFO(Allocate, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};
/*! \brief Free the resources in the buffer before the scope ends. */
-class Free : public StmtNode {
+class FreeNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
- TVM_DECLARE_FINAL_OBJECT_INFO(Free, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};
/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
*/
-class Realize : public StmtNode {
+class RealizeNode : public StmtNode {
public:
/*! \brief The function to be realized. */
FunctionRef func;
Stmt body);
static constexpr const char* _type_key = "Realize";
- TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};
/*!
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
- } else if (auto* op = stmt.as<ProducerConsumer>()) {
+ } else if (auto* op = stmt.as<ProducerConsumerNode>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
/*!
* \brief IfThenElse statment.
*/
-class IfThenElse : public StmtNode {
+class IfThenElseNode : public StmtNode {
public:
/*! \brief The condition. */
Expr condition;
TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
- TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElse, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};
/*!
*
* If value do not have side-effect, this node can be safely removed.
*/
-class Evaluate : public StmtNode {
+class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
Expr value;
TVM_DLL static Stmt make(Expr v);
static constexpr const char* _type_key = "Evaluate";
- TVM_DECLARE_FINAL_OBJECT_INFO(Evaluate, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};
/*! \brief Additional annotation of for loop. */
* }
* \endcode
*/
-class For : public StmtNode {
+class ForNode : public StmtNode {
public:
/*! \brief The loop variable. */
Var loop_var;
}
static constexpr const char* _type_key = "For";
- TVM_DECLARE_FINAL_OBJECT_INFO(For, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
/*!
* \brief A prefetch hint of func.
*/
-class Prefetch : public StmtNode {
+class PrefetchNode : public StmtNode {
public:
/*! \brief The function to be prefetched. */
FunctionRef func;
Region bounds);
static constexpr const char* _type_key = "Prefetch";
- TVM_DECLARE_FINAL_OBJECT_INFO(Prefetch, StmtNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};
/*!
* \return Expr a expression with dtype.
*/
inline Expr TypeAnnotation(DataType dtype) {
- return ir::Call::make(dtype,
+ return ir::CallNode::make(dtype,
"type_annotation", {},
- ir::Call::PureIntrinsic);
+ ir::CallNode::PureIntrinsic);
}
// overload printing of for type.
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
- virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+ virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
static FType InitVTable() {
FType vtable;
// Set dispatch
- IR_EXPR_FUNCTOR_DISPATCH(Variable);
- IR_EXPR_FUNCTOR_DISPATCH(Load);
- IR_EXPR_FUNCTOR_DISPATCH(Let);
- IR_EXPR_FUNCTOR_DISPATCH(Call);
- IR_EXPR_FUNCTOR_DISPATCH(Add);
- IR_EXPR_FUNCTOR_DISPATCH(Sub);
- IR_EXPR_FUNCTOR_DISPATCH(Mul);
- IR_EXPR_FUNCTOR_DISPATCH(Div);
- IR_EXPR_FUNCTOR_DISPATCH(Mod);
- IR_EXPR_FUNCTOR_DISPATCH(FloorDiv);
- IR_EXPR_FUNCTOR_DISPATCH(FloorMod);
- IR_EXPR_FUNCTOR_DISPATCH(Min);
- IR_EXPR_FUNCTOR_DISPATCH(Max);
- IR_EXPR_FUNCTOR_DISPATCH(EQ);
- IR_EXPR_FUNCTOR_DISPATCH(NE);
- IR_EXPR_FUNCTOR_DISPATCH(LT);
- IR_EXPR_FUNCTOR_DISPATCH(LE);
- IR_EXPR_FUNCTOR_DISPATCH(GT);
- IR_EXPR_FUNCTOR_DISPATCH(GE);
- IR_EXPR_FUNCTOR_DISPATCH(And);
- IR_EXPR_FUNCTOR_DISPATCH(Or);
- IR_EXPR_FUNCTOR_DISPATCH(Reduce);
- IR_EXPR_FUNCTOR_DISPATCH(Cast);
- IR_EXPR_FUNCTOR_DISPATCH(Not);
- IR_EXPR_FUNCTOR_DISPATCH(Select);
- IR_EXPR_FUNCTOR_DISPATCH(Ramp);
- IR_EXPR_FUNCTOR_DISPATCH(Shuffle);
- IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
- IR_EXPR_FUNCTOR_DISPATCH(IntImm);
- IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
- IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
- IR_EXPR_FUNCTOR_DISPATCH(StringImm);
+ IR_EXPR_FUNCTOR_DISPATCH(VarNode);
+ IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
+ IR_EXPR_FUNCTOR_DISPATCH(LetNode);
+ IR_EXPR_FUNCTOR_DISPATCH(CallNode);
+ IR_EXPR_FUNCTOR_DISPATCH(AddNode);
+ IR_EXPR_FUNCTOR_DISPATCH(SubNode);
+ IR_EXPR_FUNCTOR_DISPATCH(MulNode);
+ IR_EXPR_FUNCTOR_DISPATCH(DivNode);
+ IR_EXPR_FUNCTOR_DISPATCH(ModNode);
+ IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode);
+ IR_EXPR_FUNCTOR_DISPATCH(FloorModNode);
+ IR_EXPR_FUNCTOR_DISPATCH(MinNode);
+ IR_EXPR_FUNCTOR_DISPATCH(MaxNode);
+ IR_EXPR_FUNCTOR_DISPATCH(EQNode);
+ IR_EXPR_FUNCTOR_DISPATCH(NENode);
+ IR_EXPR_FUNCTOR_DISPATCH(LTNode);
+ IR_EXPR_FUNCTOR_DISPATCH(LENode);
+ IR_EXPR_FUNCTOR_DISPATCH(GTNode);
+ IR_EXPR_FUNCTOR_DISPATCH(GENode);
+ IR_EXPR_FUNCTOR_DISPATCH(AndNode);
+ IR_EXPR_FUNCTOR_DISPATCH(OrNode);
+ IR_EXPR_FUNCTOR_DISPATCH(ReduceNode);
+ IR_EXPR_FUNCTOR_DISPATCH(CastNode);
+ IR_EXPR_FUNCTOR_DISPATCH(NotNode);
+ IR_EXPR_FUNCTOR_DISPATCH(SelectNode);
+ IR_EXPR_FUNCTOR_DISPATCH(RampNode);
+ IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
+ IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
+ IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
+ IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
+ IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
+ IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
return vtable;
}
};
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
- virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
- virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
// initialize the vtable.
static FType InitVTable() {
FType vtable;
- IR_STMT_FUNCTOR_DISPATCH(LetStmt);
- IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
- IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
- IR_STMT_FUNCTOR_DISPATCH(For);
- IR_STMT_FUNCTOR_DISPATCH(Allocate);
- IR_STMT_FUNCTOR_DISPATCH(Store);
- IR_STMT_FUNCTOR_DISPATCH(Free);
- IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
- IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
- IR_STMT_FUNCTOR_DISPATCH(Provide);
- IR_STMT_FUNCTOR_DISPATCH(Realize);
- IR_STMT_FUNCTOR_DISPATCH(Prefetch);
+ IR_STMT_FUNCTOR_DISPATCH(LetStmtNode);
+ IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
+ IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
+ IR_STMT_FUNCTOR_DISPATCH(ForNode);
+ IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
+ IR_STMT_FUNCTOR_DISPATCH(StoreNode);
+ IR_STMT_FUNCTOR_DISPATCH(FreeNode);
+ IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
+ IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
+ IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
+ IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
+ IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
- IR_STMT_FUNCTOR_DISPATCH(Evaluate);
+ IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
return vtable;
}
};
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
- void VisitExpr_(const Variable* op) override;
- void VisitExpr_(const Load* op) override;
- void VisitExpr_(const Let* op) override;
- void VisitExpr_(const Call* op) override;
- void VisitExpr_(const Add* op) override;
- void VisitExpr_(const Sub* op) override;
- void VisitExpr_(const Mul* op) override;
- void VisitExpr_(const Div* op) override;
- void VisitExpr_(const Mod* op) override;
- void VisitExpr_(const FloorDiv* op) override;
- void VisitExpr_(const FloorMod* op) override;
- void VisitExpr_(const Min* op) override;
- void VisitExpr_(const Max* op) override;
- void VisitExpr_(const EQ* op) override;
- void VisitExpr_(const NE* op) override;
- void VisitExpr_(const LT* op) override;
- void VisitExpr_(const LE* op) override;
- void VisitExpr_(const GT* op) override;
- void VisitExpr_(const GE* op) override;
- void VisitExpr_(const And* op) override;
- void VisitExpr_(const Or* op) override;
- void VisitExpr_(const Reduce* op) override;
- void VisitExpr_(const Cast* op) override;
- void VisitExpr_(const Not* op) override;
- void VisitExpr_(const Select* op) override;
- void VisitExpr_(const Ramp* op) override;
- void VisitExpr_(const Broadcast* op) override;
- void VisitExpr_(const Shuffle* op) override;
- void VisitExpr_(const IntImm* op) override;
- void VisitExpr_(const UIntImm* op) override;
- void VisitExpr_(const FloatImm* op) override;
- void VisitExpr_(const StringImm* op) override;
+ void VisitExpr_(const VarNode* op) override;
+ void VisitExpr_(const LoadNode* op) override;
+ void VisitExpr_(const LetNode* op) override;
+ void VisitExpr_(const CallNode* op) override;
+ void VisitExpr_(const AddNode* op) override;
+ void VisitExpr_(const SubNode* op) override;
+ void VisitExpr_(const MulNode* op) override;
+ void VisitExpr_(const DivNode* op) override;
+ void VisitExpr_(const ModNode* op) override;
+ void VisitExpr_(const FloorDivNode* op) override;
+ void VisitExpr_(const FloorModNode* op) override;
+ void VisitExpr_(const MinNode* op) override;
+ void VisitExpr_(const MaxNode* op) override;
+ void VisitExpr_(const EQNode* op) override;
+ void VisitExpr_(const NENode* op) override;
+ void VisitExpr_(const LTNode* op) override;
+ void VisitExpr_(const LENode* op) override;
+ void VisitExpr_(const GTNode* op) override;
+ void VisitExpr_(const GENode* op) override;
+ void VisitExpr_(const AndNode* op) override;
+ void VisitExpr_(const OrNode* op) override;
+ void VisitExpr_(const ReduceNode* op) override;
+ void VisitExpr_(const CastNode* op) override;
+ void VisitExpr_(const NotNode* op) override;
+ void VisitExpr_(const SelectNode* op) override;
+ void VisitExpr_(const RampNode* op) override;
+ void VisitExpr_(const BroadcastNode* op) override;
+ void VisitExpr_(const ShuffleNode* op) override;
+ void VisitExpr_(const IntImmNode* op) override;
+ void VisitExpr_(const UIntImmNode* op) override;
+ void VisitExpr_(const FloatImmNode* op) override;
+ void VisitExpr_(const StringImmNode* op) override;
};
/*!
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
- Expr VisitExpr_(const Variable* op) override;
- Expr VisitExpr_(const Load* op) override;
- Expr VisitExpr_(const Let* op) override;
- Expr VisitExpr_(const Call* op) override;
- Expr VisitExpr_(const Add* op) override;
- Expr VisitExpr_(const Sub* op) override;
- Expr VisitExpr_(const Mul* op) override;
- Expr VisitExpr_(const Div* op) override;
- Expr VisitExpr_(const Mod* op) override;
- Expr VisitExpr_(const FloorDiv* op) override;
- Expr VisitExpr_(const FloorMod* op) override;
- Expr VisitExpr_(const Min* op) override;
- Expr VisitExpr_(const Max* op) override;
- Expr VisitExpr_(const EQ* op) override;
- Expr VisitExpr_(const NE* op) override;
- Expr VisitExpr_(const LT* op) override;
- Expr VisitExpr_(const LE* op) override;
- Expr VisitExpr_(const GT* op) override;
- Expr VisitExpr_(const GE* op) override;
- Expr VisitExpr_(const And* op) override;
- Expr VisitExpr_(const Or* op) override;
- Expr VisitExpr_(const Reduce* op) override;
- Expr VisitExpr_(const Cast* op) override;
- Expr VisitExpr_(const Not* op) override;
- Expr VisitExpr_(const Select* op) override;
- Expr VisitExpr_(const Ramp* op) override;
- Expr VisitExpr_(const Broadcast* op) override;
- Expr VisitExpr_(const Shuffle* op) override;
- Expr VisitExpr_(const IntImm* op) override;
- Expr VisitExpr_(const UIntImm* op) override;
- Expr VisitExpr_(const FloatImm* op) override;
- Expr VisitExpr_(const StringImm* op) override;
+ Expr VisitExpr_(const VarNode* op) override;
+ Expr VisitExpr_(const LoadNode* op) override;
+ Expr VisitExpr_(const LetNode* op) override;
+ Expr VisitExpr_(const CallNode* op) override;
+ Expr VisitExpr_(const AddNode* op) override;
+ Expr VisitExpr_(const SubNode* op) override;
+ Expr VisitExpr_(const MulNode* op) override;
+ Expr VisitExpr_(const DivNode* op) override;
+ Expr VisitExpr_(const ModNode* op) override;
+ Expr VisitExpr_(const FloorDivNode* op) override;
+ Expr VisitExpr_(const FloorModNode* op) override;
+ Expr VisitExpr_(const MinNode* op) override;
+ Expr VisitExpr_(const MaxNode* op) override;
+ Expr VisitExpr_(const EQNode* op) override;
+ Expr VisitExpr_(const NENode* op) override;
+ Expr VisitExpr_(const LTNode* op) override;
+ Expr VisitExpr_(const LENode* op) override;
+ Expr VisitExpr_(const GTNode* op) override;
+ Expr VisitExpr_(const GENode* op) override;
+ Expr VisitExpr_(const AndNode* op) override;
+ Expr VisitExpr_(const OrNode* op) override;
+ Expr VisitExpr_(const ReduceNode* op) override;
+ Expr VisitExpr_(const CastNode* op) override;
+ Expr VisitExpr_(const NotNode* op) override;
+ Expr VisitExpr_(const SelectNode* op) override;
+ Expr VisitExpr_(const RampNode* op) override;
+ Expr VisitExpr_(const BroadcastNode* op) override;
+ Expr VisitExpr_(const ShuffleNode* op) override;
+ Expr VisitExpr_(const IntImmNode* op) override;
+ Expr VisitExpr_(const UIntImmNode* op) override;
+ Expr VisitExpr_(const FloatImmNode* op) override;
+ Expr VisitExpr_(const StringImmNode* op) override;
};
/*!
*/
virtual void VisitExpr(const Expr& e) {}
// statement visitor
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const IfThenElse* op) override;
- void VisitStmt_(const LetStmt* op) override;
- void VisitStmt_(const For* op) override;
- void VisitStmt_(const Allocate* op) override;
- void VisitStmt_(const Store* op) override;
- void VisitStmt_(const Free* op) override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
- void VisitStmt_(const Provide* op) override;
- void VisitStmt_(const Realize* op) override;
- void VisitStmt_(const Prefetch* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const IfThenElseNode* op) override;
+ void VisitStmt_(const LetStmtNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitStmt_(const FreeNode* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const ProducerConsumerNode* op) override;
+ void VisitStmt_(const ProvideNode* op) override;
+ void VisitStmt_(const RealizeNode* op) override;
+ void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
- void VisitStmt_(const Evaluate* op) override;
+ void VisitStmt_(const EvaluateNode* op) override;
};
/*!
return e;
}
// statement visitor
- Stmt VisitStmt_(const AttrStmt* op) override;
- Stmt VisitStmt_(const IfThenElse* op) override;
- Stmt VisitStmt_(const LetStmt* op) override;
- Stmt VisitStmt_(const For* op) override;
- Stmt VisitStmt_(const Allocate* op) override;
- Stmt VisitStmt_(const Store* op) override;
- Stmt VisitStmt_(const Free* op) override;
- Stmt VisitStmt_(const AssertStmt* op) override;
- Stmt VisitStmt_(const ProducerConsumer* op) override;
- Stmt VisitStmt_(const Provide* op) override;
- Stmt VisitStmt_(const Realize* op) override;
- Stmt VisitStmt_(const Prefetch* op) override;
+ Stmt VisitStmt_(const AttrStmtNode* op) override;
+ Stmt VisitStmt_(const IfThenElseNode* op) override;
+ Stmt VisitStmt_(const LetStmtNode* op) override;
+ Stmt VisitStmt_(const ForNode* op) override;
+ Stmt VisitStmt_(const AllocateNode* op) override;
+ Stmt VisitStmt_(const StoreNode* op) override;
+ Stmt VisitStmt_(const FreeNode* op) override;
+ Stmt VisitStmt_(const AssertStmtNode* op) override;
+ Stmt VisitStmt_(const ProducerConsumerNode* op) override;
+ Stmt VisitStmt_(const ProvideNode* op) override;
+ Stmt VisitStmt_(const RealizeNode* op) override;
+ Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
- Stmt VisitStmt_(const Evaluate* op) override;
+ Stmt VisitStmt_(const EvaluateNode* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
* \param vset The variable set.
* \return Whether e uses vset.
*/
-bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
+bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
/*!
* \brief Convert a IR node to be SSA form.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
- const std::unordered_map<const Variable*, Expr>& value_map);
+ const std::unordered_map<const VarNode*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \return The converted expression.
*/
Expr Substitute(Expr expr,
- const std::unordered_map<const Variable*, Expr>& value_map);
+ const std::unordered_map<const VarNode*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
virtual void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
const Stage& stage,
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(
const Stage& stage,
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
void PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(
const Operation& self,
namespace tvm {
namespace relay {
-using Any = tvm::ir::Any;
+using Any = tvm::ir::AnyNode;
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) {
- return Variable::make(t, s);
+ return VarNode::make(t, s);
});
TVM_REGISTER_GLOBAL("make.abs")
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) {
- return For::make(loop_var,
+ return ForNode::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
.set_body([](TVMArgs args, TVMRetValue *ret) {
DataType t = args[0];
if (args.size() == 3) {
- *ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
+ *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
} else {
- *ret = Load::make(t, args[1], args[2], args[3]);
+ *ret = LoadNode::make(t, args[1], args[2], args[3]);
}
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr value = args[1];
if (args.size() == 3) {
- *ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+ *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
- *ret = Store::make(args[0], value, args[2], args[3]);
+ *ret = StoreNode::make(args[0], value, args[2], args[3]);
}
});
TVM_REGISTER_GLOBAL("make.Realize")
-.set_body_typed(Realize::make);
+.set_body_typed(RealizeNode::make);
TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed([](
Array<Expr> args, int call_type,
FunctionRef func, int value_index
) {
- return Call::make(type,
+ return CallNode::make(type,
name,
args,
- static_cast<Call::CallType>(call_type),
+ static_cast<CallNode::CallType>(call_type),
func,
value_index);
});
.set_body_typed(CommReducerNode::make);
// make from two arguments
-#define REGISTER_MAKE(Node) \
- TVM_REGISTER_GLOBAL("make."#Node) \
- .set_body_typed(Node::make); \
+#define REGISTER_MAKE(NodeName) \
+ TVM_REGISTER_GLOBAL("make."#NodeName) \
+ .set_body_typed(NodeName ## Node::make); \
+
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
.set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){
- return Allocate::make(buffer_var, type, extents, condition, body);
+ return AllocateNode::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
});
TVM_REGISTER_GLOBAL("_str")
-.set_body_typed(ir::StringImm::make);
+.set_body_typed(ir::StringImmNode::make);
TVM_REGISTER_GLOBAL("_Array")
auto* n = static_cast<const StrMapNode*>(ptr);
auto rkvs = make_object<ArrayNode>();
for (const auto& kv : n->data) {
- rkvs->data.push_back(ir::StringImm::make(kv.first));
+ rkvs->data.push_back(ir::StringImmNode::make(kv.first));
rkvs->data.push_back(kv.second);
}
*ret = Array<ObjectRef>(rkvs);
}
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
- if (const auto* ptr = expr.as<ir::IntImm>()) {
+ if (const auto* ptr = expr.as<ir::IntImmNode>()) {
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
}
bool Analyzer::CanProve(const Expr& expr) {
- if (const auto* ptr = expr.as<ir::UIntImm>()) {
+ if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
- if (const auto* ptr = res.as<ir::UIntImm>()) {
+ if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
- if (const auto* ptr = res.as<ir::UIntImm>()) {
+ if (const auto* ptr = res.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
return false;
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Expr target, Expr expr,
- const std::unordered_map<const Variable*, IntSet>& hint_map,
- const std::unordered_map<const Variable*, IntSet>& relax_map)
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce();
}
}
- void VisitExpr_(const LT* op) final {
+ void VisitExpr_(const LTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
- void VisitExpr_(const LE* op) final {
+ void VisitExpr_(const LENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
- void VisitExpr_(const GT* op) final {
+ void VisitExpr_(const GTNode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
- void VisitExpr_(const GE* op) final {
+ void VisitExpr_(const GENode* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}
- void VisitExpr_(const Add* op) final {
+ void VisitExpr_(const AddNode* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
this->VisitExpr(left ? op->a : op->b);
}
- void VisitExpr_(const Sub* op) final {
+ void VisitExpr_(const SubNode* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
this->VisitExpr(left ? op->a : op->b);
}
- void VisitExpr_(const Mul* op) final {
+ void VisitExpr_(const MulNode* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
CompareOp ReverseOp(CompareOp comp_op);
Expr target_;
Expr expr_;
- const std::unordered_map<const Variable*, IntSet>& hint_map_;
- const std::unordered_map<const Variable*, IntSet>& relax_map_;
+ const std::unordered_map<const VarNode*, IntSet>& hint_map_;
+ const std::unordered_map<const VarNode*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
std::vector<const Object*> path_;
size_t iter_{0};
void BoundDeducer::Transform() {
// We will ensure to set expr_ such that it contains target_
- if (const LT* op = expr_.as<LT>()) {
+ if (const LTNode* op = expr_.as<LTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
comp_op = kGreater;
expr_ = op->a;
result_ = op->b - 1;
}
- } else if (const LE* op = expr_.as<LE>()) {
+ } else if (const LENode* op = expr_.as<LENode>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
comp_op = kGreater;
expr_ = op->a;
result_ = op->b;
}
- } else if (const GT* op = expr_.as<GT>()) {
+ } else if (const GTNode* op = expr_.as<GTNode>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
comp_op = kLess;
expr_ = op->a;
result_ = op->b + 1;
}
- } else if (const GE* op = expr_.as<GE>()) {
+ } else if (const GENode* op = expr_.as<GENode>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
comp_op = kLess;
expr_ = op->a;
result_ = op->b;
}
- } else if (const EQ* op = expr_.as<EQ>()) {
+ } else if (const EQNode* op = expr_.as<EQNode>()) {
comp_op = kEqual;
if (GetPath(target_, op->a).empty()) {
// if the b == a -> a == b
}
IntSet DeduceBound(Expr v, Expr e,
- const std::unordered_map<const Variable*, IntSet>& hint_map,
- const std::unordered_map<const Variable*, IntSet>& relax_map) {
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
IntSet DeduceBound(Expr v, Expr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
- std::unordered_map<const Variable*, IntSet> hmap;
+ std::unordered_map<const VarNode*, IntSet> hmap;
for (auto kv : hint_map) {
hmap[kv.first.get()] = kv.second;
}
- std::unordered_map<const Variable*, IntSet> rmap;
+ std::unordered_map<const VarNode*, IntSet> rmap;
for (auto kv : relax_map) {
rmap[kv.first.get()] = kv.second;
}
}
using Rewriter::VisitExpr_;
- Expr VisitExpr_(const Add* op) final;
- Expr VisitExpr_(const Sub* op) final;
- Expr VisitExpr_(const Mul* op) final;
- Expr VisitExpr_(const Div* op) final;
- Expr VisitExpr_(const Mod* op) final;
- Expr VisitExpr_(const FloorDiv* op) final;
- Expr VisitExpr_(const FloorMod* op) final;
- Expr VisitExpr_(const Reduce* op) final;
+ Expr VisitExpr_(const AddNode* op) final;
+ Expr VisitExpr_(const SubNode* op) final;
+ Expr VisitExpr_(const MulNode* op) final;
+ Expr VisitExpr_(const DivNode* op) final;
+ Expr VisitExpr_(const ModNode* op) final;
+ Expr VisitExpr_(const FloorDivNode* op) final;
+ Expr VisitExpr_(const FloorModNode* op) final;
+ Expr VisitExpr_(const ReduceNode* op) final;
private:
/*!
}
ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
n->dtype = expr.dtype();
- if (const auto* op = expr.as<IntImm>()) {
+ if (const auto* op = expr.as<IntImmNode>()) {
n->base = op->value;
return SumExpr(n);
} else {
}
}
// Simplify the combiner used in reduce.
- Expr SimplifyReduceCombiner(const Reduce* op);
+ Expr SimplifyReduceCombiner(const ReduceNode* op);
};
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Add* op) {
+VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<Add>(a, b);
+ Expr const_res = TryConstFold<AddNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
- if (const auto* op = b.as<IntImm>()) {
+ if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Sub* op) {
+VisitExpr_(const SubNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<Sub>(a, b);
+ Expr const_res = TryConstFold<SubNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
SumExpr ret = ToSumExpr(std::move(a));
- if (const auto* op = b.as<IntImm>()) {
+ if (const auto* op = b.as<IntImmNode>()) {
ret.CopyOnWrite()->AddToSelf(-op->value);
} else if (const auto* op = b.as<SumExprNode>()) {
ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Mul* op) {
+VisitExpr_(const MulNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<Mul>(a, b);
+ Expr const_res = TryConstFold<MulNode>(a, b);
if (const_res.defined()) return const_res;
// x * c
- if (a.as<IntImm>()) {
+ if (a.as<IntImmNode>()) {
std::swap(a, b);
}
- if (const auto* bconst = b.as<IntImm>()) {
+ if (const auto* bconst = b.as<IntImmNode>()) {
if (a.as<SumExprNode>()) {
SumExpr ret = Downcast<SumExpr>(std::move(a));
ret.CopyOnWrite()->MulToSelf(bconst->value);
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
- return Mul::make(a, b);
+ return MulNode::make(a, b);
}
}
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Div* op) {
+VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<Div>(a, b);
+ Expr const_res = TryConstFold<DivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
- if (const auto* pconst = temp.as<IntImm>()) {
+ if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
- return Div::make(a, b);
+ return DivNode::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorDiv* op) {
+VisitExpr_(const FloorDivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<FloorDiv>(a, b);
+ Expr const_res = TryConstFold<FloorDivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
// continue simplification.
lhs.CopyOnWrite()->DivideBy(cval);
Expr temp = Normalize(extra);
- if (const auto* pconst = temp.as<IntImm>()) {
+ if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
- return FloorDiv::make(a, b);
+ return FloorDivNode::make(a, b);
}
}
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Mod* op) {
+VisitExpr_(const ModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<Mod>(a, b);
+ Expr const_res = TryConstFold<ModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
Expr temp = Normalize(extra);
- if (temp.as<IntImm>()) {
+ if (temp.as<IntImmNode>()) {
return truncmod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
- return Mod::make(a, b);
+ return ModNode::make(a, b);
}
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorMod* op) {
+VisitExpr_(const FloorModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
Expr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<FloorMod>(a, b);
+ Expr const_res = TryConstFold<FloorModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
Expr temp = Normalize(extra);
- if (temp.as<IntImm>()) {
+ if (temp.as<IntImmNode>()) {
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (op->a.same_as(a) && op->b.same_as(b)) {
return GetRef<Expr>(op);
} else {
- return FloorMod::make(a, b);
+ return FloorModNode::make(a, b);
}
}
// Simplify reduce expression.
Expr CanonicalSimplifier::Impl::
-SimplifyReduceCombiner(const Reduce* op) {
+SimplifyReduceCombiner(const ReduceNode* op) {
// First simplify the results
Array<Expr> simplified_result;
for (const auto& res : op->combiner->result) {
CommReducer new_combiner =
CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
- return Reduce::make(
+ return ReduceNode::make(
new_combiner, new_source, op->axis, op->condition, new_value_index);
}
Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Reduce* op) {
+VisitExpr_(const ReduceNode* op) {
// Recursively call simplification when necessary.
Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
- op = ret.as<Reduce>();
+ op = ret.as<ReduceNode>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
if (op->axis.empty()) {
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
// instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
return this->VisitExpr(
- Select::make(op->condition,
+ SelectNode::make(op->condition,
op->source[op->value_index],
op->combiner->identity_element[op->value_index]));
}
}
template<>
-inline Expr Compute<ir::Add>(Expr a, Expr b) {
+inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
return a + b;
}
template<>
-inline Expr Compute<ir::Sub>(Expr a, Expr b) {
+inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
return a - b;
}
template<>
-inline Expr Compute<ir::Mul>(Expr a, Expr b) {
+inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
return a * b;
}
template<>
-inline Expr Compute<ir::Div>(Expr a, Expr b) {
+inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
return truncdiv(a, b);
}
template<>
-inline Expr Compute<ir::Mod>(Expr a, Expr b) {
+inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
return truncmod(a, b);
}
template<>
-inline Expr Compute<ir::Max>(Expr a, Expr b) {
+inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
return max(a, b);
}
template<>
-inline Expr Compute<ir::Min>(Expr a, Expr b) {
+inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
return min(a, b);
}
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
- using ir::IntImm; \
- using ir::UIntImm; \
- using ir::FloatImm; \
- const IntImm* pa = a.as<IntImm>(); \
- const IntImm* pb = b.as<IntImm>(); \
- const FloatImm* fa = a.as<FloatImm>(); \
- const FloatImm* fb = b.as<FloatImm>(); \
+ using ir::IntImmNode; \
+ using ir::UIntImmNode; \
+ using ir::FloatImmNode; \
+ const IntImmNode* pa = a.as<IntImmNode>(); \
+ const IntImmNode* pb = b.as<IntImmNode>(); \
+ const FloatImmNode* fa = a.as<FloatImmNode>(); \
+ const FloatImmNode* fb = b.as<FloatImmNode>(); \
BODY;
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
- using ir::IntImm; \
- using ir::UIntImm; \
- const IntImm* pa = a.as<IntImm>(); \
- const IntImm* pb = b.as<IntImm>(); \
+ using ir::IntImmNode; \
+ using ir::UIntImmNode; \
+ const IntImmNode* pa = a.as<IntImmNode>(); \
+ const IntImmNode* pb = b.as<IntImmNode>(); \
const DataType& ta = a.dtype(); \
const DataType& tb = b.dtype(); \
if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
// specialization of constant folders.
template<>
-inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::AddNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
+ if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
+ if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
}
template<>
-inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::SubNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
+ if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
+ if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a;
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MulNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
+ if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
- if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
+ if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
}
template<>
-inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::DivNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm::make(rtype, pa->value / pb->value);
+ return IntImmNode::make(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
- return FloatImm::make(rtype, fa->value / fb->value);
+ return FloatImmNode::make(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) return a;
if (fb) {
}
template<>
-inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::ModNode>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
- return IntImm::make(rtype, pa->value % pb->value);
+ return IntImmNode::make(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return a;
}
template<>
-inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::FloorDivNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
+ return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
- return FloatImm::make(rtype, std::floor(fa->value / fb->value));
+ return FloatImmNode::make(rtype, std::floor(fa->value / fb->value));
}
if (fa && fa->value == 0) return a;
if (fb) {
}
template<>
-inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::FloorModNode>(Expr a, Expr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
- return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
+ return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
}
template<>
-inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MinNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
- if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
+ if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
+ if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return Expr();
}
template<>
-inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MaxNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
- if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
+ if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
+ if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return Expr();
}
template<>
-inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::GTNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::GENode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::LTNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::LENode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::EQNode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::NENode>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value);
- if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value);
+ if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
+ if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
});
return Expr();
}
template<>
-inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
- using ir::UIntImm;
- const UIntImm* pa = a.as<UIntImm>();
- const UIntImm* pb = b.as<UIntImm>();
+inline Expr TryConstFold<ir::AndNode>(Expr a, Expr b) {
+ using ir::UIntImmNode;
+ const UIntImmNode* pa = a.as<UIntImmNode>();
+ const UIntImmNode* pb = b.as<UIntImmNode>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
}
template<>
-inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
- using ir::UIntImm;
- const UIntImm* pa = a.as<UIntImm>();
- const UIntImm* pb = b.as<UIntImm>();
+inline Expr TryConstFold<ir::OrNode>(Expr a, Expr b) {
+ using ir::UIntImmNode;
+ const UIntImmNode* pa = a.as<UIntImmNode>();
+ const UIntImmNode* pb = b.as<UIntImmNode>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
}
template<>
-inline Expr TryConstFold<ir::Not>(Expr a) {
- using ir::UIntImm;
- const UIntImm* pa = a.as<UIntImm>();
+inline Expr TryConstFold<ir::NotNode>(Expr a) {
+ using ir::UIntImmNode;
+ const UIntImmNode* pa = a.as<UIntImmNode>();
if (pa) {
- return UIntImm::make(DataType::UInt(1), !(pa->value));
+ return UIntImmNode::make(DataType::UInt(1), !(pa->value));
}
return Expr();
}
return res;
}
- Entry VisitExpr_(const Cast* op) final {
+ Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype);
return Intersect(a, b);
}
- Entry VisitExpr_(const IntImm* op) final {
+ Entry VisitExpr_(const IntImmNode* op) final {
return MakeBound(op->value, op->value);
}
- Entry VisitExpr_(const UIntImm* op) final {
+ Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value <= static_cast<uint64_t>(kPosInf)) {
return MakeBound(op->value, op->value);
} else {
}
}
- Entry VisitExpr_(const Add* op) final {
+ Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
return ret;
}
- Entry VisitExpr_(const Sub* op) final {
+ Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
return ret;
}
- Entry VisitExpr_(const Mul* op) final {
+ Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return BinaryOpBoundry(a, b, InfAwareMul);
}
- Entry VisitExpr_(const Div* op) final {
+ Entry VisitExpr_(const DivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "divide by zero";
return BinaryOpBoundry(a, b, InfAwareDiv);
}
- Entry VisitExpr_(const Mod* op) final {
+ Entry VisitExpr_(const ModNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
}
}
- Entry VisitExpr_(const FloorDiv* op) final {
+ Entry VisitExpr_(const FloorDivNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
CHECK(!b.is_const(0)) << "floordiv by zero";
return BinaryOpBoundry(a, b, InfAwareFloorDiv);
}
- Entry VisitExpr_(const FloorMod* op) final {
+ Entry VisitExpr_(const FloorModNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
if (b.min_value > 0) {
}
}
- Entry VisitExpr_(const Min* op) final {
+ Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
return ret;
}
- Entry VisitExpr_(const Max* op) final {
+ Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
return ret;
}
- Entry VisitExpr_(const Select* op) final {
+ Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
- Entry VisitExpr_(const Call* op) final {
+ Entry VisitExpr_(const CallNode* op) final {
// only special handle >> and & which can be
// used for index calculation.
- if (op->is_intrinsic(Call::shift_right)) {
+ if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op);
- } else if (op->is_intrinsic(Call::bitwise_and)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
return VisitBitwiseAnd(op);
} else {
return Everything(op->dtype);
}
}
- Entry VisitExpr_(const Variable* op) final {
+ Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
}
}
- Entry VisitRightShift(const Call* op) {
+ Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
return BinaryOpBoundry(a, b, InfAwareRightShift);
}
- Entry VisitBitwiseAnd(const Call* op) {
+ Entry VisitBitwiseAnd(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
// handle positive index case.
return kNegInf;
}
if (y == kPosInf || y == kNegInf) return y;
- if (WillOverflow<Add>(x, y, kNegInf, kPosInf)) {
+ if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) {
if (x > 0) return kPosInf;
return kNegInf;
}
* \return the result.
*/
static int64_t InfAwareMul(int64_t x, int64_t y) {
- if (!WillOverflow<Mul>(x, y, kNegInf, kPosInf)) return x * y;
+ if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y;
if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf;
return kNegInf;
}
return true;
}
- LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
return ret;
}
- LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
return ret;
}
- LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
ret.coeff = MulCombine(a.base, b.coeff);
return ret;
}
- LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
ret.coeff = make_const(op->dtype, 1);
base = std::move(ret.base);
}
- std::unordered_set<const Variable*> vset;
+ std::unordered_set<const VarNode*> vset;
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
// Detect clip condition as min max value
bool DetectClipBound(
const Expr& cond,
- std::unordered_map<const Variable*, IntervalEntry>* bmap) {
+ std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
- if (const Variable* v = n.as<Variable>()) {
+ if (const VarNode* v = n.as<VarNode>()) {
if (bmap->count(v)) {
if (flag == 0) {
var = Downcast<Var>(n);
if (flag != 1) return false;
// canonical form: exp >= 0
Expr canonical;
- if (const LT* op = cond.as<LT>()) {
+ if (const LTNode* op = cond.as<LTNode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.dtype(), 1);
- } else if (const LE* op = cond.as<LE>()) {
+ } else if (const LENode* op = cond.as<LENode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a;
- } else if (const GT* op = cond.as<GT>()) {
+ } else if (const GTNode* op = cond.as<GTNode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b - make_const(op->a.dtype(), 1);
- } else if (const GE* op = cond.as<GE>()) {
+ } else if (const GENode* op = cond.as<GENode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->a - op->b;
} else {
if (is_const_int(ret.coeff, 1)) {
// var + shift >=0 -> var >= -shift
if (p.min_value.defined()) {
- p.min_value = ir::Max::make(p.min_value, -ret.base);
+ p.min_value = ir::MaxNode::make(p.min_value, -ret.base);
} else {
p.min_value = -ret.base;
}
if (is_const_int(ret.coeff, -1)) {
// -var + shift >=0 -> var <= shift
if (p.max_value.defined()) {
- p.max_value = ir::Min::make(p.max_value, ret.base);
+ p.max_value = ir::MinNode::make(p.max_value, ret.base);
} else {
p.max_value = ret.base;
}
// e must be connected by and.
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
std::vector<Expr> splits;
- SplitCommExpr<ir::And>(e, &splits);
- std::unordered_map<const Variable*, IntervalEntry> rmap;
+ SplitCommExpr<ir::AndNode>(e, &splits);
+ std::unordered_map<const VarNode*, IntervalEntry> rmap;
for (Var v : vars) {
rmap[v.get()] = IntervalEntry();
}
return ret;
}
- void VisitStmt_(const For *op) final {
- const Variable* var = op->loop_var.get();
+ void VisitStmt_(const ForNode *op) final {
+ const VarNode* var = op->loop_var.get();
dom_map_[var] = IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
- void VisitStmt_(const LetStmt* op) final {
+ void VisitStmt_(const LetStmtNode* op) final {
dom_map_[op->var.get()] =
arith::EvalSet(op->value, dom_map_);
StmtExprVisitor::VisitStmt_(op);
}
/* TODO: Thread extent unitest not generated.*/
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
- const Variable* var = thread_axis->var.get();
+ const VarNode* var = thread_axis->var.get();
dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
StmtExprVisitor::VisitStmt_(op);
dom_map_.erase(var);
}
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
if (consider_calls_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const Provide* op) final {
+ void VisitStmt_(const ProvideNode* op) final {
if (consider_provides_ && tensor_->op.same_as(op->func)
&& tensor_->value_index == op->value_index) {
Touch(op->args);
const Tensor &tensor_;
bool consider_calls_, consider_provides_;
std::vector<std::vector<IntSet> > bounds_;
- std::unordered_map<const Variable*, IntSet> dom_map_;
+ std::unordered_map<const VarNode*, IntSet> dom_map_;
};
Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
}
template<>
-inline bool WillOverflow<ir::Add>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+inline bool WillOverflow<ir::AddNode>(int64_t x,
+ int64_t y,
+ int64_t min_value,
+ int64_t max_value) {
if ((y > 0) && (x > max_value - y)) return true;
if ((y < 0) && (x < min_value - y)) return true;
return false;
}
template<>
-inline bool WillOverflow<ir::Sub>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+inline bool WillOverflow<ir::SubNode>(int64_t x,
+ int64_t y,
+ int64_t min_value,
+ int64_t max_value) {
if ((y > 0) && (x < min_value + y)) return true;
if ((y < 0) && (x > max_value + y)) return true;
return false;
}
template<>
-inline bool WillOverflow<ir::Mul>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+inline bool WillOverflow<ir::MulNode>(int64_t x,
+ int64_t y,
+ int64_t min_value,
+ int64_t max_value) {
if (y == 0) return false;
if (y > 0) {
if (x < min_value / y) return true;
}
template<>
-inline bool WillOverflow<ir::Mod>(int64_t x,
- int64_t y,
- int64_t min_value,
- int64_t max_value) {
+inline bool WillOverflow<ir::ModNode>(int64_t x,
+ int64_t y,
+ int64_t min_value,
+ int64_t max_value) {
return y == 0;
}
static const bool value = true; \
};
-TVM_DECLARE_LOGICAL_OP(And);
-TVM_DECLARE_LOGICAL_OP(Or);
-TVM_DECLARE_LOGICAL_OP(EQ);
-TVM_DECLARE_LOGICAL_OP(NE);
-TVM_DECLARE_LOGICAL_OP(GE);
-TVM_DECLARE_LOGICAL_OP(GT);
-TVM_DECLARE_LOGICAL_OP(LE);
-TVM_DECLARE_LOGICAL_OP(LT);
-TVM_DECLARE_LOGICAL_OP(Not);
+TVM_DECLARE_LOGICAL_OP(AndNode);
+TVM_DECLARE_LOGICAL_OP(OrNode);
+TVM_DECLARE_LOGICAL_OP(EQNode);
+TVM_DECLARE_LOGICAL_OP(NENode);
+TVM_DECLARE_LOGICAL_OP(GENode);
+TVM_DECLARE_LOGICAL_OP(GTNode);
+TVM_DECLARE_LOGICAL_OP(LENode);
+TVM_DECLARE_LOGICAL_OP(LTNode);
+TVM_DECLARE_LOGICAL_OP(NotNode);
/*!
* \brief Combine two interval set under arithmetic operations.
}
template<>
-inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value + b->min_value);
}
}
template<>
-inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
template<>
-inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value * b->min_value);
}
Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
- using ir::Select;
+ using ir::SelectNode;
Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = a->min_value * b->min_value;
Expr e2 = a->max_value * b->min_value;
- return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+ return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Mul";
}
template<>
-inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(a->min_value / b->min_value);
}
Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
- using ir::Select;
+ using ir::SelectNode;
Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = a->min_value / b->min_value;
Expr e2 = a->max_value / b->min_value;
- return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+ return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Div";
}
template<>
-inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
}
template<>
-inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
}
Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
- using ir::Select;
+ using ir::SelectNode;
Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
Expr e1 = floordiv(a->min_value, b->min_value);
Expr e2 = floordiv(a->max_value, b->min_value);
- return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+ return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
DLOG(WARNING) << "Return Everything in CombineInterval Div";
}
template<>
-inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
}
}
template<>
-inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::MaxNode>(Analyzer* analzyer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
}
template<>
-inline IntervalSet Combine<ir::Min>(Analyzer* analzyer,
- IntervalSet a,
- IntervalSet b) {
+inline IntervalSet Combine<ir::MinNode>(Analyzer* analzyer,
+ IntervalSet a,
+ IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
return IntervalSet(min_set->min_value, max_set->max_value);
}
- IntervalSet VisitExpr_(const IntImm* op) final {
+ IntervalSet VisitExpr_(const IntImmNode* op) final {
return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
- IntervalSet VisitExpr_(const UIntImm* op) final {
+ IntervalSet VisitExpr_(const UIntImmNode* op) final {
return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
- IntervalSet VisitExpr_(const Variable* op) final {
+ IntervalSet VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
}
}
- IntervalSet VisitExpr_(const Add* op) final {
+ IntervalSet VisitExpr_(const AddNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Sub* op) final {
+ IntervalSet VisitExpr_(const SubNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Mul* op) final {
+ IntervalSet VisitExpr_(const MulNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Div* op) final {
+ IntervalSet VisitExpr_(const DivNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Mod* op) final {
+ IntervalSet VisitExpr_(const ModNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const FloorDiv* op) final {
+ IntervalSet VisitExpr_(const FloorDivNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const FloorMod* op) final {
+ IntervalSet VisitExpr_(const FloorModNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Min* op) final {
+ IntervalSet VisitExpr_(const MinNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Max* op) final {
+ IntervalSet VisitExpr_(const MaxNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const EQ* op) final {
+ IntervalSet VisitExpr_(const EQNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const NE* op) final {
+ IntervalSet VisitExpr_(const NENode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const LT* op) final {
+ IntervalSet VisitExpr_(const LTNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const LE* op) final {
+ IntervalSet VisitExpr_(const LENode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const GT* op) final {
+ IntervalSet VisitExpr_(const GTNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const GE* op) final {
+ IntervalSet VisitExpr_(const GENode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const And* op) final {
+ IntervalSet VisitExpr_(const AndNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Or* op) final {
+ IntervalSet VisitExpr_(const OrNode* op) final {
return VisitBinaryExpr_(op);
}
- IntervalSet VisitExpr_(const Ramp* op) final {
+ IntervalSet VisitExpr_(const RampNode* op) final {
CHECK(eval_vec_);
IntervalSet base = Eval(op->base);
PVar<Integer> stride;
DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
if (vstride> 0) {
- return Combine<Add>(
+ return Combine<AddNode>(
analyzer_,
base,
IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
} else {
- return Combine<Add>(
+ return Combine<AddNode>(
analyzer_,
base,
IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
return IntervalSet::Everything();
}
- IntervalSet VisitExpr_(const Broadcast* op) final {
+ IntervalSet VisitExpr_(const BroadcastNode* op) final {
CHECK(eval_vec_);
return VisitExpr(op->value);
}
- IntervalSet VisitExpr_(const Select* op) final {
+ IntervalSet VisitExpr_(const SelectNode* op) final {
IntervalSet true_set = this->Eval(op->true_value);
IntervalSet false_set = this->Eval(op->false_value);
return Union(analyzer_, false_set, true_set);
}
Map<Var, IntSet> ConvertDomMap(
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Map<Var, IntSet> dmap;
for (auto kv : dom_map) {
dmap.Set(GetRef<Var>(kv.first), kv.second);
}
IntSet EvalSet(Expr e,
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
return EvalSet(e, ConvertDomMap(dom_map));
}
}
IntSet EvalSet(Range r,
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
return EvalSet(r, ConvertDomMap(dom_map));
}
IntSet EvalSet(IntSet s,
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
IntervalSetEvaluator m(&ana, dmap);
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
SubExprIntervalSetEvaluator m(&ana, dmap);
using namespace ir;
Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const For* op) {
+VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}
Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const LetStmt* op) {
+VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const IfThenElse* op) {
+VisitStmt_(const IfThenElseNode* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
- analyzer_->rewrite_simplify(Not::make(condition)));
+ analyzer_->rewrite_simplify(NotNode::make(condition)));
else_case = this->VisitStmt(op->else_case);
}
if (is_one(condition)) return then_case;
if (else_case.defined()) {
return else_case;
}
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
if (condition.same_as(op->condition) &&
}
Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AttrStmt* op) {
+VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
}
Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AssertStmt* op) {
+VisitStmt_(const AssertStmtNode* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
}
Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Call* op) {
+VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
Expr cond = this->VisitExpr(op->args[0]);
}
{
With<ConstraintContext> constraint(analyzer_,
- analyzer_->rewrite_simplify(Not::make(cond)));
+ analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = this->VisitExpr(op->args[2]);
}
if (is_zero(cond)) {
false_value.same_as(op->args[2])) {
return GetRef<Expr>(op);
} else {
- return Call::make(op->dtype, op->name,
+ return CallNode::make(op->dtype, op->name,
{cond, true_value, false_value},
op->call_type);
}
}
Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Let* op) {
+VisitExpr_(const LetNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Let::make(op->var, value, body);
+ return LetNode::make(op->var, value, body);
}
}
Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Select* op) {
+VisitExpr_(const SelectNode* op) {
Expr cond = this->VisitExpr(op->condition);
Expr true_value, false_value;
{
}
{
With<ConstraintContext> constraint(analyzer_,
- analyzer_->rewrite_simplify(Not::make(cond)));
+ analyzer_->rewrite_simplify(NotNode::make(cond)));
false_value = VisitExpr(op->false_value);
}
if (is_zero(cond)) {
false_value.same_as(op->false_value)) {
return GetRef<Expr>(op);
} else {
- return Select::make(cond, true_value, false_value);
+ return SelectNode::make(cond, true_value, false_value);
}
}
Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Reduce* op) {
+VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_->Bind(iv->var, iv->dom);
using StmtExprMutator::VisitExpr_;
// override functions that need to populate the context information.
- Stmt VisitStmt_(const ir::For* op) override;
- Stmt VisitStmt_(const ir::LetStmt* op) override;
- Stmt VisitStmt_(const ir::IfThenElse* op) override;
- Stmt VisitStmt_(const ir::AttrStmt* op) override;
- Stmt VisitStmt_(const ir::AssertStmt* op) override;
- Expr VisitExpr_(const ir::Let* op) override;
- Expr VisitExpr_(const ir::Select* op) override;
- Expr VisitExpr_(const ir::Call* op) override;
- Expr VisitExpr_(const ir::Reduce* op) override;
+ Stmt VisitStmt_(const ir::ForNode* op) override;
+ Stmt VisitStmt_(const ir::LetStmtNode* op) override;
+ Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
+ Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
+ Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
+ Expr VisitExpr_(const ir::LetNode* op) override;
+ Expr VisitExpr_(const ir::SelectNode* op) override;
+ Expr VisitExpr_(const ir::CallNode* op) override;
+ Expr VisitExpr_(const ir::ReduceNode* op) override;
protected:
/*! \brief internal analyzer field. */
return analyzer_.Simplify(expr);
}
- void VisitStmt_(const For* op) {
+ void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const AttrStmt* op) {
+ void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
}
}
- void VisitExpr_(const Reduce* op) {
+ void VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
return Everything();
}
- Entry VisitExpr_(const Cast* op) final {
+ Entry VisitExpr_(const CastNode* op) final {
return VisitExpr(op->value);
}
- Entry VisitExpr_(const IntImm* op) final {
+ Entry VisitExpr_(const IntImmNode* op) final {
return Entry(0, op->value);
}
- Entry VisitExpr_(const UIntImm* op) final {
+ Entry VisitExpr_(const UIntImmNode* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
return Entry(0, static_cast<int>(op->value));
} else {
}
}
- Entry VisitExpr_(const Add* op) final {
+ Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}
- Entry VisitExpr_(const Sub* op) final {
+ Entry VisitExpr_(const SubNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}
- Entry VisitExpr_(const Mul* op) final {
+ Entry VisitExpr_(const MulNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
// Simplification rule, x, y, z are in Z
return Everything();
}
- Entry VisitExpr_(const Div* op) final {
+ Entry VisitExpr_(const DivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, false);
return Everything();
}
- Entry VisitExpr_(const FloorDiv* op) final {
+ Entry VisitExpr_(const FloorDivNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
return DivByConst(op->a, b.base, true);
return Everything();
}
- Entry VisitExpr_(const Min* op) final {
+ Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
- Entry VisitExpr_(const Max* op) final {
+ Entry VisitExpr_(const MaxNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
return Union(a, b);
}
- Entry VisitExpr_(const Select* op) final {
+ Entry VisitExpr_(const SelectNode* op) final {
Entry a = VisitExpr(op->true_value);
Entry b = VisitExpr(op->false_value);
return Union(a, b);
}
- Entry VisitExpr_(const Call* op) final {
+ Entry VisitExpr_(const CallNode* op) final {
// only special handle >> which can be
// used for index calculation.
- if (op->is_intrinsic(Call::shift_right)) {
+ if (op->is_intrinsic(CallNode::shift_right)) {
return VisitRightShift(op);
} else {
return Everything();
}
}
- Entry VisitExpr_(const Variable* op) final {
+ Entry VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
}
}
- Entry VisitRightShift(const Call* op) {
+ Entry VisitRightShift(const CallNode* op) {
Entry b = VisitExpr(op->args[1]);
// a c x / c -> a x
if (b.is_const()) {
void InitMatch_() const {}
bool Match_(const ObjectRef& node) const {
- if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
+ if (const ir::IntImmNode* ptr = node.as<ir::IntImmNode>()) {
return ptr->value == value_;
} else {
return false;
// raise ambiguity error for operator overload of / and %
-TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
-TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
+TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a));
+TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a));
// arithmetic expressions
-TVM_PATTERN_BINARY_OP(operator+, ir::Add);
-TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
-TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
-TVM_PATTERN_BINARY_OP(min, ir::Min);
-TVM_PATTERN_BINARY_OP(max, ir::Max);
-TVM_PATTERN_BINARY_OP(div, ir::Div);
-TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
-TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
-TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
-TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
+TVM_PATTERN_BINARY_OP(operator+, ir::AddNode);
+TVM_PATTERN_BINARY_OP(operator-, ir::SubNode);
+TVM_PATTERN_BINARY_OP(operator*, ir::MulNode);
+TVM_PATTERN_BINARY_OP(min, ir::MinNode);
+TVM_PATTERN_BINARY_OP(max, ir::MaxNode);
+TVM_PATTERN_BINARY_OP(div, ir::DivNode);
+TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode);
+TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode);
+TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode);
+TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode);
// logical expressions
-TVM_PATTERN_BINARY_OP(operator>, ir::GT);
-TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
-TVM_PATTERN_BINARY_OP(operator<, ir::LT);
-TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
-TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
-TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
-TVM_PATTERN_BINARY_OP(operator&&, ir::And);
-TVM_PATTERN_BINARY_OP(operator||, ir::Or);
+TVM_PATTERN_BINARY_OP(operator>, ir::GTNode);
+TVM_PATTERN_BINARY_OP(operator>=, ir::GENode);
+TVM_PATTERN_BINARY_OP(operator<, ir::LTNode);
+TVM_PATTERN_BINARY_OP(operator<=, ir::LENode);
+TVM_PATTERN_BINARY_OP(operator==, ir::EQNode);
+TVM_PATTERN_BINARY_OP(operator!=, ir::NENode);
+TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode);
+TVM_PATTERN_BINARY_OP(operator||, ir::OrNode);
/*!
* \brief Pattern not expression.
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Not* ptr = node.as<ir::Not>()) {
+ if (const ir::NotNode* ptr = node.as<ir::NotNode>()) {
if (!value_.Match_(ptr->a)) return false;
return true;
} else {
}
Expr Eval() const {
- return ir::Not::make(value_.Eval());
+ return ir::NotNode::make(value_.Eval());
}
private:
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Select* ptr = node.as<ir::Select>()) {
+ if (const ir::SelectNode* ptr = node.as<ir::SelectNode>()) {
if (!condition_.Match_(ptr->condition)) return false;
if (!true_value_.Match_(ptr->true_value)) return false;
if (!false_value_.Match_(ptr->false_value)) return false;
}
Expr Eval() const {
- return ir::Select::make(
+ return ir::SelectNode::make(
condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Cast* ptr = node.as<ir::Cast>()) {
+ if (const ir::CastNode* ptr = node.as<ir::CastNode>()) {
if (!dtype_.Match_(ptr->dtype)) return false;
if (!value_.Match_(ptr->value)) return false;
return true;
}
Expr Eval() const {
- return ir::Cast::make(dtype_.Eval(), value_.Eval());
+ return ir::CastNode::make(dtype_.Eval(), value_.Eval());
}
private:
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
+ if (const ir::RampNode* ptr = node.as<ir::RampNode>()) {
if (!base_.Match_(ptr->base)) return false;
if (!stride_.Match_(ptr->stride)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
}
Expr Eval() const {
- return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
+ return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
}
private:
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
+ if (const ir::BroadcastNode* ptr = node.as<ir::BroadcastNode>()) {
if (!value_.Match_(ptr->value)) return false;
if (!lanes_.Match_(ptr->lanes)) return false;
return true;
}
Expr Eval() const {
- return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
+ return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
}
private:
};
struct PCallExprMatchFunctor {
- const ir::Call* call_;
+ const ir::CallNode* call_;
bool matched_{true};
- explicit PCallExprMatchFunctor(const ir::Call* call)
+ explicit PCallExprMatchFunctor(const ir::CallNode* call)
: call_(call) {}
template<typename T>
}
bool Match_(const ObjectRef& node) const {
- if (const ir::Call* ptr = node.as<ir::Call>()) {
+ if (const ir::CallNode* ptr = node.as<ir::CallNode>()) {
if (ptr->args.size() != sizeof...(TArgs)) return false;
if (ptr->name != Op::kName) return false;
detail::PCallExprMatchFunctor fmatch(ptr);
};
// arithemetic intrinsics
-#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
- struct OpName { \
- static Expr Eval(Array<Expr> args) { \
- return ir::Call::make(args[0].dtype(), kName, args, \
- ir::Call::PureIntrinsic); \
- } \
- static constexpr const char* kName = IntrinStr; \
- }; \
- template<typename TA, typename TB> \
- inline PCallExpr<OpName, TA, TB> \
- FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
- return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
+#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
+ struct OpName { \
+ static Expr Eval(Array<Expr> args) { \
+ return ir::CallNode::make(args[0].dtype(), kName, args, \
+ ir::CallNode::PureIntrinsic); \
+ } \
+ static constexpr const char* kName = IntrinStr; \
+ }; \
+ template<typename TA, typename TB> \
+ inline PCallExpr<OpName, TA, TB> \
+ FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
+ return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
}
TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
// unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
- struct OpName { \
- static Expr Eval(Array<Expr> args) { \
- return ir::Call::make(args[0].dtype(), kName, args, \
- ir::Call::PureIntrinsic); \
- } \
- static constexpr const char* kName = IntrinStr; \
- }; \
- template<typename TA> \
- inline PCallExpr<OpName, TA> \
- FuncName(const Pattern<TA>& a) { \
- return PCallExpr<OpName, TA>(a.derived()); \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
+ struct OpName { \
+ static Expr Eval(Array<Expr> args) { \
+ return ir::CallNode::make(args[0].dtype(), kName, args, \
+ ir::CallNode::PureIntrinsic); \
+ } \
+ static constexpr const char* kName = IntrinStr; \
+ }; \
+ template<typename TA> \
+ inline PCallExpr<OpName, TA> \
+ FuncName(const Pattern<TA>& a) { \
+ return PCallExpr<OpName, TA>(a.derived()); \
}
TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
// if_then_else
struct PIfThenElseOp {
static Expr Eval(Array<Expr> args) {
- return ir::Call::make(
+ return ir::CallNode::make(
args[1].dtype(), kName, args,
- ir::Call::PureIntrinsic);
+ ir::CallNode::PureIntrinsic);
}
static constexpr const char* kName = "tvm_if_then_else";
};
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
TryCompare(const Expr& x, int64_t val) {
Expr diff = this->VisitExpr(x);
- if (const auto* ptr = diff.as<IntImm>()) {
+ if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
return kEQ;
} else if (ptr->value > val) {
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Add* op) {
+VisitExpr_(const AddNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Add>();
- Expr const_res = TryConstFold<Add>(op->a, op->b);
+ op = ret.as<AddNode>();
+ Expr const_res = TryConstFold<AddNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Sub* op) {
+VisitExpr_(const SubNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Sub>();
- Expr const_res = TryConstFold<Sub>(op->a, op->b);
+ op = ret.as<SubNode>();
+ Expr const_res = TryConstFold<SubNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Mul* op) {
+VisitExpr_(const MulNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Mul>();
- Expr const_res = TryConstFold<Mul>(op->a, op->b);
+ op = ret.as<MulNode>();
+ Expr const_res = TryConstFold<MulNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1, b2, s1, s2;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Div* op) {
+VisitExpr_(const DivNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Div>();
- Expr const_res = TryConstFold<Div>(op->a, op->b);
+ op = ret.as<DivNode>();
+ Expr const_res = TryConstFold<DivNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
PVar<int> lanes;
// x / 2.0 = x * 0.5
- if (const FloatImm* ptr = op->b.as<FloatImm>()) {
+ if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
CHECK(op->dtype.is_float());
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Mod* op) {
+VisitExpr_(const ModNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Mod>();
- Expr const_res = TryConstFold<Mod>(op->a, op->b);
+ op = ret.as<ModNode>();
+ Expr const_res = TryConstFold<ModNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const FloorDiv* op) {
+VisitExpr_(const FloorDivNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<FloorDiv>();
- Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
+ op = ret.as<FloorDivNode>();
+ Expr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y, z, b1;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const FloorMod* op) {
+VisitExpr_(const FloorModNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<FloorMod>();
- Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
+ op = ret.as<FloorModNode>();
+ Expr const_res = TryConstFold<FloorModNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Min* op) {
+VisitExpr_(const MinNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Min>();
- Expr const_res = TryConstFold<Min>(op->a, op->b);
+ op = ret.as<MinNode>();
+ Expr const_res = TryConstFold<MinNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Max* op) {
+VisitExpr_(const MaxNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Max>();
- Expr const_res = TryConstFold<Max>(op->a, op->b);
+ op = ret.as<MaxNode>();
+ Expr const_res = TryConstFold<MaxNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const EQ* op) {
+VisitExpr_(const EQNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<EQ>();
- Expr const_res = TryConstFold<EQ>(op->a, op->b);
+ op = ret.as<EQNode>();
+ Expr const_res = TryConstFold<EQNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const NE* op) {
- return this->VisitExpr(Not::make(op->a == op->b));
+VisitExpr_(const NENode* op) {
+ return this->VisitExpr(NotNode::make(op->a == op->b));
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const LE* op) {
- return this->VisitExpr(Not::make(op->b < op->a));
+VisitExpr_(const LENode* op) {
+ return this->VisitExpr(NotNode::make(op->b < op->a));
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const GT* op) {
+VisitExpr_(const GTNode* op) {
return this->VisitExpr(op->b < op->a);
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const GE* op) {
- return this->VisitExpr(Not::make(op->a < op->b));
+VisitExpr_(const GENode* op) {
+ return this->VisitExpr(NotNode::make(op->a < op->b));
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const LT* op) {
+VisitExpr_(const LTNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<LT>();
- Expr const_res = TryConstFold<LT>(op->a, op->b);
+ op = ret.as<LTNode>();
+ Expr const_res = TryConstFold<LTNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Not* op) {
+VisitExpr_(const NotNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Not>();
- Expr const_res = TryConstFold<Not>(op->a);
+ op = ret.as<NotNode>();
+ Expr const_res = TryConstFold<NotNode>(op->a);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
PVar<Expr> x, y;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const And* op) {
+VisitExpr_(const AndNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<And>();
- Expr const_res = TryConstFold<And>(op->a, op->b);
+ op = ret.as<AndNode>();
+ Expr const_res = TryConstFold<AndNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Or* op) {
+VisitExpr_(const OrNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Or>();
- Expr const_res = TryConstFold<Or>(op->a, op->b);
+ op = ret.as<OrNode>();
+ Expr const_res = TryConstFold<OrNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Select* op) {
+VisitExpr_(const SelectNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Select>();
+ op = ret.as<SelectNode>();
if (op == nullptr) return ret;
// Pattern var to match any expression
PVar<Expr> x, y;
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Call* op) {
+VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Call>();
+ op = ret.as<CallNode>();
if (op == nullptr) return ret;
- if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
+ if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) {
return op->args[0];
- } else if (op->is_intrinsic(Call::shift_right)) {
- if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
+ } else if (op->is_intrinsic(CallNode::shift_right)) {
+ if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
- } else if (op->is_intrinsic(Call::bitwise_and)) {
- if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+ if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] & op->args[1];
}
}
- if (op->is_intrinsic(Call::likely)) {
+ if (op->is_intrinsic(CallNode::likely)) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (Equal(constraint, op->args[0])) {
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Variable* op) {
+VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Cast* op) {
+VisitExpr_(const CastNode* op) {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<Cast>();
+ op = ret.as<CastNode>();
return cast(op->dtype, op->value);
}
Expr RewriteSimplifier::Impl::
-VisitExpr_(const Let* op) {
+VisitExpr_(const LetNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Let::make(op->var, value, body);
+ return LetNode::make(op->var, value, body);
}
}
: IRMutatorWithAnalyzer(parent) {}
void Update(const Var& var, const Expr& info, bool override_info);
- Expr VisitExpr_(const Add* op) override;
- Expr VisitExpr_(const Sub* op) override;
- Expr VisitExpr_(const Mul* op) override;
- Expr VisitExpr_(const Div* op) override;
- Expr VisitExpr_(const Mod* op) override;
- Expr VisitExpr_(const FloorDiv* op) override;
- Expr VisitExpr_(const FloorMod* op) override;
- Expr VisitExpr_(const Min* op) override;
- Expr VisitExpr_(const Max* op) override;
- Expr VisitExpr_(const EQ* op) override;
- Expr VisitExpr_(const NE* op) override;
- Expr VisitExpr_(const LT* op) override;
- Expr VisitExpr_(const LE* op) override;
- Expr VisitExpr_(const GT* op) override;
- Expr VisitExpr_(const GE* op) override;
- Expr VisitExpr_(const And* op) override;
- Expr VisitExpr_(const Or* op) override;
- Expr VisitExpr_(const Not* op) override;
- Expr VisitExpr_(const Select* op) override;
- Expr VisitExpr_(const Call* op) override;
- Expr VisitExpr_(const Variable* op) override;
- Expr VisitExpr_(const Cast* op) override;
- Expr VisitExpr_(const Let* op) override;
+ Expr VisitExpr_(const AddNode* op) override;
+ Expr VisitExpr_(const SubNode* op) override;
+ Expr VisitExpr_(const MulNode* op) override;
+ Expr VisitExpr_(const DivNode* op) override;
+ Expr VisitExpr_(const ModNode* op) override;
+ Expr VisitExpr_(const FloorDivNode* op) override;
+ Expr VisitExpr_(const FloorModNode* op) override;
+ Expr VisitExpr_(const MinNode* op) override;
+ Expr VisitExpr_(const MaxNode* op) override;
+ Expr VisitExpr_(const EQNode* op) override;
+ Expr VisitExpr_(const NENode* op) override;
+ Expr VisitExpr_(const LTNode* op) override;
+ Expr VisitExpr_(const LENode* op) override;
+ Expr VisitExpr_(const GTNode* op) override;
+ Expr VisitExpr_(const GENode* op) override;
+ Expr VisitExpr_(const AndNode* op) override;
+ Expr VisitExpr_(const OrNode* op) override;
+ Expr VisitExpr_(const NotNode* op) override;
+ Expr VisitExpr_(const SelectNode* op) override;
+ Expr VisitExpr_(const CallNode* op) override;
+ Expr VisitExpr_(const VarNode* op) override;
+ Expr VisitExpr_(const CastNode* op) override;
+ Expr VisitExpr_(const LetNode* op) override;
std::function<void()> EnterConstraint(const Expr& constraint);
return operator()(std::move(stmt));
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op);
}
- Stmt VisitStmt_(const LetStmt* op) {
+ Stmt VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
}
// eliminate useless stores
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = Parent::VisitStmt_(op);
- op = stmt.as<Store>();
- if (const Load* load = op->value.as<Load>()) {
+ op = stmt.as<StoreNode>();
+ if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
Equal(load->index, op->index)) {
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
}
return GetRef<Stmt>(op);
namespace autotvm {
// for loop
-void FeatureVisitor::VisitStmt_(const For* op) {
- const auto *extent = op->extent.as<IntImm>();
+void FeatureVisitor::VisitStmt_(const ForNode* op) {
+ const auto *extent = op->extent.as<IntImmNode>();
int64_t loop_extent = -1;
if (extent != nullptr)
loop_extent = extent->value;
}
// parallel axis, virtual thread
-void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
+void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
- const auto *extent = op->value.as<IntImm>();
+ const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
std::string name = var.get()->name_hint;
}
// memory access
-void FeatureVisitor::VisitExpr_(const Load* op) {
+void FeatureVisitor::VisitExpr_(const LoadNode* op) {
EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitExpr_(op);
ExitMem_();
}
-void FeatureVisitor::VisitStmt_(const Store* op) {
+void FeatureVisitor::VisitStmt_(const StoreNode* op) {
EnterMem_(op->buffer_var, op->index);
StmtExprVisitor::VisitStmt_(op);
ExitMem_();
class FeatureVisitor : public StmtExprVisitor {
public:
// for loop
- void VisitStmt_(const For* op) final;
- void VisitStmt_(const AttrStmt* op) final;
+ void VisitStmt_(const ForNode* op) final;
+ void VisitStmt_(const AttrStmtNode* op) final;
// memory access
- void VisitExpr_(const Load* op) final;
- void VisitStmt_(const Store* op) final;
+ void VisitExpr_(const LoadNode* op) final;
+ void VisitStmt_(const StoreNode* op) final;
using StmtExprVisitor::VisitStmt_;
using StmtExprVisitor::VisitExpr_;
this->VisitExpr(expr);
}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
// TODO(lmzheng): handle more index types (multiple occurrence)
if (pattern_map.count(op) == 0) {
pattern_map[op] = TouchPattern();
}
}
- void VisitExpr_(const Mul* op) final {
- if (op->a.as<Variable>()) {
- if (const auto stride = op->b.as<IntImm>()) {
+ void VisitExpr_(const MulNode* op) final {
+ if (op->a.as<VarNode>()) {
+ if (const auto stride = op->b.as<IntImmNode>()) {
next_stride_ = stride->value;
}
}
ExprVisitor::VisitExpr_(op);
}
- std::unordered_map<const Variable*, TouchPattern> pattern_map;
+ std::unordered_map<const VarNode*, TouchPattern> pattern_map;
private:
int64_t next_stride_ = 1;
feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
Array<Expr> attr{std::string("_attr_"),
- FloatImm::make(DataType::Float(32), trans(fea.length)),
- IntImm::make(DataType::Int(32), fea.nest_level),
- FloatImm::make(DataType::Float(32), trans(fea.topdown_product)),
- FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)),
+ FloatImmNode::make(DataType::Float(32), trans(fea.length)),
+ IntImmNode::make(DataType::Int(32), fea.nest_level),
+ FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
+ FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
// arithmetic
feature_row.push_back(Array<Expr>{std::string("_arith_"),
- FloatImm::make(DataType::Float(32), trans(fea.add_ct)),
- FloatImm::make(DataType::Float(32), trans(fea.mul_ct)),
- FloatImm::make(DataType::Float(32), trans(fea.div_ct)),
+ FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
+ FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
+ FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
});
// touch map
std::sort(bufs.begin(), bufs.end());
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
- feature_row.push_back(Array<Expr>{k,
- FloatImm::make(DataType::Float(32), trans(v.stride)),
- FloatImm::make(DataType::Float(32), trans(v.mod)),
- FloatImm::make(DataType::Float(32), trans(v.count)),
- FloatImm::make(DataType::Float(32), trans(v.reuse)),
- FloatImm::make(DataType::Float(32), trans(v.thread_count)),
- FloatImm::make(DataType::Float(32), trans(v.thread_reuse)),
- });
+ feature_row.push_back(
+ Array<Expr>{k,
+ FloatImmNode::make(DataType::Float(32), trans(v.stride)),
+ FloatImmNode::make(DataType::Float(32), trans(v.mod)),
+ FloatImmNode::make(DataType::Float(32), trans(v.count)),
+ FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
+ FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
+ FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
+ });
}
ret_feature->push_back(feature_row);
}
// arithmetic stats
- void VisitExpr_(const Add* op) final {
+ void VisitExpr_(const AddNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
- void VisitExpr_(const Sub* op) final {
+ void VisitExpr_(const SubNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].add_ct++;
FeatureVisitor::VisitExpr_(op);
}
- void VisitExpr_(const Mul* op) final {
+ void VisitExpr_(const MulNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].mul_ct++;
FeatureVisitor::VisitExpr_(op);
}
- void VisitExpr_(const Div* op) final {
+ void VisitExpr_(const DivNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
}
- void VisitExpr_(const Mod* op) final {
+ void VisitExpr_(const ModNode* op) final {
if (op->dtype.is_float())
itervar_map[itervar_stack_.back()].div_ct++;
FeatureVisitor::VisitExpr_(op);
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
- t->options_array.push_back(ir::StringImm::make(item));
+ t->options_array.push_back(ir::StringImmNode::make(item));
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
- t->libs_array.push_back(ir::StringImm::make(lib_item));
+ t->libs_array.push_back(ir::StringImmNode::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
- t->keys_array.push_back(ir::StringImm::make(t->device_name));
+ t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
- t->keys_array.push_back(ir::StringImm::make(key_item));
+ t->keys_array.push_back(ir::StringImmNode::make(key_item));
}
}
}
if (t->device_name.length() > 0) {
- t->keys_array.push_back(ir::StringImm::make(t->device_name));
+ t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
- t->keys_array.push_back(ir::StringImm::make("cpu"));
+ t->keys_array.push_back(ir::StringImmNode::make("cpu"));
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
- t->keys_array.push_back(ir::StringImm::make("cuda"));
- t->keys_array.push_back(ir::StringImm::make("gpu"));
+ t->keys_array.push_back(ir::StringImmNode::make("cuda"));
+ t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
} else {
t->device_type = kDLROCM;
}
- t->keys_array.push_back(ir::StringImm::make(target_name));
- t->keys_array.push_back(ir::StringImm::make("gpu"));
+ t->keys_array.push_back(ir::StringImmNode::make(target_name));
+ t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
} else {
t->device_type = kDLVulkan;
}
- t->keys_array.push_back(ir::StringImm::make(target_name));
- t->keys_array.push_back(ir::StringImm::make("gpu"));
+ t->keys_array.push_back(ir::StringImmNode::make(target_name));
+ t->keys_array.push_back(ir::StringImmNode::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
- t->keys_array.push_back(ir::StringImm::make("sdaccel"));
- t->keys_array.push_back(ir::StringImm::make("hls"));
+ t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
+ t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
- t->keys_array.push_back(ir::StringImm::make("aocl"));
- t->keys_array.push_back(ir::StringImm::make("hls"));
+ t->keys_array.push_back(ir::StringImmNode::make("aocl"));
+ t->keys_array.push_back(ir::StringImmNode::make("hls"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
- t->keys_array.push_back(ir::StringImm::make("opengl"));
+ t->keys_array.push_back(ir::StringImmNode::make("opengl"));
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
- result.push_back(expr.as<ir::StringImm>()->value);
+ result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
- result.push_back(expr.as<ir::StringImm>()->value);
+ result.push_back(expr.as<ir::StringImmNode>()->value);
}
return result;
}
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
- result.insert(expr.as<ir::StringImm>()->value);
+ result.insert(expr.as<ir::StringImmNode>()->value);
}
return result;
}
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
- if (it.as<Variable>()) {
+ if (it.as<VarNode>()) {
has_any = true;
break;
}
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
- tags_vector.push_back(tag.as<tvm::ir::StringImm>()->value);
+ tags_vector.push_back(tag.as<tvm::ir::StringImmNode>()->value);
}
generic_func
// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(
- DataType t, const Variable* buffer, Expr index) {
+ DataType t, const VarNode* buffer, Expr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
}
-bool CodeGenC::HandleTypeMatch(const Variable* buf_var, DataType t) const {
+bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
return it->second == t;
}
-void CodeGenC::RegisterHandleType(const Variable* buf_var, DataType t) {
+void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
}
std::string CodeGenC::GetVecLoad(
- DataType t, const Variable* buffer, Expr base) {
+ DataType t, const VarNode* buffer, Expr base) {
return GetBufferRef(t, buffer, base);
}
-void CodeGenC::PrintVecStore(const Variable* buffer,
+void CodeGenC::PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value) {
std::string ref = GetBufferRef(t, buffer, base);
LOG(FATAL) << "not implemented";
}
-void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
+void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
}
-inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::Int(32)) {
std::ostringstream temp;
temp << op->value;
}
}
-inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << op->value << "U";
}
}
-inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64: case 32: {
std::ostringstream temp;
}
}
-void CodeGenC::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "\"" << op->value << "\"";
}
}
}
-inline void PrintBinaryIntrinsic(const Call* op,
+inline void PrintBinaryIntrinsic(const CallNode* op,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenC* p) {
p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
-void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
std::stringstream value;
this->PrintExpr(op->value, value);
os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
}
-void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
-void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
-void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
-void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
-void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
-void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
-void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
-void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
-void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
-void CodeGenC::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
-void CodeGenC::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
-void CodeGenC::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
-void CodeGenC::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
-void CodeGenC::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
-void CodeGenC::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
-void CodeGenC::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
-void CodeGenC::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
os << '!';
PrintExpr(op->a, os);
}
-void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
- if (op->call_type == Call::Extern ||
- op->call_type == Call::PureExtern) {
+void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
+ if (op->call_type == CallNode::Extern ||
+ op->call_type == CallNode::PureExtern) {
os << op->name << "(";
for (size_t i = 0; i < op->args.size(); i++) {
this->PrintExpr(op->args[i], os);
}
}
os << ")";
- } else if (op->is_intrinsic(Call::bitwise_and)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
PrintBinaryIntrinsic(op, " & ", os, this);
- } else if (op->is_intrinsic(Call::bitwise_xor)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
PrintBinaryIntrinsic(op, " ^ ", os, this);
- } else if (op->is_intrinsic(Call::bitwise_or)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_or)) {
PrintBinaryIntrinsic(op, " | ", os, this);
- } else if (op->is_intrinsic(Call::bitwise_not)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
this->PrintExpr(op->args[0], os);
os << ')';
- } else if (op->is_intrinsic(Call::shift_left)) {
+ } else if (op->is_intrinsic(CallNode::shift_left)) {
PrintBinaryIntrinsic(op, " << ", os, this);
- } else if (op->is_intrinsic(Call::shift_right)) {
+ } else if (op->is_intrinsic(CallNode::shift_right)) {
PrintBinaryIntrinsic(op, " >> ", os, this);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
os << "(";
PrintExpr(op->args[2], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load *l = op->args[0].as<Load>();
+ const LoadNode *l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
os << "((";
this->PrintType(l->dtype.element_of(), os);
CHECK_EQ(op->args.size(), 3U);
os << GetStructRef(
op->dtype, op->args[0], op->args[1],
- op->args[2].as<IntImm>()->value);
+ op->args[2].as<IntImmNode>()->value);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
os << "(";
this->PrintExpr(op->args[0], os);
os << " == NULL)";
- } else if (op->is_intrinsic(Call::reinterpret)) {
+ } else if (op->is_intrinsic(CallNode::reinterpret)) {
// generate (*( TYPE *)(&(ARG)))
os << "(*(";
this->PrintType(op->dtype, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
- } else if (op->is_intrinsic(Call::isnan)) {
+ } else if (op->is_intrinsic(CallNode::isnan)) {
os << "(";
this->PrintExpr(op->args[0], os);
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else {
- if (op->call_type == Call::Intrinsic ||
- op->call_type == Call::PureIntrinsic) {
+ if (op->call_type == CallNode::Intrinsic ||
+ op->call_type == CallNode::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
<< " with return type " << op->dtype;
} else {
}
}
-void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
int lanes = op->dtype.lanes();
// delcare type.
if (op->dtype.lanes() == 1) {
}
}
-void CodeGenC::VisitStmt_(const Store* op) {
+void CodeGenC::VisitStmt_(const StoreNode* op) {
DataType t = op->value.dtype();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
}
}
-void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
std::string value = PrintExpr(op->value);
CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
os << PrintExpr(op->body);
}
-void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
CHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
os << "))";
}
-void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) {
+void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
LOG(FATAL) << "Shuffle: not supported ";
}
-void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
-void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
os << "(";
PrintExpr(op->condition, os);
os << " ? ";
os << ")";
}
-void CodeGenC::VisitStmt_(const LetStmt* op) {
+void CodeGenC::VisitStmt_(const LetStmtNode* op) {
std::string value = PrintExpr(op->value);
if (print_ssa_form_) {
CHECK(!var_idmap_.count(op->var.get()));
PrintStmt(op->body);
}
-void CodeGenC::VisitStmt_(const Allocate* op) {
+void CodeGenC::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
- const Variable* buffer = op->buffer_var.as<Variable>();
+ const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
stream << ' ';
this->PrintStmt(op->body);
}
-void CodeGenC::VisitStmt_(const AttrStmt* op) {
+void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
- alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
+ alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
} else if (op->attr_key == ir::attr::volatile_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
volatile_buf_.insert(v);
}
this->PrintStmt(op->body);
}
-void CodeGenC::VisitStmt_(const AssertStmt* op) {
+void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
- if (const auto* str = op->message.as<StringImm>()) {
+ if (const auto* str = op->message.as<StringImmNode>()) {
// GLOG style check
stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
} else {
this->PrintStmt(op->body);
}
-void CodeGenC::VisitStmt_(const For* op) {
+void CodeGenC::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "}\n";
}
-void CodeGenC::VisitStmt_(const IfThenElse* op) {
+void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
}
}
-void CodeGenC::VisitStmt_(const Evaluate* op) {
+void CodeGenC::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
- const Call* call = op->value.as<Call>();
+ const CallNode* call = op->value.as<CallNode>();
if (call) {
if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call); return;
call->args[3].dtype(),
call->args[0],
call->args[1],
- call->args[2].as<IntImm>()->value);
+ call->args[2].as<IntImmNode>()->value);
this->PrintIndent();
this->stream << ref << " = " << value << ";\n";
return;
}
}
-void CodeGenC::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
*/
virtual void InitFuncState(LoweredFunc f);
// expression
- void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Shuffle* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
- void VisitStmt_(const LetStmt* op) override;
- void VisitStmt_(const Store* op) override;
- void VisitStmt_(const For* op) override;
- void VisitStmt_(const IfThenElse* op) override;
- void VisitStmt_(const Allocate* op) override;
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const Evaluate* op) override;
+ void VisitStmt_(const LetStmtNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const IfThenElseNode* op) override;
+ void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
+ void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* Print Type represetnation of type t.
* \param t The type representation.
*/
virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
- virtual void PrintStorageSync(const Call* op); // NOLINT(*)
+ virtual void PrintStorageSync(const CallNode* op); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, DataType op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
- virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base);
+ virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
// print vector store
- virtual void PrintVecStore(const Variable* buffer,
+ virtual void PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value); // NOLINT(*)
// print load of single element
DataType t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
- DataType t, const Variable* buffer, Expr index);
+ DataType t, const VarNode* buffer, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
- bool HandleTypeMatch(const Variable* buf_var, DataType t) const;
+ bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
- void RegisterHandleType(const Variable* buf_var, DataType t);
+ void RegisterHandleType(const VarNode* buf_var, DataType t);
// override
void PrintSSAAssign(
const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
- std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
+ std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
- std::unordered_map<const Variable*, DataType> handle_data_type_;
+ std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
- std::unordered_set<const Variable*> volatile_buf_;
+ std::unordered_set<const VarNode*> volatile_buf_;
};
} // namespace codegen
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
-void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
this->stream << "}\n";
}
-void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
std::string stack_name = GetUniqueName("stack");
- const std::string& type = op->args[0].as<StringImm>()->value;
- const IntImm* num = op->args[1].as<IntImm>();
+ const std::string& type = op->args[0].as<StringImmNode>()->value;
+ const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
size_t unit = sizeof(TVMValue);
this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
- const StringImm* s = op->args[0].as<StringImm>();
+ const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
- int64_t begin = op->args[3].as<IntImm>()->value;
- int64_t end = op->args[4].as<IntImm>()->value;
+ int64_t begin = op->args[3].as<IntImmNode>()->value;
+ int64_t end = op->args[4].as<IntImmNode>()->value;
int64_t num_args = end - begin;
CHECK_GE(num_args, 0);
std::string func_name = s->value;
}
}
-void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
+void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
PrintIndent();
- stream << "TVMAPISetLastError(\"" << op->message.as<StringImm>()->value << "\");\n";
+ stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
this->PrintStmt(op->body);
}
-void CodeGenCHost::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
-void CodeGenCHost::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// overload visitor functions
- void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Call *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
- void VisitExpr_(const Min *op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Max *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*)
- void VisitStmt_(const AssertStmt *op) final; // NOLINT(*)
+ void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
private:
std::string module_name_;
return CodeGenC::Finish();
}
-void CodeGenCUDA::VisitStmt_(const ir::For* op) {
+void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) {
CHECK(is_const_int(op->min, 0));
if (op->for_type == ir::ForType::Unrolled) {
PrintIndent();
}
}
-void CodeGenCUDA::PrintStorageSync(const Call* op) {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared") {
}
}
-void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
+void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
- if (const StringImm *str = op->args[7].as<StringImm>()) {
+ if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
}
-void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
+void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::fragment_shape) {
- const Variable* buffer = op->node.as<Variable>();
- const StringImm* shape_str = op->value.as<StringImm>();
+ const VarNode* buffer = op->node.as<VarNode>();
+ const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
- const Variable* buffer = op->node.as<Variable>();
- const StringImm* layout_str = op->value.as<StringImm>();
+ const VarNode* buffer = op->node.as<VarNode>();
+ const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
}
CodeGenC::VisitStmt_(op);
}
-void CodeGenCUDA::VisitStmt_(const Allocate* op) {
+void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
- const Variable* buffer = op->buffer_var.as<Variable>();
+ const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
this->PrintStmt(op->body);
}
-void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
+void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) {
if (is_const(op->value)) return;
- const Call* call = op->value.as<Call>();
+ const CallNode* call = op->value.as<CallNode>();
if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
PrintIndent();
stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
}
}
-void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
+void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
os << "((make_int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
os << "))";
}
-void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);
os << ')';
}
-void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
+void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
os << ')';
}
-inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64: case 32: {
std::ostringstream temp;
}
-void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
- const Variable* variable, std::ostream &os) {
+ const VarNode* variable, std::ostream &os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
}
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
- const Variable* variable, int32_t size) {
+ const VarNode* variable, int32_t size) {
std::string shape_str = fragment_shapes[variable];
size_t m, n, k;
size_t last_pos = 0, pos = 0;
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
- void VisitStmt_(const ir::For* op) final;
- void PrintStorageSync(const Call* op) final;
+ void VisitStmt_(const ir::ForNode* op) final;
+ void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, DataType t,
const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor
- void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImm *op, std::ostream& os) final;
- void VisitExpr_(const Call *op, std::ostream& os) final;
- void VisitStmt_(const Evaluate *op) final;
- void VisitStmt_(const Allocate *op) final;
- void VisitStmt_(const AttrStmt *op) final;
+ void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
+ void VisitExpr_(const CallNode *op, std::ostream& os) final;
+ void VisitStmt_(const EvaluateNode *op) final;
+ void VisitStmt_(const AllocateNode *op) final;
+ void VisitStmt_(const AttrStmtNode *op) final;
private:
// Whether global barrier is needed.
// whether need mma.h
bool need_mma_h_{false};
- std::unordered_map<const Variable*, std::string> fragment_shapes;
- std::unordered_map<const Variable*, std::string> fragment_layouts;
- friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
+ std::unordered_map<const VarNode*, std::string> fragment_shapes;
+ std::unordered_map<const VarNode*, std::string> fragment_layouts;
+ friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(
- const std::string& scope, DataType t, const Variable* variable, std::ostream& os);
+ const std::string& scope, DataType t, const VarNode* variable, std::ostream& os);
int32_t GetWmmaFragmentSize(
- const std::string &scope, const Variable* variable, int32_t size);
+ const std::string &scope, const VarNode* variable, int32_t size);
};
} // namespace codegen
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
-void CodeGenMetal::PrintStorageSync(const Call* op) {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenMetal::PrintStorageSync(const CallNode* op) {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
}
}
-void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
PrintType(op->dtype, os);
os << "(";
os << ')';
}
-void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
- if (op->is_intrinsic(Call::reinterpret)) {
+void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
+ if (op->is_intrinsic(CallNode::reinterpret)) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->dtype, os);
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
- void PrintStorageSync(const Call* op) final; // NOLINT(*)
+ void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element
void PrintVecElemStore(
const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
- void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
- void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
private:
int thread_index_bits_{32};
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
-void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
+void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintExpr(base, os);
}
std::string CodeGenOpenCL::GetVecLoad(
- DataType t, const Variable* buffer, Expr base) {
+ DataType t, const VarNode* buffer, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
return os.str();
}
-void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
+void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value) {
this->PrintIndent();
stream << ");\n";
}
-void CodeGenOpenCL::PrintStorageSync(const Call* op) {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
return os.str();
}
-void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << "((";
PrintType(op->dtype, os);
os << "))";
}
-void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
/* Return type of ternary expression is not always same as its sub-expressions,
* add a cast */
os << "(";
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
os << "-";
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
- void PrintStorageSync(const Call* op) final; // NOLINT(*)
+ void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
- std::string GetVecLoad(DataType t, const Variable* buffer,
+ std::string GetVecLoad(DataType t, const VarNode* buffer,
Expr base) final;
- void PrintVecStore(const Variable* buffer,
+ void PrintVecStore(const VarNode* buffer,
DataType t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
- void PrintVecAddr(const Variable* buffer, DataType t,
+ void PrintVecAddr(const VarNode* buffer, DataType t,
Expr base, std::ostream& os); // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor
- void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
private:
// whether enable fp16 and fp64 extension
this->stream << "}\n";
}
-void CodeGenOpenGL::VisitStmt_(const Store* op) {
+void CodeGenOpenGL::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Store statement not supported in OpenGL."
<< " Texture store should be a Call statement.";
}
// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
-std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
+std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
std::ostringstream os;
os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
PrintExpr(index, os);
// Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r
std::string CodeGenOpenGL::GetBufferRef(
- DataType t, const Variable* buffer, Expr index) {
+ DataType t, const VarNode* buffer, Expr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
// Codegen for immediate values
-void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const StringImmNode*, std::ostream& os) {
LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
}
-void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
- auto call = op->value.as<Call>();
- if (call == nullptr || call->name != Call::glsl_texture_store) {
+void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
+ auto call = op->value.as<CallNode>();
+ if (call == nullptr || call->name != CallNode::glsl_texture_store) {
// Fallback to normal logic.
CodeGenC::VisitStmt_(op);
}
CHECK_EQ(call->args.size(), 2);
- auto buffer = call->args[0].as<Variable>();
+ auto buffer = call->args[0].as<VarNode>();
auto value = call->args[1];
// Doesn't support store to vector.
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final;
- void VisitStmt_(const Store* op) final;
- std::string TexelFetch(const Variable* buffer, Expr index);
- std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final;
+ void VisitStmt_(const StoreNode* op) final;
+ std::string TexelFetch(const VarNode* buffer, Expr index);
+ std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values
- void VisitExpr_(const IntImm* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const UIntImm* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const FloatImm* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const StringImm* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
+ void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*)
// Match glsl_texture_store Call.
- void VisitStmt_(const Evaluate* op) final; // NOLINT(*)
+ void VisitStmt_(const EvaluateNode* op) final; // NOLINT(*)
private:
- const Variable* output_{nullptr};
- std::unordered_set<const Variable*> inputs_;
- const Variable* output_iter_var_{nullptr};
+ const VarNode* output_{nullptr};
+ std::unordered_set<const VarNode*> inputs_;
+ const VarNode* output_iter_var_{nullptr};
std::unordered_map<std::string, runtime::OpenGLShader> shaders_;
std::string thread_extent_var_;
};
return e.vid;
}
-std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
+std::string CodeGenSourceBase::AllocVarID(const VarNode* v) {
CHECK(!var_idmap_.count(v))
<< "Need input to be in SSA form dup " << v->name_hint;
std::string key = v->name_hint;
return vid;
}
-std::string CodeGenSourceBase::GetVarID(const Variable* v) const {
+std::string CodeGenSourceBase::GetVarID(const VarNode* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \param v The variable.
* \return the variable name.
*/
- std::string AllocVarID(const Variable* v);
+ std::string AllocVarID(const VarNode* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the variable name.
*/
- std::string GetVarID(const Variable* v) const;
+ std::string GetVarID(const VarNode* v) const;
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
/*! \brief the stream to be printed */
std::ostringstream stream;
/*! \brief name of each variable */
- std::unordered_map<const Variable*, std::string> var_idmap_;
+ std::unordered_map<const VarNode*, std::string> var_idmap_;
private:
/*! \brief assignment map of ssa */
os << ')';
}
-void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
+void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
PrintBinaryExpr(op, opstr, os, this);
}
-void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
+void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max";
if (op->dtype.is_float()) {
switch (op->dtype.bits()) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
- void VisitExpr_(const Min *op, std::ostream& os);
- void VisitExpr_(const Max *op, std::ostream& os);
+ void VisitExpr_(const MinNode *op, std::ostream& os);
+ void VisitExpr_(const MaxNode *op, std::ostream& os);
};
} // namespace codegen
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
- const Call* call = e.as<Call>();
+ const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
- const Call* call = e.as<Call>();
+ const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
auto one = make_const(call->args[0].dtype(), 1);
template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
- const Call* call = e.as<Call>();
+ const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::string name = T()(call->dtype, call->name);
if (name.length() != 0) {
- *rv = Call::make(
- call->dtype, name, call->args, Call::PureExtern);
+ *rv = CallNode::make(
+ call->dtype, name, call->args, CallNode::PureExtern);
} else {
*rv = e;
}
function_->addFnAttr("amdgpu-flat-work-group-size", attr.str());
}
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
return builder_->CreateCall(f, {});
}
- llvm::Value* CreateStorageSync(const Call* op) final {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+ llvm::Value* CreateStorageSync(const CallNode* op) final {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
return nullptr;
} else if (sync == "shared") {
Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
for (auto &bitcode : bitcode_files) {
- std::string path = bitcode.as<StringImm>()->value;
+ std::string path = bitcode.as<StringImmNode>()->value;
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
if (mlib.get() == nullptr) {
native_vector_bits_ = 16 * 8;
CodeGenCPU::InitTarget(tm);
}
- llvm::Value* CreateIntrinsic(const Call* op) override;
+ llvm::Value* CreateIntrinsic(const CallNode* op) override;
private:
- Expr ARMPopcount(const Call* op);
+ Expr ARMPopcount(const CallNode* op);
};
-llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- op->args[0].as<UIntImm>()->value);
+ op->args[0].as<UIntImmNode>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
Expr e = ARMPopcount(op);
- return CodeGenCPU::CreateIntrinsic(e.as<Call>());
+ return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}
-Expr CodeGenARM::ARMPopcount(const Call *call) {
+Expr CodeGenARM::ARMPopcount(const CallNode *call) {
using namespace ir;
const Expr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<Expr> vcnt_args;
- vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
- vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+ vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
+ vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt_args.push_back(e);
- return ir::Call::make(call->dtype, "llvm_intrin", vcnt_args, Call::PureIntrinsic);
+ return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
}
// Popcount lowering rule:
// Interpret input as vector of 8bit values
Expr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
- const Call* c0 = input8.as<Call>();
+ const CallNode* c0 = input8.as<CallNode>();
CHECK(c0 != nullptr);
Array<Expr> vcnt8_args;
- vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
- vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+ vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
+ vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
- Expr vcnt8 = ir::Call::make(uint8_type, "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
+ Expr vcnt8 = ir::CallNode::make(
+ uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
Array<Expr> vcnt16_args;
- vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
- vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+ vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+ vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
- Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
+ Expr vcnt16 = ir::CallNode::make(
+ uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
Array<Expr> vcnt32_args;
- vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
- vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+ vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+ vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
- Expr vcnt32 = ir::Call::make(uint32_type, "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
+ Expr vcnt32 = ir::CallNode::make(
+ uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
Array<Expr> vcnt64_args;
- vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
- vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+ vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+ vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
- return ir::Call::make(call->dtype, "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
+ return ir::CallNode::make(
+ call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
}
}
-llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) {
+llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
return end_block;
}
-void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
+void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
llvm::Function* fcompute =
llvm::Function::Create(ftype,
llvm::Function::PrivateLinkage,
- op->value.as<StringImm>()->value,
+ op->value.as<StringImmNode>()->value,
module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(
builder_->CreateCall(fcompute, arg_values));
// setup compute fuinction.
- std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+ std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
size_t idx = 0;
for (auto it = fcompute->arg_begin();
it != fcompute->arg_end(); ++it, ++idx) {
void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
const Array<Var>& vfields,
- std::unordered_map<const Variable*, llvm::Value*>* vmap) {
+ std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
- std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+ std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
// setup parallel env
ParallelEnv par_env;
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
// setup new variable map, swap it with current var context.
- std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+ std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
CHECK(parallel_env_.penv == nullptr);
std::swap(function_, f);
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
- std::string func_name = args[0].as<StringImm>()->value;
+ std::string func_name = args[0].as<StringImmNode>()->value;
llvm::Value *handle = GetPackedFuncHandle(func_name);
// call the function
int64_t nargs = end - begin;
return end_block;
}
-llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) {
+llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) {
CHECK_EQ(op->args.size(), 5U);
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
- op->args[3].as<IntImm>()->value,
- op->args[4].as<IntImm>()->value);
+ op->args[3].as<IntImmNode>()->value,
+ op->args[4].as<IntImmNode>()->value);
return rvalue;
}
-llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) {
+llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
using llvm::BasicBlock;
CHECK_EQ(op->args.size(), 6U);
llvm::Value *rvalue = nullptr;
llvm::Value *ret_tcode = nullptr;
BasicBlock *end_block = MakeCallPacked(
- op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImm>()->value,
- op->args[4].as<IntImm>()->value);
+ op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
+ op->args[4].as<IntImmNode>()->value);
// Get traced value.
llvm::Value *traced_value = MakeValue(op->args[5]);
// The update_block handles case when we need to update the return value.
}
}
-llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
return CreateCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) {
return ConstInt32(-1);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
- int kind = op->args[2].as<IntImm>()->value;
+ int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref = this->CreateStructRefPtr(
op->dtype, MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
}
} else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
- int kind = op->args[2].as<IntImm>()->value;
+ int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(
op->args[3].dtype(), MakeValue(op->args[0]),
return ConstInt32(0);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
- const std::string& type = op->args[0].as<StringImm>()->value;
+ const std::string& type = op->args[0].as<StringImmNode>()->value;
return WithFunctionEntry([&]() -> llvm::AllocaInst* {
const int64_t* pval = as_const_int(op->args[1]);
CHECK(pval) << "require stack alloca to contain constant value";
}
}
-void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
+void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
- if (op->message.as<StringImm>()) {
- os << ", " << op->message.as<StringImm>()->value;
+ if (op->message.as<StringImmNode>()) {
+ os << ", " << op->message.as<StringImmNode>()->value;
}
llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create(
CodeGenLLVM::VisitStmt_(op);
}
-void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
+void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::coproc_uop_scope) {
- this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
+ this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == ir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) {
RuntimeTVMParallelBarrier(),
{MakeValue(parallel_env_.task_id), parallel_env_.penv});
} else if (op->attr_key == ir::attr::pragma_import_llvm) {
- const StringImm* value = op->value.as<StringImm>();
+ const StringImmNode* value = op->value.as<StringImmNode>();
CHECK(value != nullptr);
this->HandleImport(value->value);
this->VisitStmt(op->body);
}
}
-void CodeGenCPU::VisitStmt_(const For* op) {
+void CodeGenCPU::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
if (op->for_type == ForType::Serial ||
op->for_type == ForType::Unrolled) {
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
- For::make(
+ ForNode::make(
op->loop_var, op->min, op->extent,
op->for_type, op->device_api, op->body), 0);
} else {
op->body);
} else {
Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
- Expr begin = Min::make(task_id * step, op->extent);
- Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
+ Expr begin = MinNode::make(task_id * step, op->extent);
+ Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
void AddFunction(const LoweredFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const For* op) override;
- llvm::Value* CreateIntrinsic(const Call* op) override;
- llvm::Value* CreateCallExtern(const Call* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ llvm::Value* CreateIntrinsic(const CallNode* op) override;
+ llvm::Value* CreateCallExtern(const CallNode* op) override;
protected:
void AddStartupFunction() final;
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value*cdata,
const Array<Var>& fields,
- std::unordered_map<const Variable*, llvm::Value*>* vmap);
+ std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end);
// create call into tvm packed function.
- llvm::Value* CreateCallPacked(const Call* op);
+ llvm::Value* CreateCallPacked(const CallNode* op);
// Create trace call into tvm packed function.
- llvm::Value* CreateCallTracePacked(const Call *op);
+ llvm::Value* CreateCallTracePacked(const CallNode *op);
// Create static initialization
void CreateStaticInit(const std::string& init_fname, const Stmt& body);
// Create parallel launch
void CreateParallelLaunch(const Stmt& body, int num_task);
// Create a new compute scope.
- void CreateComputeScope(const AttrStmt* op);
+ void CreateComputeScope(const AttrStmtNode* op);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
return nullptr;
}
-llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) {
LOG(FATAL) << "not implemented";
return nullptr;
}
// This trick comes from Halide's CodeGen_LLVM
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
- const Variable* buffer,
+ const VarNode* buffer,
Expr index,
DataType type) {
if (alias_var_set_.count(buffer) != 0) {
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
if (index.defined()) {
- const Ramp* ramp = index.as<Ramp>();
+ const RampNode* ramp = index.as<RampNode>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
}
void CodeGenLLVM::GetAlignment(DataType t,
- const Variable* buf_var,
+ const VarNode* buf_var,
const Expr& index,
int* p_alignment,
int* p_native_bits) {
return builder_->CreateInBoundsGEP(buffer, index);
}
-llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
+llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
return it->second;
}
-llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 0; i < op->args.size(); ++i) {
return call;
}
-llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- op->args[0].as<UIntImm>()->value);
+ op->args[0].as<UIntImmNode>()->value);
const uint64_t *num_signature = as_const_uint(op->args[1]);
CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
<< "but " << op->args[1] << " got!\n";
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, sig_type);
return builder_->CreateCall(f, arg_value);
- } else if (op->is_intrinsic(Call::bitwise_and)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(Call::bitwise_or)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_or)) {
return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(Call::bitwise_not)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_not)) {
return builder_->CreateNot(MakeValue(op->args[0]));
- } else if (op->is_intrinsic(Call::bitwise_xor)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(Call::shift_left)) {
+ } else if (op->is_intrinsic(CallNode::shift_left)) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
- } else if (op->is_intrinsic(Call::shift_right)) {
+ } else if (op->is_intrinsic(CallNode::shift_right)) {
if (op->args[0].dtype().is_int()) {
return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load *l = op->args[0].as<Load>();
+ const LoadNode *l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
- const Ramp *r = l->index.as<Ramp>();
+ const RampNode *r = l->index.as<RampNode>();
llvm::Value* ptr;
unsigned addrspace;
if (!r) {
ptr->getType())->getAddressSpace();
}
return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
- } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
+ } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) {
return llvm::Constant::getNullValue(t_void_p_);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
return builder_->CreateIsNull(MakeValue(op->args[0]));
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
- } else if (op->is_intrinsic(Call::reinterpret)) {
+ } else if (op->is_intrinsic(CallNode::reinterpret)) {
llvm::Type * target = LLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
- } else if (op->is_intrinsic(Call::isnan)) {
+ } else if (op->is_intrinsic(CallNode::isnan)) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
void CodeGenLLVM::Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
- if (const Ramp* ramp = e.as<Ramp>()) {
+ if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
Expr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
// Visitors
-llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
return GetVarValue(op);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) {
return llvm::ConstantInt::get(LLVMType(op->dtype), op->value);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
return GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
- DataType t, llvm::Value* a, llvm::Value *b) { \
+ DataType t, llvm::Value* a, llvm::Value *b) { \
if (t.is_int()) { \
if (t.bits() >= 32) { \
return builder_->CreateNSW ## Op (a, b); \
return builder_->CreateF ## Op (a, b); \
} \
} \
- llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
- return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
+ llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \
+ return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_BINARY_OP(Add);
#define DEFINE_CODEGEN_CMP_OP(Op) \
llvm::Value* CodeGenLLVM::Create ## Op( \
- DataType t, llvm::Value* a, llvm::Value* b) { \
+ DataType t, llvm::Value* a, llvm::Value* b) { \
if (t.is_int()) { \
return builder_->CreateICmpS ## Op (a, b); \
} else if (t.is_uint()) { \
return builder_->CreateFCmpO ## Op (a, b); \
} \
} \
- llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \
+ llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \
return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
-llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
}
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
}
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
}
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
}
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
return builder_->CreateNot(MakeValue(op->a));
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
return builder_->CreateSelect(
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
DataType t = op->dtype;
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
llvm::Value* buffer = MakeValue(op->buffer_var);
// vector load
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
buffer->getType())->getAddressSpace();
- if (const Ramp* ramp = op->index.as<Ramp>()) {
+ if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
return ret;
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
- if (op->call_type == Call::Intrinsic ||
- op->call_type == Call::PureIntrinsic) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+ if (op->call_type == CallNode::Intrinsic ||
+ op->call_type == CallNode::PureIntrinsic) {
return CreateIntrinsic(op);
- } else if (op->call_type == Call::Extern ||
- op->call_type == Call::PureExtern) {
+ } else if (op->call_type == CallNode::Extern ||
+ op->call_type == CallNode::PureExtern) {
return CreateCallExtern(op);
} else {
LOG(FATAL) << "Unknown call type " <<
}
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
return vec;
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
std::vector<llvm::Value *> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
return res;
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
return CreateBroadcast(MakeValue(op->value), op->lanes);
}
-void CodeGenLLVM::VisitStmt_(const Store* op) {
+void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
CHECK(is_one(op->predicate));
DataType t = op->value.dtype();
bool is_volatile = volatile_buf_.count(op->buffer_var.get());
// vector store
unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
buffer->getType())->getAddressSpace();
- if (const Ramp* ramp = op->index.as<Ramp>()) {
+ if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
this->Scalarize(op->index, f);
}
-void CodeGenLLVM::VisitStmt_(const For* op) {
+void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
if (op->for_type == ForType::Unrolled) {
}
-void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
+void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
BasicBlock* then_block = BasicBlock::Create(
}
-void CodeGenLLVM::VisitStmt_(const Allocate* op) {
+void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
this->VisitStmt(op->body);
}
-void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
+void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
alloc_storage_info_[v].scope =
- runtime::StorageScope::make(op->value.as<StringImm>()->value);
+ runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == ir::attr::storage_alignment) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
alloc_storage_info_[v].alignment =
- static_cast<int>(op->value.as<IntImm>()->value);
+ static_cast<int>(op->value.as<IntImmNode>()->value);
} else if (op->attr_key == ir::attr::volatile_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
volatile_buf_.insert(v);
}
this->VisitStmt(op->body);
}
-void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
+void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}
-void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
+void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
CHECK(!var_map_.count(op->var.get()));
if (op->var.dtype().is_handle()) {
if (!is_restricted_) {
}
}
-void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
+void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value);
}
-void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
} // namespace codegen
return llvm::ConstantInt::getSigned(t_int32_, value);
}
// override codegen
- llvm::Value* VisitExpr_(const Variable* op) override;
- llvm::Value* VisitExpr_(const Cast* op) override;
- llvm::Value* VisitExpr_(const IntImm* op) override;
- llvm::Value* VisitExpr_(const UIntImm* op) override;
- llvm::Value* VisitExpr_(const FloatImm* op) override;
- llvm::Value* VisitExpr_(const StringImm* op) override;
- llvm::Value* VisitExpr_(const Add* op) override;
- llvm::Value* VisitExpr_(const Sub* op) override;
- llvm::Value* VisitExpr_(const Mul* op) override;
- llvm::Value* VisitExpr_(const Div* op) override;
- llvm::Value* VisitExpr_(const Mod* op) override;
- llvm::Value* VisitExpr_(const Min* op) override;
- llvm::Value* VisitExpr_(const Max* op) override;
- llvm::Value* VisitExpr_(const LT* op) override;
- llvm::Value* VisitExpr_(const LE* op) override;
- llvm::Value* VisitExpr_(const GT* op) override;
- llvm::Value* VisitExpr_(const GE* op) override;
- llvm::Value* VisitExpr_(const EQ* op) override;
- llvm::Value* VisitExpr_(const NE* op) override;
- llvm::Value* VisitExpr_(const And* op) override;
- llvm::Value* VisitExpr_(const Or* op) override;
- llvm::Value* VisitExpr_(const Not* op) override;
- llvm::Value* VisitExpr_(const Select* op) override;
- llvm::Value* VisitExpr_(const Let* op) override;
- llvm::Value* VisitExpr_(const Load* op) override;
- llvm::Value* VisitExpr_(const Call* op) override;
- llvm::Value* VisitExpr_(const Ramp* op) override;
- llvm::Value* VisitExpr_(const Shuffle* op) override;
- llvm::Value* VisitExpr_(const Broadcast* op) override;
+ llvm::Value* VisitExpr_(const VarNode* op) override;
+ llvm::Value* VisitExpr_(const CastNode* op) override;
+ llvm::Value* VisitExpr_(const IntImmNode* op) override;
+ llvm::Value* VisitExpr_(const UIntImmNode* op) override;
+ llvm::Value* VisitExpr_(const FloatImmNode* op) override;
+ llvm::Value* VisitExpr_(const StringImmNode* op) override;
+ llvm::Value* VisitExpr_(const AddNode* op) override;
+ llvm::Value* VisitExpr_(const SubNode* op) override;
+ llvm::Value* VisitExpr_(const MulNode* op) override;
+ llvm::Value* VisitExpr_(const DivNode* op) override;
+ llvm::Value* VisitExpr_(const ModNode* op) override;
+ llvm::Value* VisitExpr_(const MinNode* op) override;
+ llvm::Value* VisitExpr_(const MaxNode* op) override;
+ llvm::Value* VisitExpr_(const LTNode* op) override;
+ llvm::Value* VisitExpr_(const LENode* op) override;
+ llvm::Value* VisitExpr_(const GTNode* op) override;
+ llvm::Value* VisitExpr_(const GENode* op) override;
+ llvm::Value* VisitExpr_(const EQNode* op) override;
+ llvm::Value* VisitExpr_(const NENode* op) override;
+ llvm::Value* VisitExpr_(const AndNode* op) override;
+ llvm::Value* VisitExpr_(const OrNode* op) override;
+ llvm::Value* VisitExpr_(const NotNode* op) override;
+ llvm::Value* VisitExpr_(const SelectNode* op) override;
+ llvm::Value* VisitExpr_(const LetNode* op) override;
+ llvm::Value* VisitExpr_(const LoadNode* op) override;
+ llvm::Value* VisitExpr_(const CallNode* op) override;
+ llvm::Value* VisitExpr_(const RampNode* op) override;
+ llvm::Value* VisitExpr_(const ShuffleNode* op) override;
+ llvm::Value* VisitExpr_(const BroadcastNode* op) override;
// stmt
- void VisitStmt_(const Store* op) override;
- void VisitStmt_(const For* op) override;
- void VisitStmt_(const IfThenElse* op) override;
- void VisitStmt_(const Allocate* op) override;
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const LetStmt* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const IfThenElseNode* op) override;
+ void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
- void VisitStmt_(const Evaluate* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
+ void VisitStmt_(const EvaluateNode* op) override;
+ void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
return res;
}
// create intrinstic given call
- virtual llvm::Value* CreateIntrinsic(const Call* op);
+ virtual llvm::Value* CreateIntrinsic(const CallNode* op);
// create extern function call
- virtual llvm::Value* CreateCallExtern(const Call* op);
+ virtual llvm::Value* CreateCallExtern(const CallNode* op);
// Get the corresponding thread index
virtual llvm::Value* GetThreadIndex(const IterVar& iv);
// Get the corresponding thread index
- virtual llvm::Value* CreateStorageSync(const Call* op);
+ virtual llvm::Value* CreateStorageSync(const CallNode* op);
// apply optimization on the module.
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
void InitFuncState();
// Get alignment given index.
void GetAlignment(
- DataType t, const Variable* buf_var, const Expr& index,
+ DataType t, const VarNode* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
- const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
+ const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// handle module import
void HandleImport(const std::string& code);
// cast operatpr
llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
// comparison op
- llvm::Value* GetVarValue(const Variable* v) const;
+ llvm::Value* GetVarValue(const VarNode* v) const;
llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* stride,
const VarExpr& loop_var, const Stmt& body);
// add alias information.
- void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type);
+ void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
/*! \brief native vector bits of current targetx*/
int native_vector_bits_{0};
/*! \brief the storage scope of allocation */
- std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
+ std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
// The definition of local variable.
- std::unordered_map<const Variable*, llvm::Value*> var_map_;
+ std::unordered_map<const VarNode*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// Whether current function is restricted
// The analyzer information
std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias)
- std::unordered_set<const Variable*> alias_var_set_;
+ std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
- std::unordered_set<const Variable*> volatile_buf_;
+ std::unordered_set<const VarNode*> volatile_buf_;
/*! \brief Helper struct for debug infos. */
struct DebugInfo {
std::unique_ptr<llvm::DIBuilder> di_builder_;
llvm::ValueAsMetadata::get(ConstInt32(1)) }));
}
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
return builder_->CreateCall(f, {});
}
- llvm::Value* CreateStorageSync(const Call* op) final {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+ llvm::Value* CreateStorageSync(const CallNode* op) final {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// TODO(tqchen) warp sync in CUDA9
return nullptr;
class CodeGenX86_64 final : public CodeGenCPU {
public:
- llvm::Value* VisitExpr_(const Cast* op) override;
+ llvm::Value* VisitExpr_(const CastNode* op) override;
private:
llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
const std::vector<llvm::Value*>& args);
};
-llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
+llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
// LLVM does not automatically generate the correct instruction sequences for
// half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
// vcvtph2ps), so we explicitly generate them ourselves.
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())),
{
- MakeValue(ir::Call::make(
- DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
- ir::Call::PureIntrinsic)),
+ MakeValue(ir::CallNode::make(
+ DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
+ ir::CallNode::PureIntrinsic)),
MakeValue(
- ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())),
- /*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)),
- /*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)),
+ ir::BroadcastNode::make(
+ ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
+ /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
+ /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
});
}
if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
- {MakeValue(ir::Call::make(
- DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
- ir::Call::PureIntrinsic))});
+ {MakeValue(ir::CallNode::make(
+ DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
+ ir::CallNode::PureIntrinsic))});
}
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
- const ir::Call* call = e.as<ir::Call>();
+ const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
const Expr& x = call->args[0];
Expr one = make_const(x.dtype(), 1);
Expr two = make_const(x.dtype(), 2);
Expr neg_two = make_const(x.dtype(), -2);
- Expr exp_neg2x = ir::Call::make(
- x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
- Expr exp_pos2x = ir::Call::make(
- x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic);
+ Expr exp_neg2x = ir::CallNode::make(
+ x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
+ Expr exp_pos2x = ir::CallNode::make(
+ x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
- *rv = ir::Select::make(
+ *rv = ir::SelectNode::make(
x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
- const ir::Call* call = e.as<ir::Call>();
+ const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
- cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
+ cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+ cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
- *rv = ir::Call::make(
- call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
+ *rv = ir::CallNode::make(
+ call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic);
}
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
- const ir::Call* call = e.as<ir::Call>();
+ const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
- cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
+ cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+ cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
- *rv = ir::Call::make(
- call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic);
+ *rv = ir::CallNode::make(
+ call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic);
}
} // namespace codegen
inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
- const Call* call = e.as<Call>();
+ const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
std::ostringstream intrinsic_name;
intrinsic_name << "__nv_" << call->name;
if (call->dtype.bits() == 32) intrinsic_name << "f";
- *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
- Call::PureExtern);
+ *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
+ CallNode::PureExtern);
}
namespace llvm {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
Expr e = args[0];
using namespace ir;
- const Call* call = e.as<Call>();
+ const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::ostringstream intrinsic_name;
intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
- *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
- Call::PureExtern);
+ *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
+ CallNode::PureExtern);
}
namespace llvm {
return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
}
-spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
- const std::string& sync = op->args[0].as<StringImm>()->value;
+spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
+ const std::string& sync = op->args[0].as<StringImmNode>()->value;
spirv::Value value;
if (sync == "warp") {
return value;
return value;
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
auto it = var_map_.find(op);
CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
return it->second;
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) {
return builder_->UIntImm(builder_->GetSType(op->dtype), op->value);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
LOG(FATAL) << "StringImm is not supported in Device code";
return spirv::Value();
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
return builder_->Add(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Sub* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Mul* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Div* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
return builder_->Div(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Mod* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Min* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->LT(a, b), a, b);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Max* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->Select(builder_->GT(a, b), a, b);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const LT* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
return builder_->LT(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const LE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
return builder_->LE(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const GT* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
return builder_->GT(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const GE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
return builder_->GE(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const EQ* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const NE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
return builder_->NE(MakeValue(op->a), MakeValue(op->b));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const And* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Or* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
spirv::Value a = MakeValue(op->a);
spirv::Value b = MakeValue(op->b);
return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Not* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
spirv::Value a = MakeValue(op->a);
return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
return builder_->Select(MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic("spirv_glsl450")) {
CHECK_GE(op->args.size(), 2U);
- uint32_t inst_id = op->args[0].as<UIntImm>()->value;
+ uint32_t inst_id = op->args[0].as<UIntImmNode>()->value;
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
}
return builder_->CallGLSL450(
builder_->GetSType(op->dtype), inst_id, values);
- } else if (op->is_intrinsic(Call::bitwise_and)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
- } else if (op->is_intrinsic(Call::bitwise_xor)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
- } else if (op->is_intrinsic(Call::bitwise_or)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_or)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
- } else if (op->is_intrinsic(Call::bitwise_not)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
spirv::Value a = MakeValue(op->args[0]);
return builder_->MakeValue(spv::OpNot, a.stype, a);
- } else if (op->is_intrinsic(Call::shift_left)) {
+ } else if (op->is_intrinsic(CallNode::shift_left)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
- } else if (op->is_intrinsic(Call::shift_right)) {
+ } else if (op->is_intrinsic(CallNode::shift_right)) {
CHECK_EQ(op->args.size(), 2U);
spirv::Value a = MakeValue(op->args[0]);
spirv::Value b = MakeValue(op->args[1]);
} else {
return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
}
- } else if (op->is_intrinsic(Call::reinterpret)) {
+ } else if (op->is_intrinsic(CallNode::reinterpret)) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
} else {
- if (op->call_type == Call::Intrinsic ||
- op->call_type == Call::PureIntrinsic) {
+ if (op->call_type == CallNode::Intrinsic ||
+ op->call_type == CallNode::PureIntrinsic) {
LOG(FATAL) << "Unresolved intrinsic " << op->name
<< " with return type " << op->dtype;
- } else if (op->call_type == Call::Extern ||
- op->call_type == Call::PureExtern) {
+ } else if (op->call_type == CallNode::Extern ||
+ op->call_type == CallNode::PureExtern) {
LOG(FATAL) << "Unresolved extern " << op->name
<< " with return type " << op->dtype;
} else {
}
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
std::vector<spirv::Value> values;
spirv::Value base = MakeValue(op->base);
for (int i = 0; i < op->lanes; ++i) {
return builder_->Concat(values);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Broadcast* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
std::vector<spirv::Value> values;
spirv::Value v = MakeValue(op->value);
for (int i = 0; i < op->lanes; i++) {
return builder_->Concat(values);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
CHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());
CHECK(it != storage_info_.end());
this->Scalarize(op->index, f);
return builder_->Concat(values);
} else {
- if (const Ramp* ramp = op->index.as<Ramp>()) {
+ if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->dtype.lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
void CodeGenSPIRV::Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f) {
- if (const Ramp* ramp = e.as<Ramp>()) {
+ if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
Expr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
}
-void CodeGenSPIRV::VisitStmt_(const Store* op) {
+void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
CHECK(is_one(op->predicate));
auto it = storage_info_.find(op->buffer_var.get());
CHECK(it != storage_info_.end());
};
this->Scalarize(op->index, f);
} else {
- if (const Ramp* ramp = op->index.as<Ramp>()) {
+ if (const RampNode* ramp = op->index.as<RampNode>()) {
if (is_one(ramp->stride)) {
CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
arith::ModularSet me = analyzer_->modular_set(ramp->base);
}
}
-void CodeGenSPIRV::VisitStmt_(const For* op) {
+void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
spirv::Value init_value = MakeValue(op->min);
builder_->StartLabel(merge_label);
}
-void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
+void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
spirv::Value cond = MakeValue(op->condition);
spirv::Label then_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
builder_->StartLabel(merge_label);
}
-void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
+void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
CHECK(!op->new_expr.defined());
CHECK(!op->dtype.is_handle());
this->VisitStmt(op->body);
}
-void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
storage_info_[v].scope =
- runtime::StorageScope::make(op->value.as<StringImm>()->value);
+ runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == ir::attr::volatile_scope) {
- const Variable* v = op->node.as<Variable>();
+ const VarNode* v = op->node.as<VarNode>();
CHECK(v);
storage_info_[v].is_volatile = true;
}
this->VisitStmt(op->body);
}
-void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}
-void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
CHECK(!var_map_.count(op->var.get()));
CHECK(!op->var.dtype().is_handle());
var_map_[op->var.get()] = MakeValue(op->value);
}
}
-void CodeGenSPIRV::VisitStmt_(const Evaluate* op) {
+void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
MakeValue(op->value);
}
-void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
return VisitExpr(e);
}
// override codegen
- spirv::Value VisitExpr_(const Variable* op) override;
- spirv::Value VisitExpr_(const Cast* op) override;
- spirv::Value VisitExpr_(const IntImm* op) override;
- spirv::Value VisitExpr_(const UIntImm* op) override;
- spirv::Value VisitExpr_(const FloatImm* op) override;
- spirv::Value VisitExpr_(const StringImm* op) override;
- spirv::Value VisitExpr_(const Add* op) override;
- spirv::Value VisitExpr_(const Sub* op) override;
- spirv::Value VisitExpr_(const Mul* op) override;
- spirv::Value VisitExpr_(const Div* op) override;
- spirv::Value VisitExpr_(const Mod* op) override;
- spirv::Value VisitExpr_(const Min* op) override;
- spirv::Value VisitExpr_(const Max* op) override;
- spirv::Value VisitExpr_(const LT* op) override;
- spirv::Value VisitExpr_(const LE* op) override;
- spirv::Value VisitExpr_(const GT* op) override;
- spirv::Value VisitExpr_(const GE* op) override;
- spirv::Value VisitExpr_(const EQ* op) override;
- spirv::Value VisitExpr_(const NE* op) override;
- spirv::Value VisitExpr_(const And* op) override;
- spirv::Value VisitExpr_(const Or* op) override;
- spirv::Value VisitExpr_(const Not* op) override;
- spirv::Value VisitExpr_(const Select* op) override;
- spirv::Value VisitExpr_(const Let* op) override;
- spirv::Value VisitExpr_(const Call* op) override;
- spirv::Value VisitExpr_(const Ramp* op) override;
- spirv::Value VisitExpr_(const Broadcast* op) override;
- spirv::Value VisitExpr_(const Load* op) override;
+ spirv::Value VisitExpr_(const VarNode* op) override;
+ spirv::Value VisitExpr_(const CastNode* op) override;
+ spirv::Value VisitExpr_(const IntImmNode* op) override;
+ spirv::Value VisitExpr_(const UIntImmNode* op) override;
+ spirv::Value VisitExpr_(const FloatImmNode* op) override;
+ spirv::Value VisitExpr_(const StringImmNode* op) override;
+ spirv::Value VisitExpr_(const AddNode* op) override;
+ spirv::Value VisitExpr_(const SubNode* op) override;
+ spirv::Value VisitExpr_(const MulNode* op) override;
+ spirv::Value VisitExpr_(const DivNode* op) override;
+ spirv::Value VisitExpr_(const ModNode* op) override;
+ spirv::Value VisitExpr_(const MinNode* op) override;
+ spirv::Value VisitExpr_(const MaxNode* op) override;
+ spirv::Value VisitExpr_(const LTNode* op) override;
+ spirv::Value VisitExpr_(const LENode* op) override;
+ spirv::Value VisitExpr_(const GTNode* op) override;
+ spirv::Value VisitExpr_(const GENode* op) override;
+ spirv::Value VisitExpr_(const EQNode* op) override;
+ spirv::Value VisitExpr_(const NENode* op) override;
+ spirv::Value VisitExpr_(const AndNode* op) override;
+ spirv::Value VisitExpr_(const OrNode* op) override;
+ spirv::Value VisitExpr_(const NotNode* op) override;
+ spirv::Value VisitExpr_(const SelectNode* op) override;
+ spirv::Value VisitExpr_(const LetNode* op) override;
+ spirv::Value VisitExpr_(const CallNode* op) override;
+ spirv::Value VisitExpr_(const RampNode* op) override;
+ spirv::Value VisitExpr_(const BroadcastNode* op) override;
+ spirv::Value VisitExpr_(const LoadNode* op) override;
// stmt
- void VisitStmt_(const Store* op) override;
- void VisitStmt_(const For* op) override;
- void VisitStmt_(const IfThenElse* op) override;
- void VisitStmt_(const Allocate* op) override;
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const LetStmt* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const IfThenElseNode* op) override;
+ void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const LetStmtNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
- void VisitStmt_(const Evaluate* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
+ void VisitStmt_(const EvaluateNode* op) override;
+ void VisitStmt_(const ProducerConsumerNode* op) override;
protected:
/*! \brief The storage information */
void InitFuncState();
// Get the thread index
spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
- spirv::Value CreateStorageSync(const Call* op);
+ spirv::Value CreateStorageSync(const CallNode* op);
void Scalarize(const Expr& e,
std::function<void(int i, spirv::Value v)> f);
// The builder
// Likely branch
uint32_t weight_likely_branch_{128};
// the storage scope of allocation
- std::unordered_map<const Variable*, StorageInfo> storage_info_;
+ std::unordered_map<const VarNode*, StorageInfo> storage_info_;
// The definition of local variable.
- std::unordered_map<const Variable*, spirv::Value> var_map_;
+ std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;
};
template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
Expr e = targs[0];
- const ir::Call* call = e.as<ir::Call>();
+ const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
Array<Expr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
+ cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
for (Expr arg : call->args) {
cargs.push_back(arg);
}
- *rv = ir::Call::make(
- call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
+ *rv = ir::CallNode::make(
+ call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic);
}
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
return sid;
}
-int CodeGenStackVM::AllocVarID(const Variable* v) {
+int CodeGenStackVM::AllocVarID(const VarNode* v) {
CHECK(!var_idmap_.count(v));
int vid = static_cast<int>(vm_.heap_size);
CHECK_EQ(vm_.heap_size, var_idmap_.size());
return vid;
}
-int CodeGenStackVM::GetVarID(const Variable* v) const {
+int CodeGenStackVM::GetVarID(const VarNode* v) const {
auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end())
<< "Find undefined Variable " << v->name_hint;
return it->second;
}
-void CodeGenStackVM::VisitExpr_(const Load* op) {
+void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetLoad(op->dtype);
- if (const IntImm* index = op->index.as<IntImm>()) {
+ if (const IntImmNode* index = op->index.as<IntImmNode>()) {
this->PushOp(code, index->value);
} else {
this->Push(op->index);
}
}
-void CodeGenStackVM::VisitStmt_(const Store* op) {
+void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
- if (const IntImm* index = op->index.as<IntImm>()) {
+ if (const IntImmNode* index = op->index.as<IntImmNode>()) {
this->Push(op->value);
this->PushOp(code, index->value);
} else {
}
}
-void CodeGenStackVM::VisitStmt_(const Allocate* op) {
+void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
CHECK(!is_zero(op->condition));
int vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
}
}
-void CodeGenStackVM::VisitExpr_(const Call* op) {
+void CodeGenStackVM::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load *l = op->args[0].as<Load>();
+ const LoadNode *l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
this->Push(l->index);
this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
- } else if (op->is_intrinsic(Call::reinterpret)) {
+ } else if (op->is_intrinsic(CallNode::reinterpret)) {
this->Push(op->args[0]);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U);
- int kind = op->args[2].as<IntImm>()->value;
+ int kind = op->args[2].as<IntImmNode>()->value;
this->Push(op->args[0]);
- const IntImm* index = op->args[1].as<IntImm>();
+ const IntImmNode* index = op->args[1].as<IntImmNode>();
CHECK(index != nullptr);
StackVM::Code code;
code.op_code = StackVM::TVM_STRUCT_GET;
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
CHECK_GE(op->args.size(), 5U);
- const StringImm* s = op->args[0].as<StringImm>();
+ const StringImmNode* s = op->args[0].as<StringImmNode>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
this->Push(op->args[1]);
this->Push(op->args[2]);
- int begin = op->args[3].as<IntImm>()->value;
- int end = op->args[4].as<IntImm>()->value;
+ int begin = op->args[3].as<IntImmNode>()->value;
+ int end = op->args[4].as<IntImmNode>()->value;
// find the fuction id.
const std::string& func_name = s->value;
auto it = extern_fun_idmap_.find(func_name);
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
- const std::string& type = op->args[0].as<StringImm>()->value;
- const IntImm* num = op->args[1].as<IntImm>();
+ const std::string& type = op->args[0].as<StringImmNode>()->value;
+ const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
// static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
}
}
-void CodeGenStackVM::VisitExpr_(const StringImm* op) {
+void CodeGenStackVM::VisitExpr_(const StringImmNode* op) {
int sid = this->GetStrID(op->value);
this->PushOp(StackVM::PUSH_I64, sid);
}
-void CodeGenStackVM::VisitExpr_(const IntImm* op) {
+void CodeGenStackVM::VisitExpr_(const IntImmNode* op) {
CHECK(op->value >= std::numeric_limits<int>::min() &&
op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
-void CodeGenStackVM::VisitExpr_(const UIntImm* op) {
+void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) {
CHECK(op->value <= std::numeric_limits<int>::max())
<< "Int constant exceed bound";
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
-void CodeGenStackVM::VisitExpr_(const FloatImm* op) {
+void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
LOG(FATAL) << "Float Imm is not supported";
}
-void CodeGenStackVM::VisitExpr_(const Variable* op) {
+void CodeGenStackVM::VisitExpr_(const VarNode* op) {
int vid = this->GetVarID(op);
this->PushOp(StackVM::LOAD_HEAP, vid);
}
-void CodeGenStackVM::VisitExpr_(const Cast* op) {
+void CodeGenStackVM::VisitExpr_(const CastNode* op) {
this->Push(op->value);
PushCast(op->dtype, op->value.dtype());
}
-void CodeGenStackVM::VisitExpr_(const Add* op) {
+void CodeGenStackVM::VisitExpr_(const AddNode* op) {
PushBinary(StackVM::ADD_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const Sub* op) {
+void CodeGenStackVM::VisitExpr_(const SubNode* op) {
PushBinary(StackVM::SUB_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const Mul* op) {
+void CodeGenStackVM::VisitExpr_(const MulNode* op) {
PushBinary(StackVM::MUL_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const Div* op) {
+void CodeGenStackVM::VisitExpr_(const DivNode* op) {
PushBinary(StackVM::DIV_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const Mod* op) {
+void CodeGenStackVM::VisitExpr_(const ModNode* op) {
PushBinary(StackVM::MOD_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const Min* op) {
+void CodeGenStackVM::VisitExpr_(const MinNode* op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, -1);
this->PushOp(StackVM::SELECT);
}
-void CodeGenStackVM::VisitExpr_(const Max* op) {
+void CodeGenStackVM::VisitExpr_(const MaxNode* op) {
this->Push(op->a);
this->Push(op->b);
this->PushOp(StackVM::PUSH_VALUE, 0);
this->PushOp(StackVM::SELECT);
}
-void CodeGenStackVM::VisitExpr_(const EQ* op) {
+void CodeGenStackVM::VisitExpr_(const EQNode* op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const LE* op) {
+void CodeGenStackVM::VisitExpr_(const LENode* op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const NE* op) {
+void CodeGenStackVM::VisitExpr_(const NENode* op) {
PushBinary(StackVM::EQ_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
-void CodeGenStackVM::VisitExpr_(const LT* op) {
+void CodeGenStackVM::VisitExpr_(const LTNode* op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
}
-void CodeGenStackVM::VisitExpr_(const GE* op) {
+void CodeGenStackVM::VisitExpr_(const GENode* op) {
PushBinary(StackVM::LT_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
-void CodeGenStackVM::VisitExpr_(const GT* op) {
+void CodeGenStackVM::VisitExpr_(const GTNode* op) {
PushBinary(StackVM::LE_I64, op->a, op->b);
this->PushOp(StackVM::NOT);
}
-void CodeGenStackVM::VisitExpr_(const And* op) {
+void CodeGenStackVM::VisitExpr_(const AndNode* op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->SetOperand(opr_index, diff);
}
-void CodeGenStackVM::VisitExpr_(const Or* op) {
+void CodeGenStackVM::VisitExpr_(const OrNode* op) {
this->Push(op->a);
int64_t pc_jump = this->GetPC();
int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
this->SetOperand(opr_index, diff);
}
-void CodeGenStackVM::VisitExpr_(const Not* op) {
+void CodeGenStackVM::VisitExpr_(const NotNode* op) {
this->Push(op->a);
this->PushOp(StackVM::NOT);
}
-void CodeGenStackVM::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) {
this->Push(op->body);
}
-void CodeGenStackVM::VisitStmt_(const For* op) {
+void CodeGenStackVM::VisitStmt_(const ForNode* op) {
CHECK(is_zero(op->min));
int vid = this->AllocVarID(op->loop_var.get());
this->PushOp(StackVM::PUSH_I64, 0);
}
}
-void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
+void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
if (is_const(ev->value)) return;
- const Call* op = ev->value.as<Call>();
+ const CallNode* op = ev->value.as<CallNode>();
if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
this->Push(op->args[0]);
this->Push(op->args[3]);
- const IntImm* index = op->args[1].as<IntImm>();
+ const IntImmNode* index = op->args[1].as<IntImmNode>();
CHECK(index != nullptr);
StackVM::Code code;
code.op_code = StackVM::TVM_STRUCT_SET;
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
- code.v_int = op->args[2].as<IntImm>()->value;
+ code.v_int = op->args[2].as<IntImmNode>()->value;
vm_.code.push_back(code);
} else {
this->Push(ev->value);
}
}
-void CodeGenStackVM::VisitStmt_(const IfThenElse* op) {
+void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
this->Push(op->condition);
int64_t label_ejump = this->GetPC();
int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
}
}
-void CodeGenStackVM::VisitStmt_(const LetStmt* op) {
+void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
this->Push(op->body);
}
-void CodeGenStackVM::VisitExpr_(const Ramp* op) {
+void CodeGenStackVM::VisitExpr_(const RampNode* op) {
LOG(FATAL) << "Ramp is not supported";
}
-void CodeGenStackVM::VisitExpr_(const Broadcast* op) {
+void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) {
LOG(FATAL) << "Broadcast is not supported";
}
-void CodeGenStackVM::VisitExpr_(const Select* op) {
+void CodeGenStackVM::VisitExpr_(const SelectNode* op) {
this->Push(op->true_value);
this->Push(op->false_value);
this->Push(op->condition);
this->PushOp(StackVM::SELECT);
}
-void CodeGenStackVM::VisitStmt_(const AssertStmt* op) {
- if (const auto* str = op->message.as<StringImm>()) {
+void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) {
+ if (const auto* str = op->message.as<StringImmNode>()) {
int sid = this->GetStrID(str->value);
this->Push(op->condition);
this->PushOp(StackVM::ASSERT, sid);
this->Push(op->body);
}
-void CodeGenStackVM::VisitStmt_(const AttrStmt* op) {
+void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) {
this->Push(op->body);
}
-void CodeGenStackVM::VisitExpr_(const Let* op) {
+void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->value);
int64_t vid = this->AllocVarID(op->var.get());
this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
* \param v The variable.
* \return the heap index of the var.
*/
- int AllocVarID(const Variable* v);
+ int AllocVarID(const VarNode* v);
/*!
* \brief Get a variable name.
* \param v The variable.
* \return the heap index of the var.
*/
- int GetVarID(const Variable* v) const;
+ int GetVarID(const VarNode* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
const Expr& a,
void PushCast(DataType dst, DataType src);
// overloadable functions
// expression
- void VisitExpr_(const Variable* op) final;
- void VisitExpr_(const Load* op) final;
- void VisitExpr_(const Let* op) final;
- void VisitExpr_(const Call* op) final;
- void VisitExpr_(const Add* op) final;
- void VisitExpr_(const Sub* op) final;
- void VisitExpr_(const Mul* op) final;
- void VisitExpr_(const Div* op) final;
- void VisitExpr_(const Mod* op) final;
- void VisitExpr_(const Min* op) final;
- void VisitExpr_(const Max* op) final;
- void VisitExpr_(const EQ* op) final;
- void VisitExpr_(const NE* op) final;
- void VisitExpr_(const LT* op) final;
- void VisitExpr_(const LE* op) final;
- void VisitExpr_(const GT* op) final;
- void VisitExpr_(const GE* op) final;
- void VisitExpr_(const And* op) final;
- void VisitExpr_(const Or* op) final;
- void VisitExpr_(const Cast* op) final;
- void VisitExpr_(const Not* op) final;
- void VisitExpr_(const Select* op) final;
- void VisitExpr_(const Ramp* op) final;
- void VisitExpr_(const Broadcast* op) final;
- void VisitExpr_(const IntImm* op) final;
- void VisitExpr_(const UIntImm* op) final;
- void VisitExpr_(const FloatImm* op) final;
- void VisitExpr_(const StringImm* op) final;
+ void VisitExpr_(const VarNode* op) final;
+ void VisitExpr_(const LoadNode* op) final;
+ void VisitExpr_(const LetNode* op) final;
+ void VisitExpr_(const CallNode* op) final;
+ void VisitExpr_(const AddNode* op) final;
+ void VisitExpr_(const SubNode* op) final;
+ void VisitExpr_(const MulNode* op) final;
+ void VisitExpr_(const DivNode* op) final;
+ void VisitExpr_(const ModNode* op) final;
+ void VisitExpr_(const MinNode* op) final;
+ void VisitExpr_(const MaxNode* op) final;
+ void VisitExpr_(const EQNode* op) final;
+ void VisitExpr_(const NENode* op) final;
+ void VisitExpr_(const LTNode* op) final;
+ void VisitExpr_(const LENode* op) final;
+ void VisitExpr_(const GTNode* op) final;
+ void VisitExpr_(const GENode* op) final;
+ void VisitExpr_(const AndNode* op) final;
+ void VisitExpr_(const OrNode* op) final;
+ void VisitExpr_(const CastNode* op) final;
+ void VisitExpr_(const NotNode* op) final;
+ void VisitExpr_(const SelectNode* op) final;
+ void VisitExpr_(const RampNode* op) final;
+ void VisitExpr_(const BroadcastNode* op) final;
+ void VisitExpr_(const IntImmNode* op) final;
+ void VisitExpr_(const UIntImmNode* op) final;
+ void VisitExpr_(const FloatImmNode* op) final;
+ void VisitExpr_(const StringImmNode* op) final;
// statment
- void VisitStmt_(const LetStmt* op) final;
- void VisitStmt_(const Store* op) final;
- void VisitStmt_(const For* op) final;
- void VisitStmt_(const IfThenElse* op) final;
- void VisitStmt_(const Allocate* op) final;
- void VisitStmt_(const AttrStmt* op) final;
- void VisitStmt_(const AssertStmt* op) final;
- void VisitStmt_(const Evaluate* op) final;
+ void VisitStmt_(const LetStmtNode* op) final;
+ void VisitStmt_(const StoreNode* op) final;
+ void VisitStmt_(const ForNode* op) final;
+ void VisitStmt_(const IfThenElseNode* op) final;
+ void VisitStmt_(const AllocateNode* op) final;
+ void VisitStmt_(const AttrStmtNode* op) final;
+ void VisitStmt_(const AssertStmtNode* op) final;
+ void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const SeqStmtNode* op) final;
- void VisitStmt_(const ProducerConsumer* op) final;
+ void VisitStmt_(const ProducerConsumerNode* op) final;
private:
bool debug_{false};
/*! \brief The vm to be generated */
StackVM vm_;
/*! \brief id of each variable */
- std::unordered_map<const Variable*, int> var_idmap_;
+ std::unordered_map<const VarNode*, int> var_idmap_;
/*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */
os << t.bits();
}
-void CodeGenHybrid::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
os << op->value;
}
-void CodeGenHybrid::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
-void CodeGenHybrid::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << std::setprecision(20) << op->value << ")";
}
-void CodeGenHybrid::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
os << "'" << op->value << "'";
}
}
}
-inline void PrintBinaryIntrinsitc(const Call* op,
+inline void PrintBinaryIntrinsitc(const CallNode* op,
const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenHybrid* p) {
os << ')';
}
-void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype == op->value.dtype()) {
PrintExpr(op->value, stream);
} else {
}
}
-void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*)
os << GetVarID(op);
}
-void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
-void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int())
PrintBinaryExpr(op, "//", os, this);
else
PrintBinaryExpr(op, "/", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
-void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "max", os, this);
}
-void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
-void CodeGenHybrid::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
-void CodeGenHybrid::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
-void CodeGenHybrid::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
-void CodeGenHybrid::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
-void CodeGenHybrid::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
-void CodeGenHybrid::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
-void CodeGenHybrid::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
os << "not ";
PrintExpr(op->a, os);
}
-void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
- if (op->call_type == Call::Halide) {
+void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
+ if (op->call_type == CallNode::Halide) {
os << GetTensorID(op->func, op->value_index);
os << "[";
for (size_t i = 0; i < op->args.size(); ++i) {
os << idx.str();
}
os << "]";
- } else if (op->is_intrinsic(Call::bitwise_and)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_and)) {
PrintBinaryIntrinsitc(op, "&", os, this);
- } else if (op->is_intrinsic(Call::bitwise_xor)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
PrintBinaryIntrinsitc(op, "^", os, this);
- } else if (op->is_intrinsic(Call::bitwise_or)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_or)) {
PrintBinaryIntrinsitc(op, "|", os, this);
- } else if (op->is_intrinsic(Call::shift_left)) {
+ } else if (op->is_intrinsic(CallNode::shift_left)) {
PrintBinaryIntrinsitc(op, "<<", os, this);
- } else if (op->is_intrinsic(Call::shift_right)) {
+ } else if (op->is_intrinsic(CallNode::shift_right)) {
PrintBinaryIntrinsitc(op, ">>", os, this);
- } else if (op->is_intrinsic(Call::bitwise_not)) {
+ } else if (op->is_intrinsic(CallNode::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
os << "(~";
PrintExpr(op->args[0], os);
}
}
-void CodeGenHybrid::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Load(s)!";
}
-void CodeGenHybrid::VisitStmt_(const Store* op) {
+void CodeGenHybrid::VisitStmt_(const StoreNode* op) {
LOG(FATAL) << "Phase 0 has no Store(s)!";
}
-void CodeGenHybrid::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Phase 0 has no Let(s)!";
}
-void CodeGenHybrid::VisitStmt_(const Allocate* op) {
+void CodeGenHybrid::VisitStmt_(const AllocateNode* op) {
LOG(FATAL) << "Phase 0 has no Allocate(s)!";
}
-void CodeGenHybrid::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp to be supported yet";
}
-void CodeGenHybrid::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Broadcast: not supported ";
}
-void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*)
PrintExpr(op->true_value, os);
os << " if ";
PrintExpr(op->condition, os);
os << "\n";
}
-void CodeGenHybrid::VisitStmt_(const LetStmt* op) {
+void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) {
std::string value = PrintExpr(op->value);
stream << GetVarID(op->var.get()) << " = " << value << ";\n";
PrintStmt(op->body);
}
-void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
+void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == ir::attr::thread_extent) {
auto iter_var = op->node.as<IterVarNode>();
CHECK(iter_var);
indent_ -= tab_;
} else if (op->attr_key == ir::attr::realize_scope) {
auto v = Downcast<FunctionRef>(op->node);
- alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
+ alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
PrintStmt(op->body);
} else {
// For now we ignore the unsupported AttrStmt
}
}
-void CodeGenHybrid::VisitStmt_(const Realize* op) {
+void CodeGenHybrid::VisitStmt_(const RealizeNode* op) {
CHECK(alloc_storage_scope_.count(op->func));
if (!alloc_storage_scope_[op->func].empty()) {
PrintIndent();
PrintStmt(op->body);
}
-void CodeGenHybrid::VisitStmt_(const AssertStmt* op) {
+void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) {
PrintIndent();
stream << "assert ";
PrintExpr(op->condition, stream);
PrintStmt(op->body);
}
-void CodeGenHybrid::VisitStmt_(const Provide* op) {
+void CodeGenHybrid::VisitStmt_(const ProvideNode* op) {
PrintIndent();
stream << GetTensorID(op->func, op->value_index);
stream << "[";
stream << "\n";
}
-void CodeGenHybrid::VisitStmt_(const For* op) {
+void CodeGenHybrid::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = GetVarID(op->loop_var.get());
bool is_noop(const Stmt &stmt) {
if (!stmt.defined())
return true;
- if (auto eval = stmt.as<Evaluate>())
+ if (auto eval = stmt.as<EvaluateNode>())
return is_const(eval->value);
return false;
}
-void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
+void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if " << cond << ":\n";
}
}
-void CodeGenHybrid::VisitStmt_(const Evaluate* op) {
+void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty())
stream << str << "\n";
}
-void CodeGenHybrid::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
PrintStmt(op->body);
}
stream << std::string(indent_, ' ');
}
-std::string CodeGenHybrid::GetVarID(const Variable *v) {
+std::string CodeGenHybrid::GetVarID(const VarNode *v) {
if (binds_.count(v))
return binds_[v];
auto key = std::make_pair(static_cast<const Object*>(v), 0);
if (auto tensor = inputs[i].as<TensorNode>()) {
stream << GetTensorID(tensor->op, tensor->value_index);
} else {
- auto var = inputs[i].as<Variable>();
+ auto var = inputs[i].as<VarNode>();
CHECK(var) << "Input should either be a tensor or a variable!";
stream << GetVarID(var);
}
return os.str();
}
// expression
- void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
+ void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
- void VisitStmt_(const LetStmt* op) override;
- void VisitStmt_(const Store* op) override;
- void VisitStmt_(const Provide* op) override;
- void VisitStmt_(const For* op) override;
- void VisitStmt_(const IfThenElse* op) override;
- void VisitStmt_(const Allocate* op) override;
- void VisitStmt_(const Realize* op) override;
- void VisitStmt_(const AttrStmt* op) override;
- void VisitStmt_(const AssertStmt* op) override;
- void VisitStmt_(const Evaluate* op) override;
+ void VisitStmt_(const LetStmtNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitStmt_(const ProvideNode* op) override;
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const IfThenElseNode* op) override;
+ void VisitStmt_(const AllocateNode* op) override;
+ void VisitStmt_(const RealizeNode* op) override;
+ void VisitStmt_(const AttrStmtNode* op) override;
+ void VisitStmt_(const AssertStmtNode* op) override;
+ void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
+ void VisitStmt_(const ProducerConsumerNode* op) override;
/*!
* \brief Print Type represetnation of type t.
* \param t The type representation.
* Values are the corresponding IDs.*/
std::map<std::pair<const Object *, int>, std::string> id_map_;
/*! \brief Variables (keys) binded to the threads (values). */
- std::map<const Variable *, std::string> binds_;
+ std::map<const VarNode *, std::string> binds_;
/*!
* \brief Find an unallocated name for the given prefix.
* \param prefix The given prefix.
* \brief Get or allocate the ID for the given variable.
* \param v The given variable.
*/
- std::string GetVarID(const Variable *v);
+ std::string GetVarID(const VarNode *v);
/*!
* \brief Get or allocate the ID for the given tensor.
* \param func The tensor to allocate a name.
virtual R VisitAttrDefault_(const Object* node, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions.
- virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+ virtual R VisitAttr_(const ir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
private:
// initialize the vtable.
// Set dispatch
ATTR_FUNCTOR_DISPATCH(StrMapNode);
ATTR_FUNCTOR_DISPATCH(ArrayNode);
- ATTR_FUNCTOR_DISPATCH(IntImm);
- ATTR_FUNCTOR_DISPATCH(UIntImm);
- ATTR_FUNCTOR_DISPATCH(FloatImm);
- ATTR_FUNCTOR_DISPATCH(StringImm);
- ATTR_FUNCTOR_DISPATCH(Variable);
- ATTR_FUNCTOR_DISPATCH(Add);
- ATTR_FUNCTOR_DISPATCH(Sub);
- ATTR_FUNCTOR_DISPATCH(Mul);
- ATTR_FUNCTOR_DISPATCH(Div);
- ATTR_FUNCTOR_DISPATCH(Mod);
- ATTR_FUNCTOR_DISPATCH(FloorDiv);
- ATTR_FUNCTOR_DISPATCH(FloorMod);
- ATTR_FUNCTOR_DISPATCH(Min);
- ATTR_FUNCTOR_DISPATCH(Max);
- ATTR_FUNCTOR_DISPATCH(GE);
- ATTR_FUNCTOR_DISPATCH(GT);
- ATTR_FUNCTOR_DISPATCH(LE);
- ATTR_FUNCTOR_DISPATCH(LT);
- ATTR_FUNCTOR_DISPATCH(EQ);
- ATTR_FUNCTOR_DISPATCH(NE);
- ATTR_FUNCTOR_DISPATCH(And);
- ATTR_FUNCTOR_DISPATCH(Or);
- ATTR_FUNCTOR_DISPATCH(Not);
- ATTR_FUNCTOR_DISPATCH(Cast);
- ATTR_FUNCTOR_DISPATCH(Call);
- ATTR_FUNCTOR_DISPATCH(Select);
+ ATTR_FUNCTOR_DISPATCH(IntImmNode);
+ ATTR_FUNCTOR_DISPATCH(UIntImmNode);
+ ATTR_FUNCTOR_DISPATCH(FloatImmNode);
+ ATTR_FUNCTOR_DISPATCH(StringImmNode);
+ ATTR_FUNCTOR_DISPATCH(VarNode);
+ ATTR_FUNCTOR_DISPATCH(AddNode);
+ ATTR_FUNCTOR_DISPATCH(SubNode);
+ ATTR_FUNCTOR_DISPATCH(MulNode);
+ ATTR_FUNCTOR_DISPATCH(DivNode);
+ ATTR_FUNCTOR_DISPATCH(ModNode);
+ ATTR_FUNCTOR_DISPATCH(FloorDivNode);
+ ATTR_FUNCTOR_DISPATCH(FloorModNode);
+ ATTR_FUNCTOR_DISPATCH(MinNode);
+ ATTR_FUNCTOR_DISPATCH(MaxNode);
+ ATTR_FUNCTOR_DISPATCH(GENode);
+ ATTR_FUNCTOR_DISPATCH(GTNode);
+ ATTR_FUNCTOR_DISPATCH(LENode);
+ ATTR_FUNCTOR_DISPATCH(LTNode);
+ ATTR_FUNCTOR_DISPATCH(EQNode);
+ ATTR_FUNCTOR_DISPATCH(NENode);
+ ATTR_FUNCTOR_DISPATCH(AndNode);
+ ATTR_FUNCTOR_DISPATCH(OrNode);
+ ATTR_FUNCTOR_DISPATCH(NotNode);
+ ATTR_FUNCTOR_DISPATCH(CastNode);
+ ATTR_FUNCTOR_DISPATCH(CallNode);
+ ATTR_FUNCTOR_DISPATCH(SelectNode);
return vtable;
}
};
bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::SubNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::MulNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::DivNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::ModNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::FloorDivNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::FloorModNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::MinNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::MaxNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::GENode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::GTNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::LTNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::LENode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::EQNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::NENode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::AndNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::OrNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::NotNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::CastNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::CallNode* lhs, const ObjectRef& other) final;
+ bool VisitAttr_(const ir::SelectNode* lhs, const ObjectRef& other) final;
};
class AttrsHashHandler :
protected:
size_t VisitAttrDefault_(const Object* lhs) final;
- size_t VisitAttr_(const ir::IntImm* lhs) final;
- size_t VisitAttr_(const ir::UIntImm* lhs) final;
- size_t VisitAttr_(const ir::FloatImm* lhs) final;
- size_t VisitAttr_(const ir::StringImm* lhs) final;
+ size_t VisitAttr_(const ir::IntImmNode* lhs) final;
+ size_t VisitAttr_(const ir::UIntImmNode* lhs) final;
+ size_t VisitAttr_(const ir::FloatImmNode* lhs) final;
+ size_t VisitAttr_(const ir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
size_t VisitAttr_(const StrMapNode* lhs) final;
- size_t VisitAttr_(const ir::Add* op) final;
- size_t VisitAttr_(const ir::Sub* op) final;
- size_t VisitAttr_(const ir::Mul* op) final;
- size_t VisitAttr_(const ir::Div* op) final;
- size_t VisitAttr_(const ir::Mod* op) final;
- size_t VisitAttr_(const ir::FloorDiv* op) final;
- size_t VisitAttr_(const ir::FloorMod* op) final;
- size_t VisitAttr_(const ir::Min* op) final;
- size_t VisitAttr_(const ir::Max* op) final;
- size_t VisitAttr_(const ir::GE* op) final;
- size_t VisitAttr_(const ir::GT* op) final;
- size_t VisitAttr_(const ir::LE* op) final;
- size_t VisitAttr_(const ir::LT* op) final;
- size_t VisitAttr_(const ir::EQ* op) final;
- size_t VisitAttr_(const ir::NE* op) final;
- size_t VisitAttr_(const ir::And* op) final;
- size_t VisitAttr_(const ir::Or* op) final;
- size_t VisitAttr_(const ir::Not* op) final;
- size_t VisitAttr_(const ir::Cast* op) final;
- size_t VisitAttr_(const ir::Call* op) final;
- size_t VisitAttr_(const ir::Select* op) final;
+ size_t VisitAttr_(const ir::AddNode* op) final;
+ size_t VisitAttr_(const ir::SubNode* op) final;
+ size_t VisitAttr_(const ir::MulNode* op) final;
+ size_t VisitAttr_(const ir::DivNode* op) final;
+ size_t VisitAttr_(const ir::ModNode* op) final;
+ size_t VisitAttr_(const ir::FloorDivNode* op) final;
+ size_t VisitAttr_(const ir::FloorModNode* op) final;
+ size_t VisitAttr_(const ir::MinNode* op) final;
+ size_t VisitAttr_(const ir::MaxNode* op) final;
+ size_t VisitAttr_(const ir::GENode* op) final;
+ size_t VisitAttr_(const ir::GTNode* op) final;
+ size_t VisitAttr_(const ir::LENode* op) final;
+ size_t VisitAttr_(const ir::LTNode* op) final;
+ size_t VisitAttr_(const ir::EQNode* op) final;
+ size_t VisitAttr_(const ir::NENode* op) final;
+ size_t VisitAttr_(const ir::AndNode* op) final;
+ size_t VisitAttr_(const ir::OrNode* op) final;
+ size_t VisitAttr_(const ir::NotNode* op) final;
+ size_t VisitAttr_(const ir::CastNode* op) final;
+ size_t VisitAttr_(const ir::CallNode* op) final;
+ size_t VisitAttr_(const ir::SelectNode* op) final;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
return lhs == other.get();
}
-bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<IntImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
-bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<UIntImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<UIntImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
-bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<FloatImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
}
return false;
}
-bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<StringImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
}
return false;
} \
} \
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
-
-bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<Not>()) {
+TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
+
+bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<NotNode>()) {
return Equal(lhs->a, rhs->a);
} else {
return false;
}
}
-bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<Cast>()) {
+bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<CastNode>()) {
if (lhs->dtype != rhs->dtype) return false;
return Equal(lhs->value, rhs->value);
} else {
}
}
-bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<Call>()) {
+bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<CallNode>()) {
return
lhs->name == rhs->name &&
lhs->dtype == rhs->dtype &&
}
}
-bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<Select>()) {
+bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
+ if (const auto* rhs = other.as<SelectNode>()) {
return
Equal(lhs->condition, rhs->condition) &&
Equal(lhs->true_value, rhs->true_value) &&
}
}
-size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
return std::hash<int64_t>()(op->value);
}
-size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
return std::hash<uint64_t>()(op->value);
}
-size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value);
}
-size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
return std::hash<std::string>()(op->value);
}
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
-TVM_DEFINE_ATTRS_BINOP_HASH(Add);
-TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
-TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
-TVM_DEFINE_ATTRS_BINOP_HASH(Div);
-TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
-TVM_DEFINE_ATTRS_BINOP_HASH(Max);
-TVM_DEFINE_ATTRS_BINOP_HASH(Min);
-TVM_DEFINE_ATTRS_BINOP_HASH(GE);
-TVM_DEFINE_ATTRS_BINOP_HASH(GT);
-TVM_DEFINE_ATTRS_BINOP_HASH(LE);
-TVM_DEFINE_ATTRS_BINOP_HASH(LT);
-TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
-TVM_DEFINE_ATTRS_BINOP_HASH(NE);
-TVM_DEFINE_ATTRS_BINOP_HASH(And);
-TVM_DEFINE_ATTRS_BINOP_HASH(Or);
-
-size_t AttrsHashHandler::VisitAttr_(const Not* op) {
- static size_t key = std::hash<std::string>()(Not::_type_key);
+TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
+
+size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
+ static size_t key = std::hash<std::string>()(NotNode::_type_key);
return Combine(key, Hash(op->a));
}
-size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
- static size_t key = std::hash<std::string>()(Cast::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
+ static size_t key = std::hash<std::string>()(CastNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->dtype));
return res;
}
-size_t AttrsHashHandler::VisitAttr_(const Call* op) {
- static size_t key = std::hash<std::string>()(Call::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
+ static size_t key = std::hash<std::string>()(CallNode::_type_key);
AttrsHash hasher;
size_t res = key;
res = Combine(res, hasher(op->name));
return res;
}
-size_t AttrsHashHandler::VisitAttr_(const Select* op) {
- static size_t key = std::hash<std::string>()(Select::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
+ static size_t key = std::hash<std::string>()(SelectNode::_type_key);
size_t res = key;
res = Combine(res, Hash(op->condition));
res = Combine(res, Hash(op->true_value));
namespace tvm {
// TODO(tqchen): change to floormod/div
-using IndexMod = ir::FloorMod;
-using IndexDiv = ir::FloorDiv;
+using IndexMod = ir::FloorModNode;
+using IndexDiv = ir::FloorDivNode;
Array<Expr> SimplifyArray(Array<Expr> array) {
for (size_t i = 0; i < array.size(); ++i) {
while (!split_buffer.empty()) {
const Expr* top_ele = split_buffer.top();
split_buffer.pop();
- auto expr_add_match = top_ele->as<Add>();
+ auto expr_add_match = top_ele->as<AddNode>();
if (expr_add_match) {
split_buffer.push(&expr_add_match->b);
split_buffer.push(&expr_add_match->a);
const Expr &mod_l_expr,
const Expr &mod_r_expr) {
using namespace ir;
- const Mul* mult_ptr = mult_expr.as<Mul>();
+ const MulNode* mult_ptr = mult_expr.as<MulNode>();
if (!mult_ptr) return std::make_pair(false, Expr());
Expr mult_outer = mult_ptr->b;
const Expr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier
while (true) {
- mult_ptr = inner->as<Mul>();
+ mult_ptr = inner->as<MulNode>();
if (mult_ptr) {
inner = &(mult_ptr->a);
mult_outer = mult_ptr->b * mult_outer;
Expr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>();
- auto inner_mult_ptr = search_ptr->as<Mul>();
- auto inner_add_ptr = search_ptr->as<Add>();
+ auto inner_mult_ptr = search_ptr->as<MulNode>();
+ auto inner_add_ptr = search_ptr->as<AddNode>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
return std::make_pair(false, Expr());
} else if (inner_div_ptr) {
*has_mod = false;
for (const Expr* ele : eles) {
auto mod_ptr = ele->as<IndexMod>();
- auto mult_ptr = ele->as<Mul>();
+ auto mult_ptr = ele->as<MulNode>();
if (mod_ptr) {
*has_mod = true;
mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
if (n->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
- auto is_int = index[0].as<IntImm>();
+ auto is_int = index[0].as<IntImmNode>();
CHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
if (dtype.lanes() != 1) {
- return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
+ return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
} else {
return offset;
}
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (dtype == DataType::Bool()) {
- return ir::Cast::make(
+ return ir::CastNode::make(
DataType::Bool(),
- ir::Load::make(
+ ir::LoadNode::make(
DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
const_true()));
} else {
- return ir::Load::make(
+ return ir::LoadNode::make(
dtype, n->data, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
<< "Cannot load " << dtype
<< " from buffer of " << n->dtype;
if (value.dtype() == DataType::Bool()) {
- return ir::Store::make(n->data,
- ir::Cast::make(DataType::Int(8), value),
+ return ir::StoreNode::make(n->data,
+ ir::CastNode::make(DataType::Int(8), value),
BufferOffset(n, begin, DataType::Int(8)),
const_true());
} else {
- return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
+ return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
const_true(dtype.lanes()));
}
}
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
- extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
+ extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
}
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
Array<Expr> acc_args{
e_dtype, self->data, elem_offset,
extent, make_const(DataType::Int(32), access_mask)};
- return ir::Call::make(
- ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
+ return ir::CallNode::make(
+ ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic);
}
Buffer BufferNode::make(Var data,
node->axes = axes;
std::ostringstream repr;
for (const IterVar& axis : axes) {
- if (const auto* factor = axis->dom->extent.as<IntImm>()) {
+ if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
CHECK_GT(factor->value, 0);
repr << factor->value;
}
if (!this->defined()) return -1;
for (const IterVar& itvar : operator->()->axes) {
if (sub == LayoutAxis::Get(itvar)) {
- const auto* factor = itvar->dom->extent.as<IntImm>();
+ const auto* factor = itvar->dom->extent.as<IntImmNode>();
CHECK(factor);
return factor->value;
}
const Array<IterVar>& src_axis,
const Array<Expr>& transform_rule) {
Array<Expr> result;
- std::unordered_map<const Variable*, Expr> bind_map;
+ std::unordered_map<const VarNode*, Expr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
- std::unordered_map<const Variable*, Expr> bind_map;
+ std::unordered_map<const VarNode*, Expr> bind_map;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
Expr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
- if (orig_shape.as<ir::Any>()) {
+ if (orig_shape.as<ir::AnyNode>()) {
symbolic_var_set.insert(i);
}
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
if (orig_shape.defined()) {
- const auto* orig_shape_const = orig_shape.as<IntImm>();
- const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
+ const auto* orig_shape_const = orig_shape.as<IntImmNode>();
+ const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
if (orig_shape_const) {
CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
<< "Input shape mismatch at index " << i << ". Expected "
result.push_back(axis->dom->extent);
} else {
if (symbolic_var_set.count(i)) {
- result.push_back(ir::Any::make());
+ result.push_back(ir::AnyNode::make());
} else {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
namespace tvm {
Expr::Expr(int32_t value)
- : Expr(IntImm::make(DataType::Int(32), value)) {}
+ : Expr(IntImmNode::make(DataType::Int(32), value)) {}
Expr::Expr(float value)
- : Expr(ir::FloatImm::make(DataType::Float(32), value)) {}
+ : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
Expr::Expr(std::string str)
- : Expr(ir::StringImm::make(str)) {}
+ : Expr(ir::StringImmNode::make(str)) {}
Var::Var(std::string name_hint, DataType t)
- : Var(Variable::make(t, name_hint)) {}
+ : Var(VarNode::make(t, name_hint)) {}
-Var Variable::make(DataType t, std::string name_hint) {
- ObjectPtr<Variable> node = make_object<Variable>();
+Var VarNode::make(DataType t, std::string name_hint) {
+ ObjectPtr<VarNode> node = make_object<VarNode>();
node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
is_zero(begin) ? end : (end - begin))) {
}
-Integer IntImm::make(DataType t, int64_t value) {
+Integer IntImmNode::make(DataType t, int64_t value) {
CHECK(t.is_int() && t.is_scalar())
<< "ValueError: IntImm can only take scalar.";
- ObjectPtr<IntImm> node = make_object<IntImm>();
+ ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = t;
node->value = value;
return Integer(node);
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IntImm>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const IntImm*>(node.get());
+.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
} else {
// simple cast that only checks if type matches and cast
inline Expr SimpleCast(const DataType& t, Expr value) {
if (value.dtype() == t) return value;
- return ir::Cast::make(t, value);
+ return ir::CastNode::make(t, value);
}
// The public function with a quick checking path.
DataType ltype = lhs.dtype();
DataType rtype = rhs.dtype();
if (ltype.lanes() == 1 && rtype.lanes() != 1) {
- lhs = ir::Broadcast::make(lhs, rtype.lanes());
+ lhs = ir::BroadcastNode::make(lhs, rtype.lanes());
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
- rhs = ir::Broadcast::make(rhs, ltype.lanes());
+ rhs = ir::BroadcastNode::make(rhs, ltype.lanes());
} else {
CHECK(ltype.lanes() == rtype.lanes())
<< "Cannot match type " << ltype << " vs " << rtype;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
if (dtype.bits() == 64) {
- return IntImm::make(dtype, std::numeric_limits<int64_t>::max());
+ return IntImmNode::make(dtype, std::numeric_limits<int64_t>::max());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = (val << (dtype.bits() - 1)) - 1;
- return IntImm::make(dtype, val);
+ return IntImmNode::make(dtype, val);
}
} else if (dtype.is_uint()) {
if (dtype.bits() == 64) {
- return UIntImm::make(dtype, std::numeric_limits<uint64_t>::max());
+ return UIntImmNode::make(dtype, std::numeric_limits<uint64_t>::max());
} else if (dtype.bits() < 64) {
uint64_t val = 1;
val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
- return UIntImm::make(dtype, val);
+ return UIntImmNode::make(dtype, val);
}
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
- return FloatImm::make(dtype, std::numeric_limits<double>::max());
+ return FloatImmNode::make(dtype, std::numeric_limits<double>::max());
} else if (dtype.bits() == 32) {
- return FloatImm::make(dtype, std::numeric_limits<float>::max());
+ return FloatImmNode::make(dtype, std::numeric_limits<float>::max());
} else if (dtype.bits() == 16) {
- return FloatImm::make(dtype, 65504.0);
+ return FloatImmNode::make(dtype, 65504.0);
}
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
if (dtype.bits() == 64) {
- return IntImm::make(dtype, std::numeric_limits<int64_t>::lowest());
+ return IntImmNode::make(dtype, std::numeric_limits<int64_t>::lowest());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = -(val << (dtype.bits() - 1));
- return IntImm::make(dtype, val);
+ return IntImmNode::make(dtype, val);
}
} else if (dtype.is_uint()) {
- return UIntImm::make(dtype, 0);
+ return UIntImmNode::make(dtype, 0);
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
- return FloatImm::make(dtype, std::numeric_limits<double>::lowest());
+ return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
} else if (dtype.bits() == 32) {
- return FloatImm::make(dtype, std::numeric_limits<float>::lowest());
+ return FloatImmNode::make(dtype, std::numeric_limits<float>::lowest());
} else if (dtype.bits() == 16) {
- return FloatImm::make(dtype, -65504.0);
+ return FloatImmNode::make(dtype, -65504.0);
}
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
}
bool is_const_power_of_two_integer(const Expr& x, int* shift) {
- if (const auto* op = x.as<ir::IntImm>()) {
+ if (const auto* op = x.as<ir::IntImmNode>()) {
return ConstPowerHelper(op->value, shift);
- } else if (const auto* op = x.as<ir::UIntImm>()) {
+ } else if (const auto* op = x.as<ir::UIntImmNode>()) {
return ConstPowerHelper(op->value, shift);
} else {
return false;
}
Expr cast(const DataType& t, Expr value) {
- using ir::IntImm;
- using ir::UIntImm;
- using ir::FloatImm;
+ using ir::IntImmNode;
+ using ir::UIntImmNode;
+ using ir::FloatImmNode;
if (value.dtype() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
- if (const IntImm* op = value.as<IntImm>()) {
+ if (const IntImmNode* op = value.as<IntImmNode>()) {
return make_const(t, op->value);
- } else if (const UIntImm* op = value.as<UIntImm>()) {
+ } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
return make_const(t, op->value);
- } else if (const FloatImm* op = value.as<FloatImm>()) {
+ } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
return make_const(t, op->value);
}
- return ir::Cast::make(t, value);
+ return ir::CastNode::make(t, value);
} else {
if (value.dtype().lanes() == 1) {
// manually unroll cast
DataType vtype = t.element_of();
if (value.dtype() != vtype) {
- if (const IntImm* op = value.as<IntImm>()) {
+ if (const IntImmNode* op = value.as<IntImmNode>()) {
value = make_const(vtype, op->value);
- } else if (const UIntImm* op = value.as<UIntImm>()) {
+ } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
return make_const(t, op->value);
- } else if (const FloatImm* op = value.as<FloatImm>()) {
+ } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
value = make_const(vtype, op->value);
} else {
- value = ir::Cast::make(vtype, value);
+ value = ir::CastNode::make(vtype, value);
}
}
- return ir::Broadcast::make(value, t.lanes());
+ return ir::BroadcastNode::make(value, t.lanes());
} else {
CHECK(value.dtype().lanes() == t.lanes());
- return ir::Cast::make(t, value);
+ return ir::CastNode::make(t, value);
}
}
}
Expr reinterpret(const DataType& t, Expr value) {
if (value.dtype() == t) return value;
- return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic);
}
Expr operator+(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Add>(a, b);
+ Expr ret = arith::TryConstFold<ir::AddNode>(a, b);
if (ret.defined()) return ret;
- return ir::Add::make(a, b);
+ return ir::AddNode::make(a, b);
}
// negation
Expr operator-(Expr a) {
- using ir::IntImm;
- using ir::FloatImm;
- const IntImm* pa = a.as<IntImm>();
- const FloatImm* fa = a.as<FloatImm>();
- if (pa) return ir::IntImm::make(a.dtype(), -pa->value);
- if (fa) return ir::FloatImm::make(a.dtype(), -fa->value);
+ using ir::IntImmNode;
+ using ir::FloatImmNode;
+ const IntImmNode* pa = a.as<IntImmNode>();
+ const FloatImmNode* fa = a.as<FloatImmNode>();
+ if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value);
+ if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
return make_zero(a.dtype()) - a;
}
Expr operator-(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Sub>(a, b);
+ Expr ret = arith::TryConstFold<ir::SubNode>(a, b);
if (ret.defined()) return ret;
- return ir::Sub::make(a, b);
+ return ir::SubNode::make(a, b);
}
Expr operator*(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Mul>(a, b);
+ Expr ret = arith::TryConstFold<ir::MulNode>(a, b);
if (ret.defined()) return ret;
- return ir::Mul::make(a, b);
+ return ir::MulNode::make(a, b);
}
Expr div(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Div>(a, b);
+ Expr ret = arith::TryConstFold<ir::DivNode>(a, b);
if (ret.defined()) return ret;
- return ir::Div::make(a, b);
+ return ir::DivNode::make(a, b);
}
Expr truncdiv(Expr a, Expr b) {
Expr truncmod(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Mod>(a, b);
+ Expr ret = arith::TryConstFold<ir::ModNode>(a, b);
if (ret.defined()) return ret;
- return ir::Mod::make(a, b);
+ return ir::ModNode::make(a, b);
}
Expr operator/(Expr a, Expr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
+ Expr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
if (ret.defined()) return ret;
- return ir::FloorDiv::make(a, b);
+ return ir::FloorDivNode::make(a, b);
}
Expr floormod(Expr a, Expr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
+ Expr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
if (ret.defined()) return ret;
- return ir::FloorMod::make(a, b);
+ return ir::FloorModNode::make(a, b);
}
Expr min(Expr a, Expr b) {
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Min>(a, b);
+ Expr ret = arith::TryConstFold<ir::MinNode>(a, b);
if (ret.defined()) return ret;
- return ir::Min::make(a, b);
+ return ir::MinNode::make(a, b);
}
Expr max(Expr a, Expr b) {
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::Max>(a, b);
+ Expr ret = arith::TryConstFold<ir::MaxNode>(a, b);
if (ret.defined()) return ret;
- return ir::Max::make(a, b);
+ return ir::MaxNode::make(a, b);
}
Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
- using ir::IntImm;
- using ir::UIntImm;
+ using ir::IntImmNode;
+ using ir::UIntImmNode;
CHECK(cond.dtype() == DataType::Bool(1))
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value);
- if (const UIntImm* op = cond.as<UIntImm>()) {
+ if (const UIntImmNode* op = cond.as<UIntImmNode>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
- } else if (const IntImm* op = cond.as<IntImm>()) {
+ } else if (const IntImmNode* op = cond.as<IntImmNode>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
}
- return ir::Call::make(
+ return ir::CallNode::make(
true_value.dtype(),
ir::intrinsic::tvm_if_then_else,
{cond, true_value, false_value},
- ir::Call::PureIntrinsic);
+ ir::CallNode::PureIntrinsic);
}
Expr likely(Expr cond) {
if (is_const(cond)) return cond;
- return ir::Call::make(cond.dtype(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(cond.dtype(),
+ ir::CallNode::likely,
+ { cond },
+ ir::CallNode::PureIntrinsic);
}
Expr operator>(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::GT>(a, b);
+ Expr ret = arith::TryConstFold<ir::GTNode>(a, b);
if (ret.defined()) return ret;
- return ir::GT::make(a, b);
+ return ir::GTNode::make(a, b);
}
Expr operator>=(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::GE>(a, b);
+ Expr ret = arith::TryConstFold<ir::GENode>(a, b);
if (ret.defined()) return ret;
- return ir::GE::make(a, b);
+ return ir::GENode::make(a, b);
}
Expr operator<(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::LT>(a, b);
+ Expr ret = arith::TryConstFold<ir::LTNode>(a, b);
if (ret.defined()) return ret;
- return ir::LT::make(a, b);
+ return ir::LTNode::make(a, b);
}
Expr operator<=(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::LE>(a, b);
+ Expr ret = arith::TryConstFold<ir::LENode>(a, b);
if (ret.defined()) return ret;
- return ir::LE::make(a, b);
+ return ir::LENode::make(a, b);
}
Expr operator==(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::EQ>(a, b);
+ Expr ret = arith::TryConstFold<ir::EQNode>(a, b);
if (ret.defined()) return ret;
- return ir::EQ::make(a, b);
+ return ir::EQNode::make(a, b);
}
Expr operator!=(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::NE>(a, b);
+ Expr ret = arith::TryConstFold<ir::NENode>(a, b);
if (ret.defined()) return ret;
- return ir::NE::make(a, b);
+ return ir::NENode::make(a, b);
}
Expr operator&&(Expr a, Expr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::And>(a, b);
+ Expr ret = arith::TryConstFold<ir::AndNode>(a, b);
if (ret.defined()) return ret;
- return ir::And::make(a, b);
+ return ir::AndNode::make(a, b);
}
Expr operator||(Expr a, Expr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::Or>(a, b);
+ Expr ret = arith::TryConstFold<ir::OrNode>(a, b);
if (ret.defined()) return ret;
- return ir::Or::make(a, b);
+ return ir::OrNode::make(a, b);
}
Expr operator!(Expr a) {
CHECK(a.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::Not>(a);
+ Expr ret = arith::TryConstFold<ir::NotNode>(a);
if (ret.defined()) return ret;
- return ir::Not::make(a);
+ return ir::NotNode::make(a);
}
Expr operator>>(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
+ if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
- return ir::Call::make(a.dtype(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic);
}
Expr operator<<(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
+ if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
- return ir::Call::make(a.dtype(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic);
}
Expr operator&(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
+ if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value));
});
- return ir::Call::make(a.dtype(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
}
Expr operator|(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
+ if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value));
});
- return ir::Call::make(a.dtype(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
}
Expr operator^(Expr a, Expr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
+ if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value));
});
- return ir::Call::make(a.dtype(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
}
Expr operator~(Expr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
- return ir::Call::make(a.dtype(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic);
}
Expr pow(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
- return ir::Call::make(x.dtype(), "pow", { x, y }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(
+ x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic);
}
Expr abs(Expr x) {
if (x.dtype().is_int()) {
- using ir::IntImm;
- const IntImm* px = x.as<IntImm>();
+ using ir::IntImmNode;
+ const IntImmNode* px = x.as<IntImmNode>();
if (px) {
- return ir::IntImm::make(x.dtype(), std::abs(px->value));
+ return ir::IntImmNode::make(x.dtype(), std::abs(px->value));
}
- return ir::Select::make(x >= make_zero(x.dtype()), x, -x);
+ return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x);
} else if (x.dtype().is_float()) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
- return ir::FloatImm::make(x.dtype(), std::fabs(fx->value));
+ return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value));
}
- return ir::Call::make(x.dtype(), "fabs", {x}, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic);
} else if (x.dtype().is_uint()) {
return x;
} else {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
if (x.dtype().bits() == 16) {
- return ir::Call::make(t, ir::Call::isnan,
+ return ir::CallNode::make(t, ir::CallNode::isnan,
{cast(DataType::Float(32, t.lanes()), std::move(x))},
- ir::Call::PureIntrinsic);
+ ir::CallNode::PureIntrinsic);
} else {
- return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(t, ir::CallNode::isnan, {x}, ir::CallNode::PureIntrinsic);
}
} else {
LOG(FATAL) << "Data type " << x.dtype()
Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::Add::make(x, y);
+ Expr result = ir::AddNode::make(x, y);
Expr identity_element = make_zero(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr all(Expr source, Array<IterVar> rdom) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::And::make(x, y);
+ Expr result = ir::AndNode::make(x, y);
Expr identity_element = make_const(source.dtype(), true);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr any(Expr source, Array<IterVar> rdom) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::Or::make(x, y);
+ Expr result = ir::OrNode::make(x, y);
Expr identity_element = make_const(source.dtype(), false);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr max(Expr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::Max::make(x, y);
+ Expr result = ir::MaxNode::make(x, y);
Expr identity_element = min_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr min(Expr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::Min::make(x, y);
+ Expr result = ir::MinNode::make(x, y);
Expr identity_element = max_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr prod(Expr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::Mul::make(x, y);
+ Expr result = ir::MulNode::make(x, y);
Expr identity_element = make_const(source.dtype(), 1);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
- return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
Expr fmod(Expr x, Expr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
- return ir::Call::make(x.dtype(), "fmod", { x, y }, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic);
}
Expr floor(Expr x) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.dtype(), std::floor(fx->value));
- return ir::Call::make(x.dtype(), "floor", {x}, ir::Call::PureIntrinsic);
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
+ if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
+ return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
}
Expr ceil(Expr x) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.dtype(), std::ceil(fx->value));
- return ir::Call::make(x.dtype(), "ceil", {x}, ir::Call::PureIntrinsic);
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
+ if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
+ return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
}
Expr round(Expr x) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
- return ir::Call::make(x.dtype(), "round", {x}, ir::Call::PureIntrinsic);
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
+ if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+ return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
}
Expr nearbyint(Expr x) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
- if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
- return ir::Call::make(x.dtype(), "nearbyint", {x}, ir::Call::PureIntrinsic);
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
+ if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+ return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
}
Expr trunc(Expr x) {
- using ir::FloatImm;
- const FloatImm* fx = x.as<FloatImm>();
+ using ir::FloatImmNode;
+ const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
- return FloatImm::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
+ return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value)));
}
- return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic);
+ return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic);
}
} // namespace tvm
namespace ir {
// constructors
-Expr UIntImm::make(DataType t, uint64_t value) {
+Expr UIntImmNode::make(DataType t, uint64_t value) {
CHECK(t.is_uint() && t.lanes() == 1)
<< "ValueError: UIntImm can only take scalar";
- ObjectPtr<UIntImm> node = make_object<UIntImm>();
+ ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
node->dtype = t;
node->value = value;
return Expr(node);
}
-Expr FloatImm::make(DataType t, double value) {
+Expr FloatImmNode::make(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar";
- ObjectPtr<FloatImm> node = make_object<FloatImm>();
+ ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = t;
node->value = value;
return Expr(node);
}
-Expr StringImm::make(std::string value) {
- ObjectPtr<StringImm> node = make_object<StringImm>();
+Expr StringImmNode::make(std::string value) {
+ ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
node->dtype = DataType::Handle();
node->value = std::move(value);
return Expr(node);
}
-Expr Cast::make(DataType t, Expr value) {
+Expr CastNode::make(DataType t, Expr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
- ObjectPtr<Cast> node = make_object<Cast>();
+ ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
return Expr(node);
}
-Expr And::make(Expr a, Expr b) {
+Expr AndNode::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
- ObjectPtr<And> node = make_object<And>();
+ ObjectPtr<AndNode> node = make_object<AndNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
-Expr Or::make(Expr a, Expr b) {
+Expr OrNode::make(Expr a, Expr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
- ObjectPtr<Or> node = make_object<Or>();
+ ObjectPtr<OrNode> node = make_object<OrNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
return Expr(node);
}
-Expr Not::make(Expr a) {
+Expr NotNode::make(Expr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
- ObjectPtr<Not> node = make_object<Not>();
+ ObjectPtr<NotNode> node = make_object<NotNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
return Expr(node);
}
-Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
+Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes());
CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
- ObjectPtr<Select> node = make_object<Select>();
+ ObjectPtr<SelectNode> node = make_object<SelectNode>();
node->dtype = true_value.dtype();
node->condition = std::move(condition);
node->true_value = std::move(true_value);
return Expr(node);
}
-Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
+Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
CHECK(buffer_var.defined());
CHECK(predicate.defined());
CHECK(index.defined());
CHECK_EQ(dtype.lanes(), index.dtype().lanes());
CHECK_EQ(dtype.lanes(), predicate.dtype().lanes());
- ObjectPtr<Load> node = make_object<Load>();
+ ObjectPtr<LoadNode> node = make_object<LoadNode>();
node->dtype = dtype;
node->buffer_var = std::move(buffer_var);
node->index = std::move(index);
return Expr(node);
}
-Expr Ramp::make(Expr base, Expr stride, int lanes) {
+Expr RampNode::make(Expr base, Expr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
CHECK(base.dtype().is_scalar());
CHECK_GT(lanes, 1);
CHECK_EQ(stride.dtype(), base.dtype());
- ObjectPtr<Ramp> node = make_object<Ramp>();
+ ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
node->base = base;
node->stride = stride;
return Expr(node);
}
-Expr Broadcast::make(Expr value, int lanes) {
+Expr BroadcastNode::make(Expr value, int lanes) {
CHECK(value.defined());
CHECK(value.dtype().is_scalar());
CHECK_GT(lanes, 1);
- ObjectPtr<Broadcast> node = make_object<Broadcast>();
+ ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>();
node->dtype = value.dtype().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
return Expr(node);
}
-Expr Let::make(Var var, Expr value, Expr body) {
+Expr LetNode::make(Var var, Expr value, Expr body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
- ObjectPtr<Let> node = make_object<Let>();
+ ObjectPtr<LetNode> node = make_object<LetNode>();
node->dtype = body.dtype();
node->var = std::move(var);
node->value = std::move(value);
return Expr(node);
}
-const char* Call::vectorizable_intrinsics[] = {
+const char* CallNode::vectorizable_intrinsics[] = {
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
- "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
- ir::Call::likely, ir::Call::popcount
+ "log", "sin", "cos", "pow", ir::CallNode::shift_left, ir::CallNode::shift_right,
+ ir::CallNode::likely, ir::CallNode::popcount
};
-bool Call::is_vectorizable() const {
- size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
+bool CallNode::is_vectorizable() const {
+ size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*);
for (size_t i = 0; i < cnt; ++i) {
- if (name == Call::vectorizable_intrinsics[i]) {
+ if (name == CallNode::vectorizable_intrinsics[i]) {
return true;
}
}
return false;
}
-Expr Call::make(DataType dtype,
+Expr CallNode::make(DataType dtype,
std::string name,
Array<Expr> args,
CallType call_type,
}
}
- ObjectPtr<Call> node = make_object<Call>();
+ ObjectPtr<CallNode> node = make_object<CallNode>();
node->dtype = dtype;
node->name = std::move(name);
node->args = std::move(args);
return Expr(node);
}
-Expr Shuffle::make(Array<Expr> vectors,
+Expr ShuffleNode::make(Array<Expr> vectors,
Array<Expr> indices) {
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
}
CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
- ObjectPtr<Shuffle> node = make_object<Shuffle>();
+ ObjectPtr<ShuffleNode> node = make_object<ShuffleNode>();
node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
node->vectors = std::move(vectors);
node->indices = std::move(indices);
return Expr(node);
}
-Expr Shuffle::make_concat(Array<Expr> vectors) {
+Expr ShuffleNode::make_concat(Array<Expr> vectors) {
CHECK_NE(vectors.size(), 0);
if (vectors.size() == 1) {
return vectors[0];
int index = 0;
for (const Expr& e : vectors) {
for (int i = 0; i < e.dtype().lanes(); ++i) {
- indices.push_back(IntImm::make(DataType::Int(32), index++));
+ indices.push_back(IntImmNode::make(DataType::Int(32), index++));
}
}
return make(vectors, indices);
}
-Expr Shuffle::make_extract_element(Expr vector, int index) {
+Expr ShuffleNode::make_extract_element(Expr vector, int index) {
return make({vector}, {Integer(index)});
}
});
}
-Expr Reduce::make(CommReducer combiner, Array<Expr> source,
+Expr ReduceNode::make(CommReducer combiner, Array<Expr> source,
Array<IterVar> axis, Expr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
if (!condition.defined()) {
condition = const_true();
}
- auto n = make_object<Reduce>();
+ auto n = make_object<ReduceNode>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
return Expr(n);
}
-Expr Any::make() {
- auto n = make_object<Any>();
+Expr AnyNode::make() {
+ auto n = make_object<AnyNode>();
return Expr(n);
}
-Stmt LetStmt::make(Var var, Expr value, Stmt body) {
+Stmt LetStmtNode::make(Var var, Expr value, Stmt body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
- ObjectPtr<LetStmt> node = make_object<LetStmt>();
+ ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
return Stmt(node);
}
-Stmt AttrStmt::make(ObjectRef node,
+Stmt AttrStmtNode::make(ObjectRef node,
std::string attr_key,
Expr value,
Stmt body) {
- auto n = make_object<AttrStmt>();
+ auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
n->value = std::move(value);
return Stmt(n);
}
-Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
+Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
- message.as<StringImm>())
+ message.as<StringImmNode>())
<< "TypeError: AssertStmt message must be an int or string:"
<< message << "\n";
- ObjectPtr<AssertStmt> node = make_object<AssertStmt>();
+ ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
node->condition = std::move(condition);
node->message = std::move(message);
node->body = std::move(body);
return Stmt(node);
}
-Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
+Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
- ObjectPtr<ProducerConsumer> node = make_object<ProducerConsumer>();
+ ObjectPtr<ProducerConsumerNode> node = make_object<ProducerConsumerNode>();
node->func = std::move(func);
node->is_producer = is_producer;
node->body = std::move(body);
return Stmt(node);
}
-Stmt For::make(Var loop_var,
+Stmt ForNode::make(Var loop_var,
Expr min,
Expr extent,
ForType for_type,
CHECK(loop_var.dtype().is_scalar());
CHECK(body.defined());
- ObjectPtr<For> node = make_object<For>();
+ ObjectPtr<ForNode> node = make_object<ForNode>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
node->extent = std::move(extent);
return Stmt(node);
}
-Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
CHECK(value.defined());
CHECK(index.defined());
CHECK(predicate.defined());
CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
- ObjectPtr<Store> node = make_object<Store>();
+ ObjectPtr<StoreNode> node = make_object<StoreNode>();
node->buffer_var = std::move(buffer_var);
node->value = std::move(value);
node->index = std::move(index);
return Stmt(node);
}
-Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
CHECK(value_index >=0 && value_index < func->num_outputs())
<< "value index output function return value bound";
CHECK(value.defined()) << "Provide of undefined value\n";
CHECK(args[i].defined()) << "Provide to undefined location\n";
}
- ObjectPtr<Provide> node = make_object<Provide>();
+ ObjectPtr<ProvideNode> node = make_object<ProvideNode>();
node->func = std::move(func);
node->value_index = value_index;
node->value = std::move(value);
return Stmt(node);
}
-Stmt Allocate::make(Var buffer_var,
+Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
Array<Expr> extents,
Expr condition,
CHECK(condition.defined());
CHECK(condition.dtype().is_bool());
- ObjectPtr<Allocate> node = make_object<Allocate>();
+ ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
node->buffer_var = std::move(buffer_var);
node->dtype = dtype;
node->extents = std::move(extents);
return Stmt(node);
}
-int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
+int32_t AllocateNode::constant_allocation_size(const Array<Expr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
- if (const IntImm *int_size = extents[i].as<IntImm>()) {
+ if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
result *= int_size->value;
if (result > std::numeric_limits<int32_t>::max()) {
return 0;
return static_cast<int32_t>(result);
}
-Stmt Free::make(Var buffer_var) {
- ObjectPtr<Free> node = make_object<Free>();
+Stmt FreeNode::make(Var buffer_var) {
+ ObjectPtr<FreeNode> node = make_object<FreeNode>();
node->buffer_var = buffer_var;
return Stmt(node);
}
-Stmt Realize::make(FunctionRef func,
+Stmt RealizeNode::make(FunctionRef func,
int value_index,
DataType dtype,
Region bounds,
CHECK(condition.defined());
CHECK(condition.dtype().is_bool());
- ObjectPtr<Realize> node = make_object<Realize>();
+ ObjectPtr<RealizeNode> node = make_object<RealizeNode>();
node->func = std::move(func);
node->value_index = value_index;
node->dtype = dtype;
return Stmt(node);
}
-Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
+Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
CHECK(bounds[i]->extent.defined());
CHECK(bounds[i]->extent.dtype().is_scalar());
}
- ObjectPtr<Prefetch> node = make_object<Prefetch>();
+ ObjectPtr<PrefetchNode> node = make_object<PrefetchNode>();
node->func = std::move(func);
node->value_index = value_index;
node->dtype = dtype;
data_ = std::move(node);
}
-Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
+Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
// else_case may be null.
- ObjectPtr<IfThenElse> node = make_object<IfThenElse>();
+ ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
node->condition = std::move(condition);
node->then_case = std::move(then_case);
node->else_case = std::move(else_case);
return Stmt(node);
}
-Stmt Evaluate::make(Expr value) {
+Stmt EvaluateNode::make(Expr value) {
CHECK(value.defined());
- ObjectPtr<Evaluate> node = make_object<Evaluate>();
+ ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
node->value = std::move(value);
return Stmt(node);
}
// Printers
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<UIntImm>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const UIntImm*>(node.get());
+.set_dispatch<UIntImmNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const UIntImmNode*>(node.get());
p->stream << "(" << op->dtype << ")" << op->value;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloatImm>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const FloatImm*>(node.get());
+.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
case 64:
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<StringImm>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const StringImm*>(node.get());
+.set_dispatch<StringImmNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const StringImmNode*>(node.get());
auto& stream = p->stream;
stream << '"';
for (size_t i = 0; i < op->value.size(); ++i) {
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Cast>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Cast*>(node.get());
+.set_dispatch<CastNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const CastNode*>(node.get());
p->stream << op->dtype << '(';
p->Print(op->value);
p->stream << ')';
})
-.set_dispatch<Variable>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Variable*>(node.get());
+.set_dispatch<VarNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const VarNode*>(node.get());
// omit the type
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
})
-.set_dispatch<Add>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Add*>(node.get());
+.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const AddNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " + ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<Sub>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Sub*>(node.get());
+.set_dispatch<SubNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const SubNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " - ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<Mul>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Mul*>(node.get());
+.set_dispatch<MulNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const MulNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << "*";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<Div>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Div*>(node.get());
+.set_dispatch<DivNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const DivNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << "/";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<Mod>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Mod*>(node.get());
+.set_dispatch<ModNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ModNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " % ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<Min>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Min*>(node.get());
+.set_dispatch<MinNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const MinNode*>(node.get());
p->stream << "min(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
})
-.set_dispatch<Max>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Max*>(node.get());
+.set_dispatch<MaxNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const MaxNode*>(node.get());
p->stream << "max(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
})
-.set_dispatch<EQ>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const EQ*>(node.get());
+.set_dispatch<EQNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const EQNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " == ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<NE>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const NE*>(node.get());
+.set_dispatch<NENode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const NENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " != ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<LT>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const LT*>(node.get());
+.set_dispatch<LTNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const LTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " < ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<LE>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const LE*>(node.get());
+.set_dispatch<LENode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const LENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " <= ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<GT>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const GT*>(node.get());
+.set_dispatch<GTNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const GTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " > ";
p->Print(op->b);
p->stream << ')';
})
-.set_dispatch<GE>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const GE*>(node.get());
+.set_dispatch<GENode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const GENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " >= ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloorDiv>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const FloorDiv*>(node.get());
+.set_dispatch<FloorDivNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const FloorDivNode*>(node.get());
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloorMod>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const FloorMod*>(node.get());
+.set_dispatch<FloorModNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const FloorModNode*>(node.get());
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<And>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const And*>(node.get());
+.set_dispatch<AndNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const AndNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " && ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Or>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Or*>(node.get());
+.set_dispatch<OrNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const OrNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " || ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Not>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Not*>(node.get());
+.set_dispatch<NotNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const NotNode*>(node.get());
p->stream << '!';
p->Print(op->a);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Select>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Select*>(node.get());
+.set_dispatch<SelectNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const SelectNode*>(node.get());
p->stream << "select(";
p->Print(op->condition);
p->stream << ", ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Load>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Load*>(node.get());
+.set_dispatch<LoadNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const LoadNode*>(node.get());
p->stream << op->buffer_var << "[";
p->Print(op->index);
p->stream << "]";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Ramp>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Ramp*>(node.get());
+.set_dispatch<RampNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const RampNode*>(node.get());
p->stream << "ramp(";
p->Print(op->base);
p->stream << ", ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Broadcast>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Broadcast*>(node.get());
+.set_dispatch<BroadcastNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const BroadcastNode*>(node.get());
p->stream << "x" << op->lanes << "(";
p->Print(op->value);
p->stream << ")";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Call>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Call*>(node.get());
+.set_dispatch<CallNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const CallNode*>(node.get());
p->stream << op->name << "(";
for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Let>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Let*>(node.get());
+.set_dispatch<LetNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = ";
p->Print(op->value);
p->stream << " in ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<LetStmt>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const LetStmt*>(node.get());
+.set_dispatch<LetStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const LetStmtNode*>(node.get());
p->PrintIndent();
p->stream << "let " << op->var << " = ";
p->Print(op->value);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<AttrStmt>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const AttrStmt*>(node.get());
+.set_dispatch<AttrStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const AttrStmtNode*>(node.get());
p->PrintIndent();
p->stream << "// attr [";
p->Print(op->node);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<AssertStmt>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const AssertStmt*>(node.get());
+.set_dispatch<AssertStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const AssertStmtNode*>(node.get());
p->PrintIndent();
p->stream << "assert(";
p->Print(op->condition);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ProducerConsumer>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const ProducerConsumer*>(node.get());
+.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ProducerConsumerNode*>(node.get());
if (op->is_producer) {
p->PrintIndent();
p->stream << "produce " << op->func->func_name() << " {\n";
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<For>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const For*>(node.get());
+.set_dispatch<ForNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ForNode*>(node.get());
p->PrintIndent();
p->stream << op->for_type << " (" << op->loop_var << ", ";
p->Print(op->min);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Store>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Store*>(node.get());
+.set_dispatch<StoreNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const StoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer_var << "[";
p->Print(op->index);
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Provide>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Provide*>(node.get());
+.set_dispatch<ProvideNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ProvideNode*>(node.get());
p->PrintIndent();
p->stream << op->func->func_name() << "(";
for (size_t i = 0; i < op->args.size(); ++i) {
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Allocate>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Allocate*>(node.get());
+.set_dispatch<AllocateNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent();
p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
for (size_t i = 0; i < op->extents.size(); ++i) {
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Free>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Free*>(node.get());
+.set_dispatch<FreeNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const FreeNode*>(node.get());
p->PrintIndent();
p->stream << "free " << op->buffer_var;
p->stream << '\n';
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Realize>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Realize*>(node.get());
+.set_dispatch<RealizeNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent();
p->stream << "realize " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Prefetch>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Prefetch*>(node.get());
+.set_dispatch<PrefetchNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent();
p->stream << "prefetch " << op->func->func_name() << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IfThenElse>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const IfThenElse*>(node.get());
+.set_dispatch<IfThenElseNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const IfThenElseNode*>(node.get());
p->PrintIndent();
while (true) {
p->stream << "if (" << op->condition << ") {\n";
break;
}
- if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
+ if (const IfThenElseNode *nested_if = op->else_case.as<IfThenElseNode>()) {
p->PrintIndent();
p->stream << "} else ";
op = nested_if;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Evaluate>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Evaluate*>(node.get());
+.set_dispatch<EvaluateNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const EvaluateNode*>(node.get());
p->PrintIndent();
p->Print(op->value);
p->stream << "\n";
}
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Shuffle>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Shuffle*>(node.get());
+.set_dispatch<ShuffleNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ShuffleNode*>(node.get());
p->stream << "shuffle(";
PrintList(op->vectors, p);
p->stream << ", ";
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Reduce>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const Reduce*>(node.get());
+.set_dispatch<ReduceNode>([](const ObjectRef& node, NodePrinter* p) {
+ auto* op = static_cast<const ReduceNode*>(node.get());
p->stream << "reduce(combiner="
<< op->combiner;
p->stream << ", source=" << op->source;
});
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Any>([](const ObjectRef& node, NodePrinter* p) {
+.set_dispatch<AnyNode>([](const ObjectRef& node, NodePrinter* p) {
p->stream << "?";
});
TVM_REGISTER_NODE_TYPE(CommReducerNode);
-TVM_REGISTER_NODE_TYPE(Reduce);
-TVM_REGISTER_NODE_TYPE(Any);
-TVM_REGISTER_NODE_TYPE(AttrStmt);
-TVM_REGISTER_NODE_TYPE(FloatImm);
-TVM_REGISTER_NODE_TYPE(IntImm);
-TVM_REGISTER_NODE_TYPE(UIntImm);
-TVM_REGISTER_NODE_TYPE(StringImm);
-TVM_REGISTER_NODE_TYPE(Cast);
-TVM_REGISTER_NODE_TYPE(Variable);
-TVM_REGISTER_NODE_TYPE(Add);
-TVM_REGISTER_NODE_TYPE(Sub);
-TVM_REGISTER_NODE_TYPE(Mul);
-TVM_REGISTER_NODE_TYPE(Div);
-TVM_REGISTER_NODE_TYPE(Mod);
-TVM_REGISTER_NODE_TYPE(FloorDiv);
-TVM_REGISTER_NODE_TYPE(FloorMod);
-TVM_REGISTER_NODE_TYPE(Min);
-TVM_REGISTER_NODE_TYPE(Max);
-TVM_REGISTER_NODE_TYPE(EQ);
-TVM_REGISTER_NODE_TYPE(NE);
-TVM_REGISTER_NODE_TYPE(LT);
-TVM_REGISTER_NODE_TYPE(LE);
-TVM_REGISTER_NODE_TYPE(GT);
-TVM_REGISTER_NODE_TYPE(GE);
-TVM_REGISTER_NODE_TYPE(And);
-TVM_REGISTER_NODE_TYPE(Or);
-TVM_REGISTER_NODE_TYPE(Not);
-TVM_REGISTER_NODE_TYPE(Select);
-TVM_REGISTER_NODE_TYPE(Load);
-TVM_REGISTER_NODE_TYPE(Ramp);
-TVM_REGISTER_NODE_TYPE(Broadcast);
-TVM_REGISTER_NODE_TYPE(Shuffle);
-TVM_REGISTER_NODE_TYPE(Prefetch);
-TVM_REGISTER_NODE_TYPE(Call);
-TVM_REGISTER_NODE_TYPE(Let);
-TVM_REGISTER_NODE_TYPE(LetStmt);
-TVM_REGISTER_NODE_TYPE(AssertStmt);
-TVM_REGISTER_NODE_TYPE(ProducerConsumer);
-TVM_REGISTER_NODE_TYPE(For);
-TVM_REGISTER_NODE_TYPE(Store);
-TVM_REGISTER_NODE_TYPE(Provide);
-TVM_REGISTER_NODE_TYPE(Allocate);
-TVM_REGISTER_NODE_TYPE(Free);
-TVM_REGISTER_NODE_TYPE(Realize);
+TVM_REGISTER_NODE_TYPE(ReduceNode);
+TVM_REGISTER_NODE_TYPE(AnyNode);
+TVM_REGISTER_NODE_TYPE(AttrStmtNode);
+TVM_REGISTER_NODE_TYPE(FloatImmNode);
+TVM_REGISTER_NODE_TYPE(IntImmNode);
+TVM_REGISTER_NODE_TYPE(UIntImmNode);
+TVM_REGISTER_NODE_TYPE(StringImmNode);
+TVM_REGISTER_NODE_TYPE(CastNode);
+TVM_REGISTER_NODE_TYPE(VarNode);
+TVM_REGISTER_NODE_TYPE(AddNode);
+TVM_REGISTER_NODE_TYPE(SubNode);
+TVM_REGISTER_NODE_TYPE(MulNode);
+TVM_REGISTER_NODE_TYPE(DivNode);
+TVM_REGISTER_NODE_TYPE(ModNode);
+TVM_REGISTER_NODE_TYPE(FloorDivNode);
+TVM_REGISTER_NODE_TYPE(FloorModNode);
+TVM_REGISTER_NODE_TYPE(MinNode);
+TVM_REGISTER_NODE_TYPE(MaxNode);
+TVM_REGISTER_NODE_TYPE(EQNode);
+TVM_REGISTER_NODE_TYPE(NENode);
+TVM_REGISTER_NODE_TYPE(LTNode);
+TVM_REGISTER_NODE_TYPE(LENode);
+TVM_REGISTER_NODE_TYPE(GTNode);
+TVM_REGISTER_NODE_TYPE(GENode);
+TVM_REGISTER_NODE_TYPE(AndNode);
+TVM_REGISTER_NODE_TYPE(OrNode);
+TVM_REGISTER_NODE_TYPE(NotNode);
+TVM_REGISTER_NODE_TYPE(SelectNode);
+TVM_REGISTER_NODE_TYPE(LoadNode);
+TVM_REGISTER_NODE_TYPE(RampNode);
+TVM_REGISTER_NODE_TYPE(BroadcastNode);
+TVM_REGISTER_NODE_TYPE(ShuffleNode);
+TVM_REGISTER_NODE_TYPE(PrefetchNode);
+TVM_REGISTER_NODE_TYPE(CallNode);
+TVM_REGISTER_NODE_TYPE(LetNode);
+TVM_REGISTER_NODE_TYPE(LetStmtNode);
+TVM_REGISTER_NODE_TYPE(AssertStmtNode);
+TVM_REGISTER_NODE_TYPE(ProducerConsumerNode);
+TVM_REGISTER_NODE_TYPE(ForNode);
+TVM_REGISTER_NODE_TYPE(StoreNode);
+TVM_REGISTER_NODE_TYPE(ProvideNode);
+TVM_REGISTER_NODE_TYPE(AllocateNode);
+TVM_REGISTER_NODE_TYPE(FreeNode);
+TVM_REGISTER_NODE_TYPE(RealizeNode);
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
-TVM_REGISTER_NODE_TYPE(IfThenElse);
-TVM_REGISTER_NODE_TYPE(Evaluate);
+TVM_REGISTER_NODE_TYPE(IfThenElseNode);
+TVM_REGISTER_NODE_TYPE(EvaluateNode);
} // namespace ir
} // namespace tvm
}
Expr Tensor::operator()(Array<Expr> indices) const {
- using ir::Call;
+ using ir::CallNode;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
- auto n = Call::make(
- (*this)->dtype, (*this)->op->name, indices, Call::Halide,
+ auto n = CallNode::make(
+ (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
(*this)->op, (*this)->value_index);
return n;
}
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode *op);
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
n->attrs = std::move(attrs);
n->axis = std::move(axis);
n->body = std::move(body);
- if (n->body[0]->IsInstance<ir::Reduce>()) {
- const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
+ if (n->body[0]->IsInstance<ir::ReduceNode>()) {
+ const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
n->reduce_axis = reduce->axis;
}
VerifyComputeOp(n.get());
std::unordered_set<Tensor> visited;
for (auto& e : body) {
ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
- const ir::Call *call = n.as<ir::Call>();
+ const ir::CallNode *call = n.as<ir::CallNode>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (!visited.count(t)) {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
Array<Expr> arr;
- if (this->body[0]->IsInstance<ir::Reduce>()) {
+ if (this->body[0]->IsInstance<ir::ReduceNode>()) {
// Specially handle reduce so the replaced op
// still share all the components
Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
- const ir::Reduce* r = new_reduce.as<ir::Reduce>();
+ const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
for (size_t k = 0; k < this->body.size(); ++k) {
- auto n = make_object<ir::Reduce>(*r);
+ auto n = make_object<ir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
arr.push_back(Expr(n));
void ComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
- auto *call = n.as<ir::Call>();
+ auto *call = n.as<ir::CallNode>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (t->op.defined() && out_dom_map->count(t)) {
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i-1);
- realize = ir::Realize::make(t->op, t->value_index,
+ realize = ir::RealizeNode::make(t->op, t->value_index,
t->dtype, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
Array<Expr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
- realize = ir::AttrStmt::make(
+ realize = ir::AttrStmtNode::make(
t, ir::attr::buffer_dim_align,
- Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic),
realize);
}
}
std::vector<Stmt> inits, provides;
size_t size = op->body.size();
- const Reduce* reduce = op->body[0].as<Reduce>();
+ const ReduceNode* reduce = op->body[0].as<ReduceNode>();
CHECK(reduce);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
Array<Expr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
- inits.emplace_back(Provide::make(
+ inits.emplace_back(ProvideNode::make(
t->op, t->value_index, init_value[i], args));
- provides.emplace_back(Provide::make(
+ provides.emplace_back(ProvideNode::make(
t->op, t->value_index, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
- *provide = IfThenElse::make(reduce->condition, *provide);
+ *provide = IfThenElseNode::make(reduce->condition, *provide);
}
}
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
- return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
+ return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
}
Stmt MakeComputeStmt(const ComputeOpNode* self,
/// Special member functions
//@{
explicit ComputeVerifier(const ComputeOpNode* compute)
- : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
+ : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
virtual ~ComputeVerifier() = default;
ComputeVerifier(const ComputeVerifier&) = delete;
ComputeVerifier(ComputeVerifier&&) = delete;
void Run() {
for (const Expr e : compute_->body) {
// Check for consistency of top level reductions
- const ir::Reduce* reduce = e.as<ir::Reduce>();
+ const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
CHECK((reduce && reduce_) || (!reduce && !reduce_))
<< "All ComputeOp should be consistent "
<< "with being Reduce operation or not.";
--level_;
}
- void VisitExpr_(const ir::Reduce* op) final {
+ void VisitExpr_(const ir::ReduceNode* op) final {
// Check for non top level reductions
CHECK(0 == level_)
<< "Reductions are only allowed at the top level of compute. "
private:
const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
- const ir::Reduce* reduce_{nullptr}; ///< Top level Reduce operation
+ const ir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
int level_{0}; ///< Level of op being processed
};
} // namespace
Stmt body,
Stmt update) {
Array<Expr> conds;
- std::unordered_set<const Variable*> banned;
+ std::unordered_set<const VarNode*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto iit = stage->iter_var_attrs.find(iv);
}
}
- return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
+ return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
update, body);
}
} // namespace tvm
size_t size = self->body.size();
CHECK_GT(size, 0);
- std::vector<const Reduce*> reduces(size);
+ std::vector<const ReduceNode*> reduces(size);
for (size_t i = 0; i < size; ++i) {
- const Reduce* reduce = self->body[i].as<Reduce>();
+ const ReduceNode* reduce = self->body[i].as<ReduceNode>();
CHECK(reduce);
reduces[i] = reduce;
}
thread_head_check.emplace_back(stage->store_predicate);
}
- Stmt reduce_body = Evaluate::make(Call::make(
+ Stmt reduce_body = EvaluateNode::make(CallNode::make(
DataType::Handle(),
ir::intrinsic::tvm_thread_allreduce,
- freduce_args, Call::Intrinsic));
- reduce_body = AttrStmt::make(
+ freduce_args, CallNode::Intrinsic));
+ reduce_body = AttrStmtNode::make(
reduces[0]->combiner,
attr::reduce_scope,
make_zero(DataType::Handle()),
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
- assigns[idx] = Provide::make(
+ assigns[idx] = ProvideNode::make(
stage->op, idx,
- Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+ LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = SeqStmt::Flatten(assigns);
assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
- body = Allocate::make(
+ body = AllocateNode::make(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
- body = AttrStmt::make(
- res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
+ body = AttrStmtNode::make(
+ res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
}
body = op::Substitute(body, value_map);
return MergeNest(nest, body);
void ExternOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (Tensor t : this->inputs) {
auto it = out_dom_map->find(t);
Range::make_by_min_extent(
make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = ir::Realize::make(
+ realize_body = ir::RealizeNode::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec;
Array<Expr> tuple;
tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
tuple.push_back(buffer->shape[k]);
}
- ret = AttrStmt::make(
+ ret = AttrStmtNode::make(
bind_spec, attr::buffer_bind_scope,
- Call::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
+ CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
std::unordered_set<Tensor> visited;
Array<Tensor> curr_inputs;
ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
- const ir::Call *call = n.as<ir::Call>();
+ const ir::CallNode *call = n.as<ir::CallNode>();
if (call != nullptr && call->func.defined()) {
Tensor t = Downcast<Operation>(call->func).output(call->value_index);
if (orig_inputs.count(t) && !visited.count(t)) {
void HybridOpNode::PropBoundToInputs(
const Operation &self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet> &dom_map,
+ const std::unordered_map<const VarNode*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
auto curr_inputs = InputTensors();
for (Tensor t : curr_inputs) {
Range::make_by_min_extent(
make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = ir::Realize::make(
+ realize_body = ir::RealizeNode::make(
t->op, t->value_index, t->dtype,
bounds, const_true(), realize_body);
}
const std::unordered_map<IterVar, Range> &dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public StmtExprMutator {
Expr factor;
- const Variable *parent;
+ const VarNode *parent;
IterVar inner, outer;
public:
outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
}
- Stmt VisitStmt_(const For *op) final {
+ Stmt VisitStmt_(const ForNode *op) final {
if (op->loop_var.get() == parent) {
- std::unordered_map<const Variable *, Expr> rmap;
+ std::unordered_map<const VarNode *, Expr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
Stmt ret = ir::Substitute(op->body, rmap);
Expr cond = likely(outer * factor < (op->extent - inner));
- ret = IfThenElse::make(cond, ret);
- ret = For::make(inner->var, Expr(0), inner->dom->extent,
+ ret = IfThenElseNode::make(cond, ret);
+ ret = ForNode::make(inner->var, Expr(0), inner->dom->extent,
IterVarTypeToForType(inner->iter_type), op->device_api, ret);
- ret = For::make(outer->var, Expr(0), outer->dom->extent,
+ ret = ForNode::make(outer->var, Expr(0), outer->dom->extent,
IterVarTypeToForType(outer->iter_type), op->device_api, ret);
splitted = true;
return ret;
class LoopFuser : public StmtExprMutator {
const IterVar &parent;
- const Variable *inner;
- const Variable *outer;
+ const VarNode *inner;
+ const VarNode *outer;
bool under_outer;
Expr extent;
extent(0), fused(false) {}
// TODO(@were): Handle imperfect loops
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
- std::unordered_map<const Variable *, Expr> rmap;
+ std::unordered_map<const VarNode *, Expr> rmap;
rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent;
fused = true;
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const Variable *, Expr> rmap;
+ std::unordered_map<const VarNode *, Expr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap);
under_outer = false;
- return For::make(parent->var, Expr(0), extent * op->extent,
+ return ForNode::make(parent->var, Expr(0), extent * op->extent,
op->for_type, op->device_api, body);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const Variable *, Expr> rmap;
+ std::unordered_map<const VarNode *, Expr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
Stmt ApplyLoopAnnotations(const Stage &stage,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
class LoopAnnotator : public StmtMutator {
- const Variable *var;
+ const VarNode *var;
const IterVarAttr &attr;
public:
- LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
+ LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
- Stmt VisitStmt_(const For *op) final {
+ Stmt VisitStmt_(const ForNode *op) final {
if (op->loop_var.get() == var) {
if (attr->bind_thread.defined()) {
const auto &iter_var = attr->bind_thread;
CHECK(Equal(iter_var->dom->extent, op->extent))
<< "Thread extent and loop extent mismatch!\n";
}
- std::unordered_map<const Variable *, Expr> rmap;
+ std::unordered_map<const VarNode *, Expr> rmap;
rmap[op->loop_var.get()] = iter_var;
Stmt body = ir::Substitute(op->body, rmap);
- return AttrStmt::make(iter_var, "thread_extent", op->extent, body);
+ return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
} else {
- return For::make(op->loop_var, op->min, op->extent,
+ return ForNode::make(op->loop_var, op->min, op->extent,
IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
}
}
int found = 0;
const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
- const Variable *var = actual->var.get();
+ const VarNode *var = actual->var.get();
ForType expected = IterVarTypeToForType(iter_var->iter_type);
IterVarAttr attr;
if (stage->iter_var_attrs.count(iter_var)) {
PostOrderVisit(stmt,
[&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
- if (const For *op = node.as<For>()) {
+ if (const ForNode *op = node.as<ForNode>()) {
if (op->loop_var.get() == var) {
++found;
need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
Stmt ApplyLoopOrder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
- std::vector<const Variable*> current_order;
+ std::vector<const VarNode*> current_order;
PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) {
- if (const For *op = node.as<For>())
+ if (const ForNode *op = node.as<ForNode>())
current_order.push_back(op->loop_var.get());
});
std::reverse(current_order.begin(), current_order.end());
auto &required_ord = stage->leaf_iter_vars;
CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
- std::unordered_map<const Variable *, IterVar> reorder;
+ std::unordered_map<const VarNode *, IterVar> reorder;
bool need_reorder = false;
for (size_t i = 0; i < current_order.size(); ++i) {
auto ¤t = current_order[i];
class LoopReorder : public StmtMutator {
const Stage &stage;
const std::unordered_map<IterVar, Range> &dom_map;
- const std::unordered_map<const Variable *, IterVar> &reorder;
+ const std::unordered_map<const VarNode *, IterVar> &reorder;
public:
LoopReorder(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map,
- const std::unordered_map<const Variable*, IterVar> &reorder)
+ const std::unordered_map<const VarNode*, IterVar> &reorder)
: stage(stage), dom_map(dom_map), reorder(reorder) {}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
// Reorder from in to out
Stmt body_ = this->VisitStmt(op->body);
CHECK(reorder.count(op->loop_var.get()));
for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
}
const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
- return For::make(target->var, range->min, range->extent,
+ return ForNode::make(target->var, range->min, range->extent,
for_type, DeviceAPI::None, body);
}
};
// TODO(@were): Write a comprehensive pass to analyze iter var types
std::vector<IterVar> res_;
PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
- if (const For *op = node.as<For>()) {
+ if (const ForNode *op = node.as<ForNode>()) {
Var loop_var(op->loop_var);
Range dom = Range::make_by_min_extent(op->min, op->extent);
res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
: vmap_(vmap) {}
- Stmt VisitStmt_(const ir::Provide* op) final {
+ Stmt VisitStmt_(const ir::ProvideNode* op) final {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- Stmt ret = ir::Provide::make(
+ Stmt ret = ir::ProvideNode::make(
it->second->op, it->second->value_index, op->value, op->args);
found = true;
return this->VisitStmt(ret);
std::unordered_map<IterVar, Expr>* p_value_map,
bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
- Stmt no_op = Evaluate::make(0);
+ Stmt no_op = EvaluateNode::make(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
}
CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
- const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
+ const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
Expr pvalue = it_attr->pragma_values[k];
if (!pvalue.defined()) {
pvalue = make_const(DataType::Int(32), 1);
}
nest[i + 1].emplace_back(
- AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+ AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(
- LetStmt::make(var, dom->min, no_op));
+ LetStmtNode::make(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
- For::make(var, 0, dom->extent,
+ ForNode::make(var, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
nest[i + 1].emplace_back(
- For::make(idx, 0, dom->extent,
+ ForNode::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(
- LetStmt::make(var, new_value, no_op));
+ LetStmtNode::make(var, new_value, no_op));
}
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
CHECK(!is_one(dom->extent))
it_attr->prefetch_offset.size());
for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
nest[i + 1].emplace_back(
- AttrStmt::make(it_attr->prefetch_data[j],
+ AttrStmtNode::make(it_attr->prefetch_data[j],
ir::attr::prefetch_scope,
it_attr->prefetch_offset[j], no_op));
}
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
- AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
+ AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
CHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
- AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
+ AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
- AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
+ AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
// annotate the extent of the IterVar
if (!new_loop_var) {
nest[i + 1].emplace_back(
- AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
+ AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
}
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
- Stmt no_op = Evaluate::make(0);
+ Stmt no_op = EvaluateNode::make(0);
std::vector<Stmt> nest;
for (const Expr& cond : predicates) {
- nest.emplace_back(IfThenElse::make(cond, no_op));
+ nest.emplace_back(IfThenElseNode::make(cond, no_op));
}
return nest;
}
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
- Expr VisitExpr_(const ir::Call* op) final {
- if (op->call_type == ir::Call::Halide) {
+ Expr VisitExpr_(const ir::CallNode* op) final {
+ if (op->call_type == ir::CallNode::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- Expr ret = ir::Call::make(
+ Expr ret = ir::CallNode::make(
op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
Stmt Substitute(Stmt s,
const std::unordered_map<IterVar, Expr>& value_map) {
- std::unordered_map<const Variable*, Expr> init;
+ std::unordered_map<const VarNode*, Expr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
}
void PlaceholderOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
}
void ScanOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
CHECK_EQ(self.operator->(), this);
for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
IterVar sp_ax = this->spatial_axis_[sp_idx];
CHECK(!out_dom_map->count(sp_ax));
CHECK(fix_pt.count(sp_ax));
- if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
+ if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
// fix point, we can slice it.
(*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
} else {
IterVar sp_ax = this->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
- ret = ir::Realize::make(t->op, t->value_index, t->dtype,
+ ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
bounds, const_true(), ret);
}
return ret;
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt provide = AttrStmt::make(
+ Stmt provide = AttrStmtNode::make(
stage->op, attr::scan_update_scope, this->scan_axis->var,
- Evaluate::make(0));
- Stmt init = AttrStmt::make(
+ EvaluateNode::make(0));
+ Stmt init = AttrStmtNode::make(
stage->op, attr::scan_init_scope, 0,
- Evaluate::make(0));
+ EvaluateNode::make(0));
size_t begin_scan = 0;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
void TensorComputeOpNode::PropBoundToInputs(
const Operation& self,
arith::Analyzer* analyzer,
- const std::unordered_map<const Variable*, IntSet>& dom_map,
+ const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
for (size_t i = 0; i < this->inputs.size(); ++i) {
Tensor t = this->inputs[i];
CHECK_EQ(stage->op.operator->(), this);
// Start bind data.
- Stmt nop = Evaluate::make(0);
+ Stmt nop = EvaluateNode::make(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = this->InputTensors();
tuple.push_back(region[i]->min);
tuple.push_back(region[i]->extent);
}
- input_bind_nest.emplace_back(AttrStmt::make(
+ input_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
}
// output binding
}
}
- output_bind_nest.emplace_back(AttrStmt::make(
+ output_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
}
// Check variable remap
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
ir::ArgBinder binder(&vmap);
// Map the expressions passed in the call to the TensorIntrin, to the placeholder
schedule::PassUpDomain(stage, dom_map, &up_state);
// Get domains if inputs
std::unordered_map<Tensor, TensorDom> in_dom;
- std::unordered_map<const Variable*, IntSet> temp_dmap;
+ std::unordered_map<const VarNode*, IntSet> temp_dmap;
arith::Analyzer analyzer;
Array<Tensor> inputs = self->InputTensors();
for (Tensor t : inputs) {
const ComputeLoopNest& n,
size_t tloc) {
// Veirfication step.
- std::unordered_set<const Variable*> banned;
+ std::unordered_set<const VarNode*> banned;
CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
n.init_nest.size() == 0);
auto f_push_banned = [&banned](const Stmt& s) {
- if (const For* op = s.as<For>()) {
+ if (const ForNode* op = s.as<ForNode>()) {
banned.insert(op->loop_var.get());
- } else if (const AttrStmt* op = s.as<AttrStmt>()) {
+ } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
banned.insert(iv->var.get());
}
- } else if (const LetStmt* op = s.as<LetStmt>()) {
+ } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
banned.insert(op->var.get());
}
};
// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public StmtExprMutator {
public:
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
- if (op->call_type == Call::Halide) {
+ op = expr.as<CallNode>();
+ if (op->call_type == CallNode::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = in_remap_.find(t);
if (it != in_remap_.end()) {
for (size_t i = e.start; i < e.region.size(); ++i) {
args.push_back(op->args[i] - e.region[i]->min);
}
- return Call::make(
+ return CallNode::make(
op->dtype, e.tensor->op->name, args,
op->call_type, e.tensor->op, e.tensor->value_index);
}
return expr;
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
}
}
- Expr VisitExpr_(const Reduce* op) final {
+ Expr VisitExpr_(const ReduceNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Reduce>();
+ op = expr.as<ReduceNode>();
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
auto it = axis_remap_.find(op->axis[i]);
axis.push_back(it->second);
}
}
- return Reduce::make(
+ return ReduceNode::make(
op->combiner, op->source, axis, op->condition, op->value_index);
}
// input data remap
std::unordered_map<Tensor, InputEntry> in_remap_;
// variable remap.
- std::unordered_map<const Variable*, Expr> var_remap_;
+ std::unordered_map<const VarNode*, Expr> var_remap_;
// IterVar remap.
std::unordered_map<IterVar, IterVar> axis_remap_;
};
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
// Start bind data.
- Stmt nop = Evaluate::make(0);
+ Stmt nop = EvaluateNode::make(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size())
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
- input_bind_nest.emplace_back(AttrStmt::make(
+ input_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
}
// output binding
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<ObjectRef> bind_spec{buffer, tensor};
- output_bind_nest.emplace_back(AttrStmt::make(
+ output_bind_nest.emplace_back(AttrStmtNode::make(
bind_spec, ir::attr::buffer_bind_scope,
- Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+ CallNode::make(DataType::Handle(),
+ ir::intrinsic::tvm_tuple,
+ tuple, CallNode::Intrinsic), nop));
}
// Check variable remap
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
ir::ArgBinder binder(&vmap);
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
<< "Tensorization fail: reduction axis size do not match";
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
- asserts->emplace_back(AssertStmt::make(scond, os.str(), Evaluate::make(0)));
+ asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
}
}
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.dtype(), value.dtype());
- if (const Variable* v = arg.as<Variable>()) {
+ if (const VarNode* v = arg.as<VarNode>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg = Downcast<Var>(arg);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
- init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
+ init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0)));
} else {
(*def_map_)[v] = value;
}
const std::string& arg_name) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
- const Stmt nop = Evaluate::make(0);
+ const Stmt nop = EvaluateNode::make(0);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
Expr a_ndim = make_const(tvm_ndim_type,
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
- asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
+ asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
// type checks
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
- UIntImm::make(DataType::UInt(8), dtype.code()) &&
+ UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
- UIntImm::make(DataType::UInt(8), dtype.bits()) &&
+ UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
- UIntImm::make(DataType::UInt(16), dtype.lanes()));
- asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
+ UIntImmNode::make(DataType::UInt(16), dtype.lanes()));
+ asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, ir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
- init_nest_.emplace_back(AttrStmt::make(
+ init_nest_.emplace_back(AttrStmtNode::make(
vptr, ir::attr::storage_alignment,
- IntImm::make(DataType::Int(32), buffer->data_alignment), nop));
+ IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop));
}
Var v_shape(arg_name + ".shape", DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
- init_nest_.emplace_back(LetStmt::make(
+ init_nest_.emplace_back(LetStmtNode::make(
v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
cast(buffer->shape[k].dtype(),
- Load::make(tvm_shape_type, v_shape,
- IntImm::make(DataType::Int(32), k), const_true(1))),
+ LoadNode::make(tvm_shape_type, v_shape,
+ IntImmNode::make(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
Var v_strides(arg_name + ".strides", DataType::Handle());
def_handle_dtype_.Set(v_strides, ir::TypeAnnotation(tvm_shape_type));
- init_nest_.emplace_back(LetStmt::make(
+ init_nest_.emplace_back(LetStmtNode::make(
v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
nop));
- Expr is_null = Call::make(
+ Expr is_null = CallNode::make(
DataType::Bool(1), intrinsic::tvm_handle_is_null,
- {v_strides}, Call::PureIntrinsic);
+ {v_strides}, CallNode::PureIntrinsic);
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
size_t k = i - 1;
Expr svalue = cast(
stype,
- Load::make(tvm_shape_type, v_strides,
- IntImm::make(DataType::Int(32), k), const_true(1)));
+ LoadNode::make(tvm_shape_type, v_strides,
+ IntImmNode::make(DataType::Int(32), k), const_true(1)));
conds.push_back(expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
<< " expected to be compact array";
if (conds.size() != 0) {
Stmt check =
- AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
- stride_err_msg.str(), Evaluate::make(0));
- check = IfThenElse::make(Not::make(is_null), check, Stmt());
- asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)}));
+ AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, Expr()),
+ stride_err_msg.str(), EvaluateNode::make(0));
+ check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
+ asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Expr value = cast(buffer->shape[k].dtype(),
- Load::make(tvm_shape_type, v_strides,
- IntImm::make(DataType::Int(32), k), const_true(1)));
+ LoadNode::make(tvm_shape_type, v_strides,
+ IntImmNode::make(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
- asserts_.emplace_back(AssertStmt::make(Not::make(is_null), stride_null_err_msg.str(), nop));
+ asserts_.emplace_back(
+ AssertStmtNode::make(
+ NotNode::make(is_null), stride_null_err_msg.str(), nop));
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
cast(buffer->shape[k].dtype(),
- Load::make(tvm_shape_type, v_strides,
- IntImm::make(DataType::Int(32), k), const_true(1))),
+ LoadNode::make(tvm_shape_type, v_strides,
+ IntImmNode::make(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* ArgBinder will update this def_map when adding new definitions.
*/
explicit ArgBinder(
- std::unordered_map<const Variable*, Expr>* def_map)
+ std::unordered_map<const VarNode*, Expr>* def_map)
: def_map_(def_map) {
}
/*!
const std::string& arg_name,
bool with_lets);
/*! \brief The definition map, can be uses to substitute */
- std::unordered_map<const Variable*, Expr>* def_map_;
+ std::unordered_map<const VarNode*, Expr>* def_map_;
/*! \brief defs generated in the current binder */
std::vector<Var> defs_;
/*! \brief Initialize nest */
public:
BoundCollector() {}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == ir::attr::buffer_bound) {
- if (const Variable *key = op->node.as<Variable>()) {
+ if (const VarNode *key = op->node.as<VarNode>()) {
mem_to_shape[key] = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const Variable *, Expr> mem_to_shape;
+ std::unordered_map<const VarNode *, Expr> mem_to_shape;
};
class BoundChecker : public StmtExprMutator {
public:
explicit BoundChecker(
- const std::unordered_map<const Variable *, Expr> &mem_to_shape)
+ const std::unordered_map<const VarNode *, Expr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
// If the shape was updated we should update the hashtable.
if (UpdateIsNeeded(op->buffer_var)) {
Update(op->buffer_var, op->extents, op->dtype);
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
return StmtExprMutator::VisitExpr_(op);
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
store_scope_bound_collector_.clear();
process_store_ = true;
unsafe_rewritten_ = false;
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
Expr condition = MakeCondition();
- if (!condition.as<StringImm>()) {
- Stmt nop = Evaluate::make(1);
+ if (!condition.as<StringImmNode>()) {
+ Stmt nop = EvaluateNode::make(1);
Stmt then_case =
- Store::make(op->buffer_var, op->value, op->index, op->predicate);
+ StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
Stmt else_case =
- AssertStmt::make(condition, StringImm::make(error_message_), nop);
- Stmt body = IfThenElse::make(condition, then_case, else_case);
+ AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop);
+ Stmt body = IfThenElseNode::make(condition, then_case, else_case);
return body;
}
}
return GetRef<Stmt>(op);
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
}
// Scalarize the shape.
- Expr shape = Mul::make(make_const(DataType::UInt(64), type.lanes()),
- Cast::make(DataType::UInt(64), new_shape[0]));
+ Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+ CastNode::make(DataType::UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
- shape = Mul::make(shape, Mul::make(make_const(DataType::UInt(64), type.lanes()),
- Cast::make(DataType::UInt(64), new_shape[i])));
+ shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+ CastNode::make(DataType::UInt(64), new_shape[i])));
}
mem_to_shape_[buffer_var.get()] = shape;
}
return false;
}
- if (const Ramp *ramp_index = index.as<Ramp>()) {
+ if (const RampNode *ramp_index = index.as<RampNode>()) {
return ramp_index->base.defined() &&
ramp_index->base.dtype().is_scalar() &&
ramp_index->stride.defined() &&
Expr index = buffer_to_mem.first;
Expr upper_bound = buffer_to_mem.second;
- if (const Ramp *ramp_index = index.as<Ramp>()) {
+ if (const RampNode *ramp_index = index.as<RampNode>()) {
// In case index is base + stride * i.
// Non inclusive range.
- index = Add::make(
+ index = AddNode::make(
ramp_index->base,
- Mul::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
+ MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
ramp_index->lanes - 1)));
}
upper_bound = ir::Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
- index = Cast::make(DataType::Int(64), index);
- upper_bound = Cast::make(DataType::Int(64), upper_bound);
+ index = CastNode::make(DataType::Int(64), index);
+ upper_bound = CastNode::make(DataType::Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
Expr lower_bound = make_zero(DataType::Int(64));
Expr current_condition =
- And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
+ AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
condition =
- !i ? current_condition : And::make(condition, current_condition);
+ !i ? current_condition : AndNode::make(condition, current_condition);
}
return condition;
}
// Error message.
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const Variable *, Expr> mem_to_shape_;
+ std::unordered_map<const VarNode *, Expr> mem_to_shape_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
}
};
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
Expr ctx = op->args[0];
} else {
CHECK(ctx.dtype().is_handle());
std::string name;
- if (const Call* call = ctx.as<Call>()) {
+ if (const CallNode* call = ctx.as<CallNode>()) {
name = call->name + "_cache";
} else {
name = "ctx_cache_";
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
std::map<Expr, Var, CompareExpr> temp;
static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap,
Stmt body) {
for (const auto& kv : cmap) {
- body = LetStmt::make(kv.second, kv.first, body);
+ body = LetStmtNode::make(kv.second, kv.first, body);
}
return body;
}
// Visitor to find touched set by co-processor scope.
class CoProcTouchedBuffer : public StmtExprVisitor {
public:
- void VisitExpr_(const Load* op) final {
+ void VisitExpr_(const LoadNode* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
}
StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const Store* op) final {
+ void VisitStmt_(const StoreNode* op) final {
if (in_scope_) {
touched_[op->buffer_var.get()].coproc = true;
} else {
}
StmtExprVisitor::VisitStmt_(op);
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
if (in_scope_) {
touched_[buffer].coproc = true;
} else {
}
StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::coproc_scope && !in_scope_) {
in_scope_ = true;
IterVar iv = Downcast<IterVar>(op->node);
bool normal{false};
bool coproc{false};
};
- std::unordered_map<const Variable*, TouchEntry> touched_;
+ std::unordered_map<const VarNode*, TouchEntry> touched_;
std::unordered_set<IterVar> coproc_;
private:
class CoProcSyncPlanner : public StorageAccessVisitor {
public:
explicit CoProcSyncPlanner(
- const std::unordered_set<const Variable*>& touched,
+ const std::unordered_set<const VarNode*>& touched,
const std::string& coproc_name)
: touched_(touched), coproc_name_(coproc_name) {
}
std::unordered_map<const Object*, std::vector<Stmt> > sync_;
protected:
- bool Enabled(const Variable* buf,
+ bool Enabled(const VarNode* buf,
const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const For* loop) final {
+ std::vector<StmtEntry> seq, const ForNode* loop) final {
return PlanSync(seq, loop, false);
}
private:
// Plan write synchronization if write is not coherent
std::vector<AccessEntry> PlanSync(
- std::vector<StmtEntry> seq, const For* loop,
+ std::vector<StmtEntry> seq, const ForNode* loop,
bool force_sync_at_end) {
// detect write barriers
// access by the co-processor.
}
std::vector<Stmt> GetSync(std::string sync_name) {
- return {Evaluate::make(Call::make(
+ return {EvaluateNode::make(CallNode::make(
DataType::Int(32),
sync_name,
- {}, Call::Intrinsic))};
+ {}, CallNode::Intrinsic))};
}
- const std::unordered_set<const Variable*>& touched_;
+ const std::unordered_set<const VarNode*>& touched_;
std::string coproc_name_;
};
class CoProcBarrierDetector : public StorageAccessVisitor {
public:
explicit CoProcBarrierDetector(
- const std::unordered_set<const Variable*>& touched,
+ const std::unordered_set<const VarNode*>& touched,
const std::string& coproc_name)
: touched_(touched) {
read_barrier_name_ = coproc_name + ".coproc_read_barrier";
std::unordered_map<const Object*, std::vector<Stmt> > barrier_after_;
protected:
- bool Enabled(const Variable* buf,
+ bool Enabled(const VarNode* buf,
const StorageScope& scope) const final {
return touched_.count(buf);
}
// Plan the sync
std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const For* loop) final {
+ std::vector<StmtEntry> seq, const ForNode* loop) final {
if (read_barrier_) {
return PlanReadBarrier(seq, loop);
} else {
private:
// Plan write barrier at Read after write point.
std::vector<AccessEntry> PlanWriteBarrier(
- std::vector<StmtEntry> seq, const For* loop) {
+ std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> read_seq;
- std::unordered_map<const Variable*, std::vector<AccessEntry> > write_set;
+ std::unordered_map<const VarNode*, std::vector<AccessEntry> > write_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = write_set.find(acc.buffer.get());
}
std::vector<AccessEntry> PlanReadBarrier(
- std::vector<StmtEntry> seq, const For* loop) {
+ std::vector<StmtEntry> seq, const ForNode* loop) {
std::vector<AccessEntry> write_seq;
- std::unordered_map<const Variable*, std::vector<AccessEntry> > read_set;
+ std::unordered_map<const VarNode*, std::vector<AccessEntry> > read_set;
auto fupdate = [&](size_t i, const AccessEntry& acc) {
auto it = read_set.find(acc.buffer.get());
<< "Cannot deduce write range of " << wvec[0].buffer;
Expr min = r->min;
Expr extent = r->extent;
- return Evaluate::make(Call::make(
+ return EvaluateNode::make(CallNode::make(
DataType::Int(32), func,
- {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic));
+ {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic));
}
// Write barrier name
bool read_barrier_{false};
std::string read_barrier_name_;
std::string write_barrier_name_;
- const std::unordered_set<const Variable*>& touched_;
+ const std::unordered_set<const VarNode*>& touched_;
};
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::coproc_scope &&
op->node.same_as(coproc_axis_)) {
- const IntImm* ctx_id = op->value.as<IntImm>();
+ const IntImmNode* ctx_id = op->value.as<IntImmNode>();
CHECK(ctx_id != nullptr);
curr_state_.clear();
curr_state_.node = op->body.get();
}
}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
SyncState temp_first, temp_last;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
}
}
- void VisitStmt_(const IfThenElse* op) final {
+ void VisitStmt_(const IfThenElseNode* op) final {
SyncState temp_first, temp_last, curr_state;
std::swap(first_state_, temp_first);
std::swap(last_state_, temp_last);
}
Stmt MakePush(int from, int to) {
- return Evaluate::make(Call::make(
+ return EvaluateNode::make(CallNode::make(
DataType::Int(32), sync_push_name_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- Call::Intrinsic));
+ CallNode::Intrinsic));
}
Stmt MakePop(int from, int to) {
- return Evaluate::make(Call::make(
+ return EvaluateNode::make(CallNode::make(
DataType::Int(32), sync_pop_name_,
{make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- Call::Intrinsic));
+ CallNode::Intrinsic));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
CoProcTouchedBuffer visitor;
visitor(stmt);
if (visitor.coproc_.size() == 0) return stmt;
- std::unordered_set<const Variable*> touched;
+ std::unordered_set<const VarNode*> touched;
for (const auto &kv : visitor.touched_) {
if (kv.second.normal && kv.second.coproc) {
namespace tvm {
namespace ir {
Stmt DecorateDeviceScope(Stmt stmt) {
- Stmt body = AttrStmt::make(make_zero(DataType::Int(32)),
+ Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)),
ir::attr::device_scope,
0,
stmt);
// in a For stmt.
bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
std::vector<const Object*> if_node_list;
- const For* for_node = for_stmt.as<For>();
+ const ForNode* for_node = for_stmt.as<ForNode>();
CHECK(for_node);
- CHECK(if_stmt.as<IfThenElse>());
+ CHECK(if_stmt.as<IfThenElseNode>());
PostOrderVisit(for_node->body, [&](const ObjectRef& node) {
- if (node.as<IfThenElse>()) {
+ if (node.as<IfThenElseNode>()) {
if_node_list.push_back(node.get());
}
});
// in the main VisitAndMutate function.
Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
const Object* top_for_node;
- const For* parent_for_node = parent_for_stmt.as<For>();
+ const ForNode* parent_for_node = parent_for_stmt.as<ForNode>();
CHECK(parent_for_node);
- CHECK(new_if_stmt.as<IfThenElse>());
+ CHECK(new_if_stmt.as<IfThenElseNode>());
PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) {
- if (node.as<For>()) {
+ if (node.as<ForNode>()) {
top_for_node = node.get();
}
});
std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
Stmt then_for;
Stmt else_for;
- CHECK(if_stmt.as<IfThenElse>());
+ CHECK(if_stmt.as<IfThenElseNode>());
PackedFunc replace_then_case = PackedFunc(
[&](TVMArgs args, TVMRetValue *ret){
const ObjectRef& node = args[0];
if (node == if_stmt) {
- *ret = node.as<IfThenElse>()->then_case;
+ *ret = node.as<IfThenElseNode>()->then_case;
}
});
[&](TVMArgs args, TVMRetValue *ret){
const ObjectRef& node = args[0];
if (node == if_stmt) {
- *ret = node.as<IfThenElse>()->else_case;
+ *ret = node.as<IfThenElseNode>()->else_case;
}
});
then_for = IRTransform(for_stmt, nullptr, replace_then_case,
{Expr("IfThenElse")});
- if (if_stmt.as<IfThenElse>()->else_case) {
+ if (if_stmt.as<IfThenElseNode>()->else_case) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
{Expr("IfThenElse")});
}
// Locate all For nodes and capture child IfThenElse nodes.
void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
PostOrderVisit(stmt, [&](const ObjectRef& node){
- const For* for_node = node.as<For>();
+ const ForNode* for_node = node.as<ForNode>();
if (!for_node) return;
std::queue<Stmt> tracker;
while (!tracker.empty()) {
Stmt head = tracker.front();
tracker.pop();
- if (head->IsInstance<For>()) {
+ if (head->IsInstance<ForNode>()) {
for (const auto& if_stmt : for2if_map_.at(head.get())) {
for2if_map_[for_stmt.get()].push_back(if_stmt);
}
- } else if (head->IsInstance<AttrStmt>()) {
- const AttrStmt* attr_node = head.as<AttrStmt>();
+ } else if (head->IsInstance<AttrStmtNode>()) {
+ const AttrStmtNode* attr_node = head.as<AttrStmtNode>();
tracker.push(attr_node->body);
- } else if (head->IsInstance<IfThenElse>()) {
+ } else if (head->IsInstance<IfThenElseNode>()) {
for2if_map_[for_stmt.get()].push_back(head);
- const IfThenElse* if_node = head.as<IfThenElse>();
+ const IfThenElseNode* if_node = head.as<IfThenElseNode>();
tracker.push(if_node->then_case);
if (if_node->else_case) {
tracker.push(if_node->else_case);
std::unordered_set<const Object*> new_var_set;
cond_var_map_.insert({head.get(), new_var_set});
PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) {
- if (cond_node.as<Variable>()) {
+ if (cond_node.as<VarNode>()) {
cond_var_map_[head.get()].insert(cond_node.get());
}
});
// Create IfThenElse -> For map.
for (const Stmt& for_stmt : ordered_for_list_) {
std::vector<Stmt> if_list = for2if_map_[for_stmt.get()];
- const For* for_node = for_stmt.as<For>();
+ const ForNode* for_node = for_stmt.as<ForNode>();
CHECK(for_node);
top_for_var_map_.insert({for_node->loop_var.get(), if_list});
for (const Stmt& if_stmt : if_list) {
std::vector<Stmt> for_list = item.second;
for (size_t i = 0; i < for_list.size(); ++i) {
const Stmt& for_stmt = for_list.at(i);
- const For* for_node = for_stmt.as<For>();
+ const ForNode* for_node = for_stmt.as<ForNode>();
CHECK(for_node);
std::vector<Stmt> new_for_list{for_stmt};
for_tracking_map_.insert({for_stmt.get(), new_for_list});
top_for = for_stmt;
}
}
- if (top_for.as<For>()) {
+ if (top_for.as<ForNode>()) {
if_position_map.insert({if_stmt, top_for});
}
}
for (const auto& item : if_position_map) {
- top_for_var_set.insert(item.second.as<For>()->loop_var.get());
+ top_for_var_set.insert(item.second.as<ForNode>()->loop_var.get());
}
std::vector<const Object*> removed_for_var_list;
for_tracking_map_[for_stmt.get()].push_back(else_for);
}
- const IfThenElse* new_if_node = new_if.as<IfThenElse>();
+ const IfThenElseNode* new_if_node = new_if.as<IfThenElseNode>();
CHECK(new_if_node);
- new_if = IfThenElse::make(new_if_node->condition, then_for, else_for);
+ new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for);
if (i < if2for_map_[if_stmt.get()].size() - 1) {
const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
const Stmt& actual_next_for =
PackedFunc replace_top_for = PackedFunc(
[&](TVMArgs args, TVMRetValue *ret){
const ObjectRef& current_for = args[0];
- const For* for_node = current_for.as<For>();
+ const ForNode* for_node = current_for.as<ForNode>();
if (!for_node) return;
if (top_for_var_map_.count(for_node->loop_var.get())) {
new_if_list.emplace_back(HoistIf(if_stmt));
}
- const IfThenElse* next_if_node;
- const IfThenElse* current_if_node =
- new_if_list.back().as<IfThenElse>();
+ const IfThenElseNode* next_if_node;
+ const IfThenElseNode* current_if_node =
+ new_if_list.back().as<IfThenElseNode>();
Stmt new_for = Stmt();
for (size_t i = new_if_list.size() - 1; i > 0; --i) {
CHECK(current_if_node);
const Stmt current_if_stmt =
- IfThenElse::make(current_if_node->condition,
+ IfThenElseNode::make(current_if_node->condition,
current_if_node->then_case,
current_if_node->else_case);
- next_if_node = new_if_list[i - 1].as<IfThenElse>();
+ next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
CHECK(next_if_node);
- new_for = IfThenElse::make(next_if_node->condition, current_if_stmt,
+ new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt,
next_if_node->else_case);
- current_if_node = new_for.as<IfThenElse>();
+ current_if_node = new_for.as<IfThenElseNode>();
}
if (!new_for.get()) {
- const IfThenElse* first_if_node = new_if_list[0].as<IfThenElse>();
+ const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
CHECK(first_if_node);
- new_for = IfThenElse::make(first_if_node->condition,
+ new_for = IfThenElseNode::make(first_if_node->condition,
first_if_node->then_case,
first_if_node->else_case);
}
: m(_m), n(_n), k(_k), layout(_layout) {}
};
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
// Get shape and layout information from load and store intrinsic
CHECK_EQ(op->args.size(), 8U);
- const Variable* buffer_var = op->args[0].as<Variable>();
+ const VarNode* buffer_var = op->args[0].as<VarNode>();
CHECK(buffer_var);
// Get shape
- const IntImm* m = op->args[1].as<IntImm>();
- const IntImm* n = op->args[2].as<IntImm>();
- const IntImm* k = op->args[3].as<IntImm>();
- const StringImm* layout = op->args[7].as<StringImm>();
+ const IntImmNode* m = op->args[1].as<IntImmNode>();
+ const IntImmNode* n = op->args[2].as<IntImmNode>();
+ const IntImmNode* k = op->args[3].as<IntImmNode>();
+ const StringImmNode* layout = op->args[7].as<StringImmNode>();
CHECK(m);
CHECK(n);
CHECK(k);
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
// Get shape information from fill intrinsic
CHECK_EQ(op->args.size(), 6U);
- const Variable* buffer_var = op->args[0].as<Variable>();
+ const VarNode* buffer_var = op->args[0].as<VarNode>();
CHECK(buffer_var);
// Get shape
- const IntImm* m = op->args[1].as<IntImm>();
- const IntImm* n = op->args[2].as<IntImm>();
- const IntImm* k = op->args[3].as<IntImm>();
+ const IntImmNode* m = op->args[1].as<IntImmNode>();
+ const IntImmNode* n = op->args[2].as<IntImmNode>();
+ const IntImmNode* k = op->args[3].as<IntImmNode>();
CHECK(m);
CHECK(n);
CHECK(k);
}
// Get memory scope
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
- const Variable* buffer = op->node.as<Variable>();
+ const VarNode* buffer = op->node.as<VarNode>();
CHECK(buffer);
- scopes[buffer] = op->value.as<StringImm>()->value;
+ scopes[buffer] = op->value.as<StringImmNode>()->value;
}
StmtExprVisitor::VisitStmt_(op);
}
// Memory scope for allocations
- std::unordered_map<const Variable*, std::string> scopes;
+ std::unordered_map<const VarNode*, std::string> scopes;
// Fragment metadata for all fragments
- std::unordered_map<const Variable*, FragmentInfo> fragments;
+ std::unordered_map<const VarNode*, FragmentInfo> fragments;
};
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
public:
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
CHECK_EQ(op->args.size(), 8U);
- const Variable* buffer_var_d = op->args[0].as<Variable>();
- const Variable* buffer_var_a = op->args[2].as<Variable>();
- const Variable* buffer_var_b = op->args[4].as<Variable>();
- const Variable* buffer_var_c = op->args[6].as<Variable>();
+ const VarNode* buffer_var_d = op->args[0].as<VarNode>();
+ const VarNode* buffer_var_a = op->args[2].as<VarNode>();
+ const VarNode* buffer_var_b = op->args[4].as<VarNode>();
+ const VarNode* buffer_var_c = op->args[6].as<VarNode>();
CHECK(buffer_var_d);
CHECK(buffer_var_a);
CHECK(buffer_var_b);
private:
// A tool for checking shapes of two fragments
- bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
+ bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
CHECK(fragment_getter.fragments.count(buffer1));
CHECK(fragment_getter.fragments.count(buffer2));
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
public:
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- const Variable* buffer = op->buffer_var.get();
+ const VarNode* buffer = op->buffer_var.get();
if (fragment_getter.fragments.count(buffer)) {
// Add attribute to fragments allocation
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
std::string shape = std::to_string(info.m) + ", " +
std::to_string(info.n) + ", " +
std::to_string(info.k);
- Expr shape_expr = StringImm::make(shape);
- Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
+ Expr shape_expr = StringImmNode::make(shape);
+ Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
- Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
- StringImm::make(info.layout), shape_attr);
+ Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout,
+ StringImmNode::make(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
flower_copy_fromto_(flower_copy_fromto) {
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
- storage_scope_[buf] = op->value.as<StringImm>()->value;
+ const VarNode* buf = op->node.as<VarNode>();
+ storage_scope_[buf] = op->value.as<StringImmNode>()->value;
} else if (op->attr_key == pragma_key_) {
Stmt ret;
CHECK(MatchCopyPattern(op->body, &ret))
Stmt body = stmt;
// strip the loops
- std::vector<const For*> loops;
- while (const For* op = body.as<For>()) {
+ std::vector<const ForNode*> loops;
+ while (const ForNode* op = body.as<ForNode>()) {
if (!is_zero(op->min)) return false;
loops.push_back(op);
body = op->body;
}
- const Store* store = body.as<Store>();
+ const StoreNode* store = body.as<StoreNode>();
if (store == nullptr) return false;
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
- const Cast* cast = store->value.as<Cast>();
- const Load* load = store->value.as<Load>();
+ const CastNode* cast = store->value.as<CastNode>();
+ const LoadNode* load = store->value.as<LoadNode>();
if (0 == loops.size()) {
CHECK(!has_cond);
}
// for now only support true condition matching
if (has_cond) {
- load = sel_true_value.Eval().as<Load>();
+ load = sel_true_value.Eval().as<LoadNode>();
}
// cast can be part of the pattern
if (cast != nullptr) {
- load = cast->value.as<Load>();
+ load = cast->value.as<LoadNode>();
}
if (load == nullptr) return false;
if (load->dtype.lanes() != 1) return false;
Array<Var> loop_vars;
- for (const For* op : loops) {
+ for (const ForNode* op : loops) {
loop_vars.push_back(op->loop_var);
}
Array<Expr> store_strides =
if (loop_var_size == 0) {
dst_shape.push_back(make_const(DataType::Int(32), 1));
} else {
- for (const For* op : loops) {
+ for (const ForNode* op : loops) {
dst_shape.push_back(op->extent);
}
}
DataType t = loop_vars[i].dtype();
Expr svalue = src_shape[i];
if (min_value.defined()) {
- Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
+ Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
src_elem_offset = src_elem_offset + pbefore * load_strides[i];
svalue = svalue - pbefore;
pad_before.push_back(pbefore);
pad_before.push_back(make_zero(t));
}
if (max_value.defined()) {
- Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1),
+ Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
make_zero(t)));
svalue = svalue - pafter;
pad_after.push_back(pafter);
return true;
}
// Get storage scope
- std::string GetStorageScope(const Variable* var) const {
+ std::string GetStorageScope(const VarNode* var) const {
auto it = storage_scope_.find(var);
if (it != storage_scope_.end()) {
return it->second;
// function to lower copy intrinsics.
const PackedFunc& flower_copy_fromto_;
// Storage scope
- std::unordered_map<const Variable*, std::string> storage_scope_;
+ std::unordered_map<const VarNode*, std::string> storage_scope_;
};
Stmt InjectCopyIntrin(Stmt stmt,
// Detect double buffer variables.
class DoubleBufferDetector : public StmtExprVisitor {
public:
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::double_buffer_scope) {
- touched_.insert(op->node.as<Variable>());
+ touched_.insert(op->node.as<VarNode>());
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
if (touched_.count(op)) {
touched_.erase(op);
}
}
// The set of touched variable.
- std::unordered_set<const Variable*> touched_;
+ std::unordered_set<const VarNode*> touched_;
};
class StripDoubleBufferWrite : public StmtMutator {
public:
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::double_buffer_write) {
return VisitStmt(op->body);
} else {
DoubleBufferDetector detector;
detector(stmt);
if (detector.touched_.empty()) return stmt;
- for (const Variable* v : detector.touched_) {
+ for (const VarNode* v : detector.touched_) {
dbuffer_info_[v] = StorageEntry();
}
return ConvertSSA(operator()(std::move(stmt)));
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
+ const VarNode* buf = op->node.as<VarNode>();
auto it = dbuffer_info_.find(buf);
if (it != dbuffer_info_.end()) {
- it->second.scope = op->value.as<StringImm>()->value;
+ it->second.scope = op->value.as<StringImmNode>()->value;
return this->VisitStmt(op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
- it->second.stride = arith::ComputeReduce<Mul>(
+ it->second.stride = arith::ComputeReduce<MulNode>(
op->extents, Expr()) * op->dtype.lanes();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
for (Expr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
- alloc_nest.emplace_back(AttrStmt::make(
+ alloc_nest.emplace_back(AttrStmtNode::make(
op->buffer_var, attr::storage_scope,
- StringImm::make(it->second.scope),
- Evaluate::make(0)));
- alloc_nest.emplace_back(Allocate::make(
+ StringImmNode::make(it->second.scope),
+ EvaluateNode::make(0)));
+ alloc_nest.emplace_back(AllocateNode::make(
op->buffer_var, op->dtype, new_extents, op->condition,
- Evaluate::make(0)));
+ EvaluateNode::make(0)));
return op->body;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
loop_nest_.push_back(op);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = loop_pre_.find(op);
if (it != loop_pre_.end()) {
- const For* old_loop = stmt.as<For>();
+ const ForNode* old_loop = stmt.as<ForNode>();
if (split_loop_ != 0) {
// Explicitly unroll the loop
CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
Expr outer_ext = new_ext / factor;
Expr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
std::vector<Stmt> loop_seq;
for (int32_t i = 0; i < split_loop_; ++i) {
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
- Stmt loop = For::make(
+ Stmt loop = ForNode::make(
outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
SeqStmt::Flatten(loop_seq));
// tail
Expr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
- IfThenElse::make(idx < old_loop->extent,
+ IfThenElseNode::make(idx < old_loop->extent,
Substitute(tail_body, vmap)));
}
stmt = SeqStmt::Flatten(loop, tail_seq);
return stmt;
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.stride.defined());
- return Store::make(op->buffer_var,
+ return StoreNode::make(op->buffer_var,
op->value,
e.switch_write_var * e.stride + op->index,
op->predicate);
}
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
const StorageEntry& e = it->second;
CHECK(e.stride.defined());
CHECK(e.switch_read_var.defined());
- return Load::make(op->dtype,
+ return LoadNode::make(op->dtype,
op->buffer_var,
e.switch_read_var * e.stride + op->index,
op->predicate);
}
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
CHECK(!dbuffer_info_.count(op));
return GetRef<Expr>(op);
}
private:
- Stmt MakeProducer(const AttrStmt* op) {
+ Stmt MakeProducer(const AttrStmtNode* op) {
const VarExpr buffer = Downcast<VarExpr>(op->node);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
in_double_buffer_scope_ = true;
Stmt body = this->VisitStmt(op->body);
in_double_buffer_scope_ = false;
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap);
- body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
- body = IfThenElse::make(loop_shift < e.loop->extent, body);
+ body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body);
+ body = IfThenElseNode::make(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
// The size of the buffer
Expr stride;
// The loop we need
- const For* loop{nullptr};
+ const ForNode* loop{nullptr};
// The switch variable.
VarExpr switch_write_var;
// The switch variable for reading.
// Whether we are inside double buffer scope.
bool in_double_buffer_scope_{false};
// The current loop next
- std::vector<const For*> loop_nest_;
+ std::vector<const ForNode*> loop_nest_;
// The allocs to be appended before the loop
- std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
+ std::unordered_map<const ForNode*, std::vector<Stmt> > loop_allocs_;
// The stmt to be appended before the loop
- std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
+ std::unordered_map<const ForNode*, std::vector<Stmt> > loop_pre_;
// The allocation size of the buffer
- std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
+ std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
};
class PrefetchInjector : public StmtMutator {
public:
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtMutator::VisitStmt_(op);
- op = ret.as<AttrStmt>();
+ op = ret.as<AttrStmtNode>();
if (op && op->attr_key == attr::prefetch_scope) {
Tensor ts = Downcast<Tensor>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
vectorized_.erase(iter_var);
- Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region);
+ Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region);
return SeqStmt({prefetch, op->body});
}
return ret;
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
auto &var = op->loop_var;
loop_nest_.push_back(var);
if (op->for_type == ForType::Vectorized) {
private:
std::vector<VarExpr> loop_nest_;
- std::unordered_map<const Variable *, IntSet> vectorized_;
+ std::unordered_map<const VarNode *, IntSet> vectorized_;
static const Range none;
};
// If expression is touched by var.
class ExprTouched final : public StmtExprVisitor {
public:
- explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
+ explicit ExprTouched(const std::unordered_set<const VarNode*> &touched,
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitStmt(n);
}
- void VisitExpr_(const Load *op) final {
+ void VisitExpr_(const LoadNode *op) final {
HandleUseVar(op->buffer_var.get());
StmtExprVisitor::VisitExpr_(op);
}
- void VisitExpr_(const Variable *op) final {
+ void VisitExpr_(const VarNode *op) final {
HandleUseVar(op);
}
- void VisitExpr_(const Call *op) final {
+ void VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask = 0;
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
- const Variable* buffer_var = op->args[1].as<Variable>();
+ const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
// read
if (rw_mask & 1) {
StmtExprVisitor::VisitExpr_(op);
}
}
- void HandleUseVar(const Variable* var) {
+ void HandleUseVar(const VarNode* var) {
auto it = touched_var_.find(var);
if (it != touched_var_.end()) {
expr_touched_ = true;
used_vars_.push_back(var);
}
}
- void HandleWriteVar(const Variable* var) {
+ void HandleWriteVar(const VarNode* var) {
write_vars_.push_back(var);
}
// the fields.
bool expr_touched_{false};
- std::vector<const Variable*> used_vars_;
- std::vector<const Variable*> write_vars_;
- const std::unordered_set<const Variable*>& touched_var_;
+ std::vector<const VarNode*> used_vars_;
+ std::vector<const VarNode*> write_vars_;
+ const std::unordered_set<const VarNode*>& touched_var_;
bool check_write_;
};
// Analyze if the buffers are invariant to value of var
class VarTouchedAnalysis : public StmtVisitor {
public:
- void VisitStmt_(const LetStmt* op) final {
+ void VisitStmt_(const LetStmtNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
Record(op->var.get(), tc);
this->VisitStmt(op->body);
}
- void VisitStmt_(const Store* op) final {
+ void VisitStmt_(const StoreNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->value);
tc(op->index);
Record(op->buffer_var.get(), tc);
}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
ExprTouched tc(touched_var_, false);
tc(op->min);
tc(op->extent);
this->VisitStmt(op->body);
}
// external function call
- void VisitStmt_(const Evaluate* op) final {
+ void VisitStmt_(const EvaluateNode* op) final {
ExprTouched tc(touched_var_, true);
tc(op->value);
- for (const Variable* var : tc.write_vars_) {
+ for (const VarNode* var : tc.write_vars_) {
Record(var, tc);
}
}
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
ExprTouched tc(touched_var_, false);
for (size_t i = 0; i < op->extents.size(); ++i) {
tc(op->extents[i]);
Record(op->buffer_var.get(), tc);
this->VisitStmt(op->body);
}
- void Record(const Variable* var,
+ void Record(const VarNode* var,
const ExprTouched& tc) {
if (touched_var_.count(var)) return;
if (tc.expr_touched_) {
touched_var_.insert(var);
} else {
- for (const Variable* r : tc.used_vars_) {
+ for (const VarNode* r : tc.used_vars_) {
if (r != var) {
affect_[r].push_back(var);
}
}
}
- std::unordered_set<const Variable*>
+ std::unordered_set<const VarNode*>
TouchedVar(const Stmt& stmt,
- const Variable* var) {
+ const VarNode* var) {
touched_var_.insert(var);
this->VisitStmt(stmt);
// do a DFS to push affect around dependency.
- std::vector<const Variable*> pending(
+ std::vector<const VarNode*> pending(
touched_var_.begin(), touched_var_.end());
while (!pending.empty()) {
- const Variable* v = pending.back();
+ const VarNode* v = pending.back();
pending.pop_back();
- for (const Variable* r : affect_[v]) {
+ for (const VarNode* r : affect_[v]) {
if (!touched_var_.count(r)) {
touched_var_.insert(r);
pending.push_back(r);
private:
// Whether variable is touched by the thread variable.
- std::unordered_set<const Variable*> touched_var_;
+ std::unordered_set<const VarNode*> touched_var_;
// x -> all the buffers x read from
- std::unordered_map<const Variable*,
- std::vector<const Variable*> > affect_;
+ std::unordered_map<const VarNode*,
+ std::vector<const VarNode*> > affect_;
};
// constructor
VTInjector(Var var,
int num_threads,
- const std::unordered_set<const Variable*>& touched_var,
+ const std::unordered_set<const VarNode*>& touched_var,
bool allow_share)
: var_(var), num_threads_(num_threads),
touched_var_(touched_var), allow_share_(allow_share) {
return stmt;
}
// Variable
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
CHECK(!alloc_remap_.count(op))
<< "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
return index + var_ * alloc_extent;
}
// Load
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return Load::make(op->dtype, op->buffer_var,
+ return LoadNode::make(op->dtype, op->buffer_var,
RewriteIndex(op->index, it->second),
op->predicate);
} else {
}
}
// Expression.
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
visit_touched_var_ = true;
Expr stride =
it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
- return Call::make(
+ return CallNode::make(
op->dtype, op->name,
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
return StmtExprMutator::VisitExpr_(op);
}
}
- Stmt VisitStmt_(const Evaluate* op) final {
+ Stmt VisitStmt_(const EvaluateNode* op) final {
trigger_base_inject_ = !allow_share_;
return StmtExprMutator::VisitStmt_(op);
}
// Store
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
trigger_base_inject_ = !allow_share_;
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return Store::make(op->buffer_var,
+ return StoreNode::make(op->buffer_var,
op->value,
RewriteIndex(op->index, it->second),
op->predicate);
}
}
// Attribute
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
Expr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return AttrStmt::make(op->node, op->attr_key, value, body);
+ return AttrStmtNode::make(op->node, op->attr_key, value, body);
}
}
}
// LetStmt
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
Expr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return LetStmt::make(op->var, value, body);
+ return LetStmtNode::make(op->var, value, body);
}
}
// For
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
CHECK(is_zero(op->min));
Expr extent = this->VisitExpr(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return For::make(
+ return ForNode::make(
op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
- Stmt VisitStmt_(const IfThenElse* op) final {
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
Expr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElse::make(condition, then_case, else_case);
+ return IfThenElseNode::make(condition, then_case, else_case);
}
}
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
// Allocate
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
- Expr stride = arith::ComputeReduce<Mul>(
+ Expr stride = arith::ComputeReduce<MulNode>(
op->extents, Expr()) * op->dtype.lanes();
Array<Expr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var, op->dtype,
extents, condition, body,
op->new_expr, op->free_function);
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, Expr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return For::make(idx, make_zero(idx.dtype()),
+ return ForNode::make(idx, make_zero(idx.dtype()),
make_const(idx.dtype(), num_threads_),
ForType::Serial, DeviceAPI::None, stmt);
}
// the counter of loops in after mutation.
int max_loop_depth_{0};
// The variables that get touched.
- const std::unordered_set<const Variable*>& touched_var_;
+ const std::unordered_set<const VarNode*>& touched_var_;
// Whether allow shareding.
bool allow_share_;
// The allocations that get touched -> extent
- std::unordered_map<const Variable*, Expr> alloc_remap_;
+ std::unordered_map<const VarNode*, Expr> alloc_remap_;
};
class VirtualThreadInjector : public StmtMutator {
public:
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<AttrStmt>();
+ op = stmt.as<AttrStmtNode>();
if (op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
bool allow_share = iv->thread_tag == "vthread";
- int nthread = static_cast<int>(op->value.as<IntImm>()->value);
+ int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
VarTouchedAnalysis vs;
auto touched = vs.TouchedVar(op->body, iv->var.get());
VTInjector injecter(iv->var, nthread, touched, allow_share);
}
}
- Stmt VisitStmt_(const Provide* op) final {
+ Stmt VisitStmt_(const ProvideNode* op) final {
LOG(FATAL) << "Need to call StorageFlatten first";
return GetRef<Stmt>(op);
}
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
}
if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
- expr = Let::make(args_[i], op->args[i], expr);
+ expr = LetNode::make(args_[i], op->args[i], expr);
}
} else {
Map<Var, Expr> vmap;
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
- Evaluate::make(expr), vmap).as<Evaluate>()->value;
+ EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
}
return expr;
} else {
StmtComparator::VisitStmt(n, other);
}
// Stmt
- void VisitStmt_(const LetStmt* op, const Stmt& other) final {
- const LetStmt* rhs = other.as<LetStmt>();
+ void VisitStmt_(const LetStmtNode* op, const Stmt& other) final {
+ const LetStmtNode* rhs = other.as<LetStmtNode>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const AttrStmt* op, const Stmt& other) final {
- const AttrStmt* rhs = other.as<AttrStmt>();
+ void VisitStmt_(const AttrStmtNode* op, const Stmt& other) final {
+ const AttrStmtNode* rhs = other.as<AttrStmtNode>();
if (CompareString(op->attr_key, rhs->attr_key) != 0) return;
if (CompareNodeRef(op->node, rhs->node) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const IfThenElse* op, const Stmt& other) final {
- const IfThenElse* rhs = other.as<IfThenElse>();
+ void VisitStmt_(const IfThenElseNode* op, const Stmt& other) final {
+ const IfThenElseNode* rhs = other.as<IfThenElseNode>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareStmt(op->then_case, rhs->then_case) != 0) return;
if (CompareStmt(op->else_case, rhs->else_case) != 0) return;
}
- void VisitStmt_(const For* op, const Stmt& other) final {
- const For* rhs = other.as<For>();
+ void VisitStmt_(const ForNode* op, const Stmt& other) final {
+ const ForNode* rhs = other.as<ForNode>();
if (CompareExpr(op->min, rhs->min) != 0) return;
if (CompareExpr(op->extent, rhs->extent) != 0) return;
if (tie_def_) {
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const Allocate* op, const Stmt& other) final {
- const Allocate* rhs = other.as<Allocate>();
+ void VisitStmt_(const AllocateNode* op, const Stmt& other) final {
+ const AllocateNode* rhs = other.as<AllocateNode>();
if (tie_def_) {
vmap_[op->buffer_var.get()] = rhs->buffer_var.get();
} else {
if (CompareString(op->free_function, rhs->free_function) != 0) return;
}
- void VisitStmt_(const Store* op, const Stmt& other) final {
- const Store* rhs = other.as<Store>();
+ void VisitStmt_(const StoreNode* op, const Stmt& other) final {
+ const StoreNode* rhs = other.as<StoreNode>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
- void VisitStmt_(const Free* op, const Stmt& other) final {
- const Free* rhs = other.as<Free>();
+ void VisitStmt_(const FreeNode* op, const Stmt& other) final {
+ const FreeNode* rhs = other.as<FreeNode>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
}
- void VisitStmt_(const AssertStmt* op, const Stmt& other) final {
- const AssertStmt* rhs = other.as<AssertStmt>();
+ void VisitStmt_(const AssertStmtNode* op, const Stmt& other) final {
+ const AssertStmtNode* rhs = other.as<AssertStmtNode>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->message, rhs->message) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
- const ProducerConsumer* rhs = other.as<ProducerConsumer>();
+ void VisitStmt_(const ProducerConsumerNode* op, const Stmt& other) final {
+ const ProducerConsumerNode* rhs = other.as<ProducerConsumerNode>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->is_producer, rhs->is_producer) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const Provide* op, const Stmt& other) final {
- const Provide* rhs = other.as<Provide>();
+ void VisitStmt_(const ProvideNode* op, const Stmt& other) final {
+ const ProvideNode* rhs = other.as<ProvideNode>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareArray(op->args, rhs->args) != 0) return;
}
- void VisitStmt_(const Realize* op, const Stmt& other) final {
- const Realize* rhs = other.as<Realize>();
+ void VisitStmt_(const RealizeNode* op, const Stmt& other) final {
+ const RealizeNode* rhs = other.as<RealizeNode>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->dtype, rhs->dtype) != 0) return;
if (CompareStmt(op->body, rhs->body) != 0) return;
}
- void VisitStmt_(const Prefetch* op, const Stmt& other) final {
- const Prefetch* rhs = other.as<Prefetch>();
+ void VisitStmt_(const PrefetchNode* op, const Stmt& other) final {
+ const PrefetchNode* rhs = other.as<PrefetchNode>();
if (CompareNodeRef(op->func, rhs->func) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareType(op->dtype, rhs->dtype) != 0) return;
}
}
- void VisitStmt_(const Evaluate* op, const Stmt& other) final {
- const Evaluate* rhs = other.as<Evaluate>();
+ void VisitStmt_(const EvaluateNode* op, const Stmt& other) final {
+ const EvaluateNode* rhs = other.as<EvaluateNode>();
CompareExpr(op->value, rhs->value);
}
// Exprs
- void VisitExpr_(const Variable* op, const Expr& other) final {
- const Variable* rhs = other.as<Variable>();
+ void VisitExpr_(const VarNode* op, const Expr& other) final {
+ const VarNode* rhs = other.as<VarNode>();
auto it = vmap_.find(op);
if (it != vmap_.end()) op = it->second;
if (op < rhs) {
order_ = +1;
}
}
- void VisitExpr_(const Load* op, const Expr& other) final {
- const Load* rhs = other.as<Load>();
+ void VisitExpr_(const LoadNode* op, const Expr& other) final {
+ const LoadNode* rhs = other.as<LoadNode>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
- void VisitExpr_(const Let* op, const Expr& other) final {
- const Let* rhs = other.as<Let>();
+ void VisitExpr_(const LetNode* op, const Expr& other) final {
+ const LetNode* rhs = other.as<LetNode>();
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
} else {
if (CompareExpr(op->body, rhs->body) != 0) return;
}
- void VisitExpr_(const Call* op, const Expr& other) final {
- const Call* rhs = other.as<Call>();
+ void VisitExpr_(const CallNode* op, const Expr& other) final {
+ const CallNode* rhs = other.as<CallNode>();
if (CompareString(op->name, rhs->name)) return;
if (CompareArray(op->args, rhs->args)) return;
if (CompareValue(op->call_type, rhs->call_type) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
}
- void VisitExpr_(const Reduce *op, const Expr& other) final {
- const Reduce* rhs = other.as<Reduce>();
+ void VisitExpr_(const ReduceNode *op, const Expr& other) final {
+ const ReduceNode* rhs = other.as<ReduceNode>();
if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
if (CompareArray(op->source, rhs->source) != 0) return;
}
- void VisitExpr_(const IntImm *op, const Expr& other) final {
- CompareValue(op->value, other.as<IntImm>()->value);
+ void VisitExpr_(const IntImmNode *op, const Expr& other) final {
+ CompareValue(op->value, other.as<IntImmNode>()->value);
}
- void VisitExpr_(const UIntImm *op, const Expr& other) final {
- CompareValue(op->value, other.as<UIntImm>()->value);
+ void VisitExpr_(const UIntImmNode *op, const Expr& other) final {
+ CompareValue(op->value, other.as<UIntImmNode>()->value);
}
- void VisitExpr_(const FloatImm *op, const Expr& other) final {
- CompareValue(op->value, other.as<FloatImm>()->value);
+ void VisitExpr_(const FloatImmNode *op, const Expr& other) final {
+ CompareValue(op->value, other.as<FloatImmNode>()->value);
}
- void VisitExpr_(const StringImm *op, const Expr& other) final {
- CompareString(op->value, other.as<StringImm>()->value);
+ void VisitExpr_(const StringImmNode *op, const Expr& other) final {
+ CompareString(op->value, other.as<StringImmNode>()->value);
}
- void VisitExpr_(const Cast *op, const Expr& other) final {
- CompareExpr(op->value, other.as<Cast>()->value);
+ void VisitExpr_(const CastNode *op, const Expr& other) final {
+ CompareExpr(op->value, other.as<CastNode>()->value);
}
- void VisitExpr_(const Not *op, const Expr& other) final {
- CompareExpr(op->a, other.as<Not>()->a);
+ void VisitExpr_(const NotNode *op, const Expr& other) final {
+ CompareExpr(op->a, other.as<NotNode>()->a);
}
- void VisitExpr_(const Select *op, const Expr& other) final {
- const Select* rhs = other.as<Select>();
+ void VisitExpr_(const SelectNode *op, const Expr& other) final {
+ const SelectNode* rhs = other.as<SelectNode>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
}
- void VisitExpr_(const Ramp *op, const Expr& other) final {
- const Ramp* rhs = other.as<Ramp>();
+ void VisitExpr_(const RampNode *op, const Expr& other) final {
+ const RampNode* rhs = other.as<RampNode>();
if (CompareExpr(op->base, rhs->base) != 0) return;
if (CompareExpr(op->stride, rhs->stride) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
- void VisitExpr_(const Broadcast *op, const Expr& other) final {
- const Broadcast* rhs = other.as<Broadcast>();
+ void VisitExpr_(const BroadcastNode *op, const Expr& other) final {
+ const BroadcastNode* rhs = other.as<BroadcastNode>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
- void VisitExpr_(const Shuffle *op, const Expr& other) final {
- const Shuffle* rhs = other.as<Shuffle>();
+ void VisitExpr_(const ShuffleNode *op, const Expr& other) final {
+ const ShuffleNode* rhs = other.as<ShuffleNode>();
if (CompareArray(op->vectors, rhs->vectors) != 0) return;
if (CompareArray(op->indices, rhs->indices) != 0) return;
}
- DEFINE_BIOP_EXPR_CMP_(Add)
- DEFINE_BIOP_EXPR_CMP_(Sub)
- DEFINE_BIOP_EXPR_CMP_(Mul)
- DEFINE_BIOP_EXPR_CMP_(Div)
- DEFINE_BIOP_EXPR_CMP_(Mod)
- DEFINE_BIOP_EXPR_CMP_(FloorDiv)
- DEFINE_BIOP_EXPR_CMP_(FloorMod)
- DEFINE_BIOP_EXPR_CMP_(Min)
- DEFINE_BIOP_EXPR_CMP_(Max)
- DEFINE_BIOP_EXPR_CMP_(EQ)
- DEFINE_BIOP_EXPR_CMP_(NE)
- DEFINE_BIOP_EXPR_CMP_(LT)
- DEFINE_BIOP_EXPR_CMP_(LE)
- DEFINE_BIOP_EXPR_CMP_(GT)
- DEFINE_BIOP_EXPR_CMP_(GE)
- DEFINE_BIOP_EXPR_CMP_(And)
- DEFINE_BIOP_EXPR_CMP_(Or)
+ DEFINE_BIOP_EXPR_CMP_(AddNode)
+ DEFINE_BIOP_EXPR_CMP_(SubNode)
+ DEFINE_BIOP_EXPR_CMP_(MulNode)
+ DEFINE_BIOP_EXPR_CMP_(DivNode)
+ DEFINE_BIOP_EXPR_CMP_(ModNode)
+ DEFINE_BIOP_EXPR_CMP_(FloorDivNode)
+ DEFINE_BIOP_EXPR_CMP_(FloorModNode)
+ DEFINE_BIOP_EXPR_CMP_(MinNode)
+ DEFINE_BIOP_EXPR_CMP_(MaxNode)
+ DEFINE_BIOP_EXPR_CMP_(EQNode)
+ DEFINE_BIOP_EXPR_CMP_(NENode)
+ DEFINE_BIOP_EXPR_CMP_(LTNode)
+ DEFINE_BIOP_EXPR_CMP_(LENode)
+ DEFINE_BIOP_EXPR_CMP_(GTNode)
+ DEFINE_BIOP_EXPR_CMP_(GENode)
+ DEFINE_BIOP_EXPR_CMP_(AndNode)
+ DEFINE_BIOP_EXPR_CMP_(OrNode)
private:
int CompareExpr(const Expr& lhs, const Expr& rhs) {
// Only equality/non-equality information is valid.
bool tie_def_{false};
// varaible remap if any
- std::unordered_map<const Variable*, const Variable*> vmap_;
+ std::unordered_map<const VarNode*, const VarNode*> vmap_;
};
const Array<Expr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
for (Expr s : only_enable) {
- only_type_index.insert(Object::TypeKey2Index(s.as<StringImm>()->value.c_str()));
+ only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
}
}
-void StmtVisitor::VisitStmt_(const LetStmt* op) {
+void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
this->VisitExpr(op->value);
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const AttrStmt* op) {
+void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
this->VisitExpr(op->value);
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const For* op) {
+void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitExpr(op->min);
this->VisitExpr(op->extent);
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const Allocate* op) {
+void StmtVisitor::VisitStmt_(const AllocateNode* op) {
VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
}
}
-void StmtVisitor::VisitStmt_(const Store* op) {
+void StmtVisitor::VisitStmt_(const StoreNode* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
}
-void StmtVisitor::VisitStmt_(const IfThenElse* op) {
+void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
if (op->else_case.defined()) {
}
}
-void StmtVisitor::VisitStmt_(const Free* op) {}
+void StmtVisitor::VisitStmt_(const FreeNode* op) {}
-void StmtVisitor::VisitStmt_(const AssertStmt* op) {
+void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->message);
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const ProducerConsumer* op) {
+void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) {
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const Provide* op) {
+void StmtVisitor::VisitStmt_(const ProvideNode* op) {
VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
this->VisitExpr(op->value);
}
-void StmtVisitor::VisitStmt_(const Realize* op) {
+void StmtVisitor::VisitStmt_(const RealizeNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
this->VisitExpr(op->condition);
}
-void StmtVisitor::VisitStmt_(const Prefetch* op) {
+void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
VisitArray(op->bounds, [this](const Range& r) {
this->VisitExpr(r->min);
this->VisitExpr(r->extent);
});
}
-void StmtVisitor::VisitStmt_(const Evaluate* op) {
+void StmtVisitor::VisitStmt_(const EvaluateNode* op) {
this->VisitExpr(op->value);
}
-void ExprVisitor::VisitExpr_(const Variable* op) {}
+void ExprVisitor::VisitExpr_(const VarNode* op) {}
-void ExprVisitor::VisitExpr_(const Load* op) {
+void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
}
-void ExprVisitor::VisitExpr_(const Let* op) {
+void ExprVisitor::VisitExpr_(const LetNode* op) {
this->VisitExpr(op->value);
this->VisitExpr(op->body);
}
-void ExprVisitor::VisitExpr_(const Call* op) {
+void ExprVisitor::VisitExpr_(const CallNode* op) {
VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
}
this->VisitExpr(op->b); \
}
-DEFINE_BINOP_VISIT_(Add);
-DEFINE_BINOP_VISIT_(Sub);
-DEFINE_BINOP_VISIT_(Mul);
-DEFINE_BINOP_VISIT_(Div);
-DEFINE_BINOP_VISIT_(Mod);
-DEFINE_BINOP_VISIT_(FloorDiv);
-DEFINE_BINOP_VISIT_(FloorMod);
-DEFINE_BINOP_VISIT_(Min);
-DEFINE_BINOP_VISIT_(Max);
-DEFINE_BINOP_VISIT_(EQ);
-DEFINE_BINOP_VISIT_(NE);
-DEFINE_BINOP_VISIT_(LT);
-DEFINE_BINOP_VISIT_(LE);
-DEFINE_BINOP_VISIT_(GT);
-DEFINE_BINOP_VISIT_(GE);
-DEFINE_BINOP_VISIT_(And);
-DEFINE_BINOP_VISIT_(Or);
-
-void ExprVisitor::VisitExpr_(const IntImm* op) {}
-void ExprVisitor::VisitExpr_(const UIntImm* op) {}
-void ExprVisitor::VisitExpr_(const FloatImm* op) {}
-void ExprVisitor::VisitExpr_(const StringImm* op) {}
-
-void ExprVisitor::VisitExpr_(const Reduce* op) {
+DEFINE_BINOP_VISIT_(AddNode);
+DEFINE_BINOP_VISIT_(SubNode);
+DEFINE_BINOP_VISIT_(MulNode);
+DEFINE_BINOP_VISIT_(DivNode);
+DEFINE_BINOP_VISIT_(ModNode);
+DEFINE_BINOP_VISIT_(FloorDivNode);
+DEFINE_BINOP_VISIT_(FloorModNode);
+DEFINE_BINOP_VISIT_(MinNode);
+DEFINE_BINOP_VISIT_(MaxNode);
+DEFINE_BINOP_VISIT_(EQNode);
+DEFINE_BINOP_VISIT_(NENode);
+DEFINE_BINOP_VISIT_(LTNode);
+DEFINE_BINOP_VISIT_(LENode);
+DEFINE_BINOP_VISIT_(GTNode);
+DEFINE_BINOP_VISIT_(GENode);
+DEFINE_BINOP_VISIT_(AndNode);
+DEFINE_BINOP_VISIT_(OrNode);
+
+void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
+void ExprVisitor::VisitExpr_(const UIntImmNode* op) {}
+void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
+void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
+
+void ExprVisitor::VisitExpr_(const ReduceNode* op) {
VisitArray(op->axis, [this](const IterVar& r) {
this->VisitExpr(r->dom->min);
this->VisitExpr(r->dom->extent);
this->VisitExpr(op->condition);
}
-void ExprVisitor::VisitExpr_(const Cast* op) {
+void ExprVisitor::VisitExpr_(const CastNode* op) {
this->VisitExpr(op->value);
}
-void ExprVisitor::VisitExpr_(const Not* op) {
+void ExprVisitor::VisitExpr_(const NotNode* op) {
this->VisitExpr(op->a);
}
-void ExprVisitor::VisitExpr_(const Select* op) {
+void ExprVisitor::VisitExpr_(const SelectNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->true_value);
this->VisitExpr(op->false_value);
}
-void ExprVisitor::VisitExpr_(const Ramp* op) {
+void ExprVisitor::VisitExpr_(const RampNode* op) {
this->VisitExpr(op->base);
this->VisitExpr(op->stride);
}
-void ExprVisitor::VisitExpr_(const Shuffle* op) {
+void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); });
VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); });
}
-void ExprVisitor::VisitExpr_(const Broadcast* op) {
+void ExprVisitor::VisitExpr_(const BroadcastNode* op) {
this->VisitExpr(op->value);
}
}
};
-Stmt StmtMutator::VisitStmt_(const AttrStmt* op) {
+Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
Expr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
}
}
-Stmt StmtMutator::VisitStmt_(const LetStmt* op) {
+Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
Expr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
}
}
-Stmt StmtMutator::VisitStmt_(const For* op) {
+Stmt StmtMutator::VisitStmt_(const ForNode* op) {
Expr min = this->VisitExpr(op->min);
Expr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
}
}
-Stmt StmtMutator::VisitStmt_(const Allocate* op) {
+Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
Array<Expr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
Expr condition = this->VisitExpr(op->condition);
}
}
-Stmt StmtMutator::VisitStmt_(const IfThenElse* op) {
+Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
Expr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
}
}
-Stmt StmtMutator::VisitStmt_(const Store* op) {
+Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
Expr value = this->VisitExpr(op->value);
Expr index = this->VisitExpr(op->index);
Expr predicate = this->VisitExpr(op->predicate);
}
}
-Stmt StmtMutator::VisitStmt_(const Provide* op) {
+Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<Expr> args = Internal::Mutate(this, op->args);
Expr value = this->VisitExpr(op->value);
if (args.same_as(op->args) &&
}
}
-Stmt StmtMutator::VisitStmt_(const Realize* op) {
+Stmt StmtMutator::VisitStmt_(const RealizeNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
Stmt body = this->VisitStmt(op->body);
Expr condition = this->VisitExpr(op->condition);
}
}
-Stmt StmtMutator::VisitStmt_(const Prefetch* op) {
+Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
if (bounds.same_as(op->bounds)) {
return GetRef<Stmt>(op);
}
}
-Stmt StmtMutator::VisitStmt_(const AssertStmt* op) {
+Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
Expr condition = this->VisitExpr(op->condition);
Expr message = this->VisitExpr(op->message);
Stmt body = this->VisitStmt(op->body);
}
}
-Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) {
+Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) {
Stmt body = this->VisitStmt(op->body);
if (body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
}
-Stmt StmtMutator::VisitStmt_(const Evaluate* op) {
+Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Stmt>(op);
}
}
-Stmt StmtMutator::VisitStmt_(const Free* op) {
+Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
return GetRef<Stmt>(op);
}
-Expr ExprMutator::VisitExpr_(const Variable* op) {
+Expr ExprMutator::VisitExpr_(const VarNode* op) {
return GetRef<Expr>(op);
}
-Expr ExprMutator::VisitExpr_(const Load* op) {
+Expr ExprMutator::VisitExpr_(const LoadNode* op) {
Expr index = this->VisitExpr(op->index);
Expr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
return GetRef<Expr>(op);
} else {
- return Load::make(op->dtype, op->buffer_var, index, predicate);
+ return LoadNode::make(op->dtype, op->buffer_var, index, predicate);
}
}
-Expr ExprMutator::VisitExpr_(const Let* op) {
+Expr ExprMutator::VisitExpr_(const LetNode* op) {
Expr value = this->VisitExpr(op->value);
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Let::make(op->var, value, body);
+ return LetNode::make(op->var, value, body);
}
}
-Expr ExprMutator::VisitExpr_(const Call* op) {
+Expr ExprMutator::VisitExpr_(const CallNode* op) {
auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); };
Array<Expr> args = MutateArray(op->args, fmutate);
if (args.same_as(op->args)) {
return GetRef<Expr>(op);
} else {
- return Call::make(op->dtype,
+ return CallNode::make(op->dtype,
op->name,
args,
op->call_type,
return GetRef<Expr>(op); \
}
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
Expr ExprMutator::VisitExpr_(const OP* op) { \
} \
}
-DEFINE_BIOP_EXPR_MUTATE_(Add);
-DEFINE_BIOP_EXPR_MUTATE_(Sub);
-DEFINE_BIOP_EXPR_MUTATE_(Mul);
-DEFINE_BIOP_EXPR_MUTATE_(Div);
-DEFINE_BIOP_EXPR_MUTATE_(Mod);
-DEFINE_BIOP_EXPR_MUTATE_(FloorDiv);
-DEFINE_BIOP_EXPR_MUTATE_(FloorMod);
-DEFINE_BIOP_EXPR_MUTATE_(Min);
-DEFINE_BIOP_EXPR_MUTATE_(Max);
-DEFINE_BIOP_EXPR_MUTATE_(EQ);
-DEFINE_BIOP_EXPR_MUTATE_(NE);
-DEFINE_BIOP_EXPR_MUTATE_(LT);
-DEFINE_BIOP_EXPR_MUTATE_(LE);
-DEFINE_BIOP_EXPR_MUTATE_(GT);
-DEFINE_BIOP_EXPR_MUTATE_(GE);
-DEFINE_BIOP_EXPR_MUTATE_(And);
-DEFINE_BIOP_EXPR_MUTATE_(Or);
-
-Expr ExprMutator::VisitExpr_(const Reduce* op) {
+DEFINE_BIOP_EXPR_MUTATE_(AddNode);
+DEFINE_BIOP_EXPR_MUTATE_(SubNode);
+DEFINE_BIOP_EXPR_MUTATE_(MulNode);
+DEFINE_BIOP_EXPR_MUTATE_(DivNode);
+DEFINE_BIOP_EXPR_MUTATE_(ModNode);
+DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode);
+DEFINE_BIOP_EXPR_MUTATE_(FloorModNode);
+DEFINE_BIOP_EXPR_MUTATE_(MinNode);
+DEFINE_BIOP_EXPR_MUTATE_(MaxNode);
+DEFINE_BIOP_EXPR_MUTATE_(EQNode);
+DEFINE_BIOP_EXPR_MUTATE_(NENode);
+DEFINE_BIOP_EXPR_MUTATE_(LTNode);
+DEFINE_BIOP_EXPR_MUTATE_(LENode);
+DEFINE_BIOP_EXPR_MUTATE_(GTNode);
+DEFINE_BIOP_EXPR_MUTATE_(GENode);
+DEFINE_BIOP_EXPR_MUTATE_(AndNode);
+DEFINE_BIOP_EXPR_MUTATE_(OrNode);
+
+Expr ExprMutator::VisitExpr_(const ReduceNode* op) {
auto fitervar = [this](const IterVar& v) {
Range r = v->dom;
Expr min = this->VisitExpr(r->min);
condition.same_as(op->condition)) {
return GetRef<Expr>(op);
} else {
- return Reduce::make(
+ return ReduceNode::make(
op->combiner, source, axis, condition, op->value_index);
}
}
-Expr ExprMutator::VisitExpr_(const Cast* op) {
+Expr ExprMutator::VisitExpr_(const CastNode* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return Cast::make(op->dtype, value);
+ return CastNode::make(op->dtype, value);
}
}
-Expr ExprMutator::VisitExpr_(const Not* op) {
+Expr ExprMutator::VisitExpr_(const NotNode* op) {
Expr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<Expr>(op);
} else {
- return Not::make(a);
+ return NotNode::make(a);
}
}
-Expr ExprMutator::VisitExpr_(const Select* op) {
+Expr ExprMutator::VisitExpr_(const SelectNode* op) {
Expr condition = this->VisitExpr(op->condition);
Expr true_value = this->VisitExpr(op->true_value);
Expr false_value = this->VisitExpr(op->false_value);
false_value.same_as(op->false_value)) {
return GetRef<Expr>(op);
} else {
- return Select::make(condition, true_value, false_value);
+ return SelectNode::make(condition, true_value, false_value);
}
}
-Expr ExprMutator::VisitExpr_(const Ramp* op) {
+Expr ExprMutator::VisitExpr_(const RampNode* op) {
Expr base = this->VisitExpr(op->base);
Expr stride = this->VisitExpr(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return GetRef<Expr>(op);
} else {
- return Ramp::make(base, stride, op->lanes);
+ return RampNode::make(base, stride, op->lanes);
}
}
-Expr ExprMutator::VisitExpr_(const Broadcast* op) {
+Expr ExprMutator::VisitExpr_(const BroadcastNode* op) {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return Broadcast::make(value, op->lanes);
+ return BroadcastNode::make(value, op->lanes);
}
}
-Expr ExprMutator::VisitExpr_(const Shuffle* op) {
+Expr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
if (vectors.same_as(op->vectors)) {
return GetRef<Expr>(op);
} else {
- return Shuffle::make(vectors, op->indices);
+ return ShuffleNode::make(vectors, op->indices);
}
}
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
- if (const auto* for_ = s.as<For>()) {
- auto n = make_object<For>(*for_);
+ if (const auto* for_ = s.as<ForNode>()) {
+ auto n = make_object<ForNode>(*for_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
- } else if (const auto* let = s.as<LetStmt>()) {
- auto n = make_object<LetStmt>(*let);
+ } else if (const auto* let = s.as<LetStmtNode>()) {
+ auto n = make_object<LetStmtNode>(*let);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
- } else if (const auto* attr = s.as<AttrStmt>()) {
- auto n = make_object<AttrStmt>(*attr);
+ } else if (const auto* attr = s.as<AttrStmtNode>()) {
+ auto n = make_object<AttrStmtNode>(*attr);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
- } else if (const auto* ite = s.as<IfThenElse>()) {
- auto n = make_object<IfThenElse>(*ite);
+ } else if (const auto* ite = s.as<IfThenElseNode>()) {
+ auto n = make_object<IfThenElseNode>(*ite);
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
n->seq.Set(n->size() - 1, body);
body = Stmt(n);
- } else if (const auto* assert_ = s.as<AssertStmt>()) {
- auto n = make_object<AssertStmt>(*assert_);
+ } else if (const auto* assert_ = s.as<AssertStmtNode>()) {
+ auto n = make_object<AssertStmtNode>(*assert_);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
- } else if (const auto* alloc = s.as<Allocate>()) {
- auto n = make_object<Allocate>(*alloc);
+ } else if (const auto* alloc = s.as<AllocateNode>()) {
+ auto n = make_object<AllocateNode>(*alloc);
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
handle,
make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind))};
- return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
+ return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic);
}
/*!
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
- return Call::make(
+ return CallNode::make(
DataType::Handle(), intrinsic::tvm_address_of,
- {Load::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
+ {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
const_true(dtype.lanes()))},
- Call::PureIntrinsic);
+ CallNode::PureIntrinsic);
}
/*!
inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
if (dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
- offset = Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
+ offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
}
- return Call::make(
+ return CallNode::make(
DataType::Handle(), intrinsic::tvm_address_of,
- {Load::make(dtype, handle, offset,
+ {LoadNode::make(dtype, handle, offset,
const_true(dtype.lanes()))},
- Call::PureIntrinsic);
+ CallNode::PureIntrinsic);
}
/*!
make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind)),
value};
- return Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
+ return EvaluateNode::make(
+ CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
}
/*!
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
- const Ramp* r = index.as<Ramp>();
+ const RampNode* r = index.as<RampNode>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
Stmt Lift(Stmt stmt) {
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
// do not go beyond
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
if (attr_node_.defined()) {
- Stmt body = AttrStmt::make(
+ Stmt body = AttrStmtNode::make(
attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = Expr();
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var, op->dtype,
op->extents, op->condition, body,
op->new_expr, op->free_function);
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr_key_) {
attr_node_ = op->node;
attr_value_ = op->value;
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
return SeqStmt::Flatten(reorg);
}
- Stmt VisitStmt_(const IfThenElse* op) final {
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
if (!op->else_case.defined()) {
return StmtMutator::VisitStmt_(op);
}
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElse::make(op->condition, then_case, else_case);
+ return IfThenElseNode::make(op->condition, then_case, else_case);
}
} else {
if (first_node.defined()) {
- then_case = AttrStmt::make(
+ then_case = AttrStmtNode::make(
first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
- else_case = AttrStmt::make(
+ else_case = AttrStmtNode::make(
attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElse::make(op->condition, then_case, else_case);
+ return IfThenElseNode::make(op->condition, then_case, else_case);
}
}
}
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
if (a.dtype() != b.dtype()) return false;
- if (const IntImm* op = a.as<IntImm>()) {
- return op->value == b.as<IntImm>()->value;
+ if (const IntImmNode* op = a.as<IntImmNode>()) {
+ return op->value == b.as<IntImmNode>()->value;
}
- if (const UIntImm* op = a.as<UIntImm>()) {
- return op->value == b.as<UIntImm>()->value;
+ if (const UIntImmNode* op = a.as<UIntImmNode>()) {
+ return op->value == b.as<UIntImmNode>()->value;
}
return false;
}
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
-bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
+bool ExprUseVars(Expr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
- if (const Variable* v = node.as<Variable>()) {
+ if (const VarNode* v = node.as<VarNode>()) {
if (vars.count(v)) {
success = true;
return;
explicit CandidateSelector(bool split_const_loop)
: split_const_loop_(split_const_loop) {}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
// partition const loop when sets split_const_loop_
if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
- const Variable* var = op->loop_var.get();
+ const VarNode* var = op->loop_var.get();
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const IterVarNode *iv = op->node.as<IterVarNode>();
CHECK(iv);
}
}
- void VisitExpr_(const Call* op) final {
- if (op->is_intrinsic(Call::likely)) {
+ void VisitExpr_(const CallNode* op) final {
+ if (op->is_intrinsic(CallNode::likely)) {
in_likely_ = true;
StmtExprVisitor::VisitExpr_(op);
in_likely_ = false;
}
}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
if (in_likely_ && record_.count(op)) {
record_.at(op) = true;
}
bool in_likely_{false};
bool no_split_{false};
bool split_const_loop_{false};
- std::unordered_map<const Variable*, VarIsUsed> record_;
+ std::unordered_map<const VarNode*, VarIsUsed> record_;
};
// Populate partitions data structure, i.e., for a specific variable,
class PartitionFinder : public StmtExprVisitor {
public:
explicit PartitionFinder(VarExpr current_var,
- const std::unordered_map<const Variable*, IntSet>& hint_map,
- const std::unordered_map<const Variable*, IntSet>& relax_map)
+ const std::unordered_map<const VarNode*, IntSet>& hint_map,
+ const std::unordered_map<const VarNode*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
for (const auto& kv : hint_map) {
out_vars_.insert(kv.first);
}
}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
- const Variable* var = op->loop_var.get();
+ const VarNode* var = op->loop_var.get();
hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
StmtExprVisitor::VisitStmt_(op);
hint_map_.erase(var);
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
// handle thread_axis
if (op->attr_key == attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
- const Variable* var = thread_axis->var.get();
+ const VarNode* var = thread_axis->var.get();
IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
hint_map_.insert({var, dom});
relax_map_.insert({var, dom});
}
}
- void VisitExpr_(const Call* op) final {
- if (op->is_intrinsic(Call::likely)) {
+ void VisitExpr_(const CallNode* op) final {
+ if (op->is_intrinsic(CallNode::likely)) {
Expr cond = op->args[0];
if (ExprUseVars(cond,
- std::unordered_set<const Variable*>({current_var_.get()}))) {
+ std::unordered_set<const VarNode*>({current_var_.get()}))) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// true. Also find the interval, if exists, in which we can prove that cond is
// false.
private:
Expr InverseCond(const Expr& cond) {
Expr inverse_cond;
- if (const LT* op = cond.as<LT>()) {
+ if (const LTNode* op = cond.as<LTNode>()) {
// a < b -> a >= b
- inverse_cond = GE::make(op->a, op->b);
- } else if (const GT* op = cond.as<GT>()) {
+ inverse_cond = GENode::make(op->a, op->b);
+ } else if (const GTNode* op = cond.as<GTNode>()) {
// a > b -> a <= b
- inverse_cond = LE::make(op->a, op->b);
- } else if (const LE* op = cond.as<LE>()) {
+ inverse_cond = LENode::make(op->a, op->b);
+ } else if (const LENode* op = cond.as<LENode>()) {
// a <= b -> a > b
- inverse_cond = GT::make(op->a, op->b);
- } else if (const GE* op = cond.as<GE>()) {
+ inverse_cond = GTNode::make(op->a, op->b);
+ } else if (const GENode* op = cond.as<GENode>()) {
// a >= b -> a < b
- inverse_cond = LT::make(op->a, op->b);
- } else if (const EQ* op = cond.as<EQ>()) {
+ inverse_cond = LTNode::make(op->a, op->b);
+ } else if (const EQNode* op = cond.as<EQNode>()) {
// a == b -> a != b
- inverse_cond = NE::make(op->a, op->b);
+ inverse_cond = NENode::make(op->a, op->b);
// a != b -> a == b
- } else if (const NE* op = cond.as<NE>()) {
- inverse_cond = EQ::make(op->a, op->b);
+ } else if (const NENode* op = cond.as<NENode>()) {
+ inverse_cond = EQNode::make(op->a, op->b);
}
return inverse_cond;
}
VarExpr current_var_;
- std::unordered_set<const Variable*> out_vars_;
- std::unordered_map<const Variable*, IntSet> hint_map_;
- std::unordered_map<const Variable*, IntSet> relax_map_;
+ std::unordered_set<const VarNode*> out_vars_;
+ std::unordered_map<const VarNode*, IntSet> hint_map_;
+ std::unordered_map<const VarNode*, IntSet> relax_map_;
};
// Replace the set of conditions given by ps with cond_value (true or false)
explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
innermost_thread_scope_ = true;
Stmt stmt = StmtMutator::VisitStmt_(op);
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_)(op->body);
- Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
+ Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
Expr value = this->VisitExpr(op->value);
- stmt = AttrStmt::make(op->node, op->attr_key, value, body);
+ stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
return stmt;
return operator()(std::move(stmt));
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
return res;
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key != attr::thread_extent) {
return StmtMutator::VisitStmt_(op);
}
inline Stmt MakeFor(const Object* op, Expr extent, Stmt body);
/* Candidate IRs that may be partitioned potentially */
- std::unordered_map<const Variable*, IntSet> hint_map_;
- std::unordered_map<const Variable*, IntSet> relax_map_;
+ std::unordered_map<const VarNode*, IntSet> hint_map_;
+ std::unordered_map<const VarNode*, IntSet> relax_map_;
arith::Analyzer analyzer_;
CandidateSelector selector;
};
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
- body_begin = Max::make(body_begin, min);
+ body_begin = MaxNode::make(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
- post_doubt_begin = Min::make(post_doubt_begin, max+1);
+ post_doubt_begin = MinNode::make(post_doubt_begin, max+1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
}
inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) {
- const For *for_node = static_cast<const For*>(node);
+ const ForNode *for_node = static_cast<const ForNode*>(node);
CHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
- return For::make(for_node->loop_var, 0, extent,
+ return ForNode::make(for_node->loop_var, 0, extent,
for_node->for_type, for_node->device_api, body);
}
}
class RemoveLikelyTags : public StmtExprMutator {
public:
- Expr VisitExpr_(const Call *op) final {
- if (op->is_intrinsic(Call::likely)) {
+ Expr VisitExpr_(const CallNode *op) final {
+ if (op->is_intrinsic(CallNode::likely)) {
CHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
- inline Expr VisitExpr_(const Cast* op) final {
+ inline Expr VisitExpr_(const CastNode* op) final {
auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Cast>();
+ op = expr.as<CastNode>();
if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
CHECK(lower) << "Cast lowering function for target " << target_ << " destination type "
return expr;
}
- inline Expr VisitExpr_(const FloatImm* imm) final {
+ inline Expr VisitExpr_(const FloatImmNode* imm) final {
auto type_code = imm->dtype.code();
auto e = GetRef<Expr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
return e;
}
- inline Stmt VisitStmt_(const Allocate* allocate) final {
+ inline Stmt VisitStmt_(const AllocateNode* allocate) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
- allocate = stmt.as<Allocate>();
+ allocate = stmt.as<AllocateNode>();
if (toBeLowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
- return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents,
+ return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
allocate->condition, allocate->body, allocate->new_expr,
allocate->free_function);
}
return stmt;
}
- inline Expr VisitExpr_(const Load* load) final {
+ inline Expr VisitExpr_(const LoadNode* load) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
Expr expr = StmtExprMutator::VisitExpr_(load);
- load = expr.as<Load>();
+ load = expr.as<LoadNode>();
if (toBeLowered) {
auto new_load_type = DataType::UInt(load->dtype.bits());
- return Load::make(new_load_type, load->buffer_var, load->index, load->predicate);
+ return LoadNode::make(new_load_type, load->buffer_var, load->index, load->predicate);
}
return expr;
}
-#define DEFINE_MUTATE__(OP) \
- inline Expr VisitExpr_(const OP* op) final { \
+#define DEFINE_MUTATE__(OP, NodeName) \
+ inline Expr VisitExpr_(const NodeName* op) final { \
auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
Expr expr = StmtExprMutator::VisitExpr_(op); \
- op = expr.as<OP>(); \
+ op = expr.as<NodeName>(); \
if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
CHECK(lower) << #OP " lowering function for target " << target_ << " type " \
return expr; \
}
- DEFINE_MUTATE__(Add)
- DEFINE_MUTATE__(Sub)
- DEFINE_MUTATE__(Mul)
- DEFINE_MUTATE__(Div)
- DEFINE_MUTATE__(Mod)
- DEFINE_MUTATE__(Min)
- DEFINE_MUTATE__(Max)
- DEFINE_MUTATE__(EQ)
- DEFINE_MUTATE__(NE)
- DEFINE_MUTATE__(LT)
- DEFINE_MUTATE__(LE)
- DEFINE_MUTATE__(GT)
- DEFINE_MUTATE__(GE)
+ DEFINE_MUTATE__(Add, AddNode);
+ DEFINE_MUTATE__(Sub, SubNode);
+ DEFINE_MUTATE__(Mul, MulNode);
+ DEFINE_MUTATE__(Div, DivNode);
+ DEFINE_MUTATE__(Mod, ModNode);
+ DEFINE_MUTATE__(Min, MinNode);
+ DEFINE_MUTATE__(Max, MaxNode);
+ DEFINE_MUTATE__(EQ, EQNode);
+ DEFINE_MUTATE__(NE, NENode);
+ DEFINE_MUTATE__(LT, LTNode);
+ DEFINE_MUTATE__(LE, LENode);
+ DEFINE_MUTATE__(GT, GTNode);
+ DEFINE_MUTATE__(GE, GENode);
// Later changes may need to add more mutate functions as we support workloads with more ops.
private:
}
}
- Expr VisitExpr_(const Call* op) final {
- if (op->call_type == Call::Intrinsic ||
- op->call_type == Call::PureIntrinsic) {
+ Expr VisitExpr_(const CallNode* op) final {
+ if (op->call_type == CallNode::Intrinsic ||
+ op->call_type == CallNode::PureIntrinsic) {
Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
if (r.defined()) return r;
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const Add* op) final {
- if (const Mul* mb = op->b.as<Mul>()) {
+ Expr VisitExpr_(const AddNode* op) final {
+ if (const MulNode* mb = op->b.as<MulNode>()) {
return MakeFMA(mb->a, mb->b, op->a, op);
- } else if (const Mul* ma = op->a.as<Mul>()) {
+ } else if (const MulNode* ma = op->a.as<MulNode>()) {
return MakeFMA(ma->a, ma->b, op->b, op);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
- Expr VisitExpr_(const FloorDiv* op) final {
+ Expr VisitExpr_(const FloorDivNode* op) final {
auto e = GetRef<Expr>(op);
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<FloorDiv>();
+ op = ret.as<FloorDivNode>();
if (op == nullptr) return ret;
int shift;
const DataType& dtype = op->dtype;
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
} else {
- return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
+ return ir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
}
}
} else {
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
Expr rdiv = truncdiv(op->a, op->b);
Expr rmod = truncmod(op->a, op->b);
- return ir::Select::make(
+ return ir::SelectNode::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1));
}
}
- Expr VisitExpr_(const FloorMod* op) final {
+ Expr VisitExpr_(const FloorModNode* op) final {
Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
- op = ret.as<FloorMod>();
+ op = ret.as<FloorModNode>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
int shift;
// -> rmod >= 0 ? 0 : b
return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
} else {
- return ir::Select::make(rmod >= 0, rmod, rmod + op->b);
+ return ir::SelectNode::make(rmod >= 0, rmod, rmod + op->b);
}
}
} else {
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
// b < 0 && rmod > 0 -> rmod + b
- return ir::Select::make(
+ return ir::SelectNode::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rmod, rmod + op->b);
}
}
- Expr VisitExpr_(const Max* op) final {
+ Expr VisitExpr_(const MaxNode* op) final {
using namespace arith;
PVar<Expr> x, y;
PVar<Integer> c;
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const EQ* op) final {
+ Expr VisitExpr_(const EQNode* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const NE* op) final {
+ Expr VisitExpr_(const NENode* op) final {
using namespace arith;
PVar<Expr> x, y;
auto e = GetRef<Expr>(op);
// For some targets, LLVM will generate more efficient FMA
// instruction with the latter. For example, vmla vs. vmlal
// on ARM.
- if (const Broadcast* bcast = e.as<Broadcast>()) {
- if (const Cast* cast = bcast->value.as<Cast>()) {
+ if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
+ if (const CastNode* cast = bcast->value.as<CastNode>()) {
auto should_swap = [&]() {
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
};
if (should_swap()) {
- Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
- return Cast::make(bcast->dtype, new_bcast);
+ Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
+ return CastNode::make(bcast->dtype, new_bcast);
}
}
}
}
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
- const Add* op) {
+ const AddNode* op) {
// emit fma instruction: a * b + c
Expr lhs = SwapBroadcastCast(a);
Expr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- Expr r = (*fma_)(Call::make(
- op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
+ Expr r = (*fma_)(CallNode::make(
+ op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
- Expr mul = this->VisitExpr(Mul::make(lhs, rhs));
- return Add::make(mul, this->VisitExpr(c));
+ Expr mul = this->VisitExpr(MulNode::make(lhs, rhs));
+ return AddNode::make(mul, this->VisitExpr(c));
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
explicit ThreadAllreduceBuilder(int warp_size)
: warp_size_(warp_size) {}
- Stmt VisitStmt_(const AttrStmt *op) final {
+ Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
return ret;
} else if (op->attr_key == attr::storage_scope) {
Stmt ret = StmtExprMutator::VisitStmt_(op);
- op = ret.as<AttrStmt>();
- const Variable* v = op->node.as<Variable>();
+ op = ret.as<AttrStmtNode>();
+ const VarNode* v = op->node.as<VarNode>();
if (alloc_remap_.count(v)) {
return op->body;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const Evaluate* op) final {
+ Stmt VisitStmt_(const EvaluateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Evaluate>();
- const Call* call = op->value.as<Call>();
+ op = stmt.as<EvaluateNode>();
+ const CallNode* call = op->value.as<CallNode>();
if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
return MakeAllreduce(call);
} else {
return stmt;
}
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- const Allocate* repl = it->second.as<Allocate>();
+ const AllocateNode* repl = it->second.as<AllocateNode>();
// use volatile access to shared buffer.
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
repl->buffer_var, attr::volatile_scope, 1, op->body);
- stmt = Allocate::make(
+ stmt = AllocateNode::make(
repl->buffer_var, repl->dtype,
repl->extents, repl->condition, stmt);
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
repl->buffer_var, attr::storage_scope,
- StringImm::make("shared"), stmt);
+ StringImmNode::make("shared"), stmt);
return stmt;
} else {
return stmt;
}
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index));
}
};
// make allreduce.
- Stmt MakeAllreduce(const Call* call) {
+ Stmt MakeAllreduce(const CallNode* call) {
CHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
- const UIntImm *size_of_args = call->args[0].as<UIntImm>();
+ const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
CHECK(size_of_args) << call->args[0]->GetTypeKey();
CHECK_EQ(size, size_of_args->value);
Array<Expr> inits = combiner->identity_element;
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
- values[idx] = Select::make(cond, values[idx], inits[idx]);
+ values[idx] = SelectNode::make(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].dtype();
}
- std::vector<const Variable*> buffers(size);
+ std::vector<const VarNode*> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
- const Variable* buffer = call->args[2+size+idx].as<Variable>();
+ const VarNode* buffer = call->args[2+size+idx].as<VarNode>();
CHECK(buffer);
buffers[idx] = buffer;
}
- std::unordered_set<const Variable*> reduce_set;
+ std::unordered_set<const VarNode*> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
- const Variable* v = call->args[i].as<Variable>();
+ const VarNode* v = call->args[i].as<VarNode>();
CHECK(v);
reduce_set.insert(v);
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
- for (const AttrStmt* attr : thread_extents_) {
+ for (const AttrStmtNode* attr : thread_extents_) {
ThreadEntry e;
IterVar iv = Downcast<IterVar>(attr->node);
e.scope = runtime::ThreadScope::make(iv->thread_tag);
for (size_t i = 0; i < size; ++i) {
Expr pred = const_true(types[i].lanes());
Var buffer_var = Downcast<Var>(call->args[2+size+i]);
- stores[i] = Store::make(buffer_var, values[i], 0, pred);
+ stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
}
return SeqStmt::Flatten(stores);
}
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
Expr pred = const_true(types[idx].lanes());
- seq.emplace_back(Store::make(
+ seq.emplace_back(StoreNode::make(
shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
}
for (size_t idx = 0; idx < size; ++idx) {
CHECK(!load_remap_.count(buffers[idx]));
Expr pred = const_true(types[idx].lanes());
- load_remap_[buffers[idx]] = Load::make(
+ load_remap_[buffers[idx]] = LoadNode::make(
types[idx], shared_bufs[idx],
BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
- alloc_remap_[buffers[idx]] = Allocate::make(
+ alloc_remap_[buffers[idx]] = AllocateNode::make(
shared_bufs[idx], types[idx],
{Expr(group_extent), Expr(reduce_extent)},
- pred, Evaluate::make(0));
+ pred, EvaluateNode::make(0));
}
return SeqStmt::Flatten(seq);
}
auto freduce = [&](int offset) {
Array<Expr> a, b;
for (size_t i = 0; i < size; ++i) {
- b.push_back(Load::make(types[i], shared_bufs[i],
+ b.push_back(LoadNode::make(types[i], shared_bufs[i],
BufIndex(reduce_index + offset, group_index, reduce_extent),
const_true()));
- a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
+ a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true()));
}
Array<Expr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
- stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
+ stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
}
return SeqStmt::Flatten(stores);
};
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < (reduce_extent - reduce_align);
- seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+ seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
reduce_align > warp_size_) {
reduce_align = reduce_align >> 1;
Expr cond = reduce_index < reduce_align;
- seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+ seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
- seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
+ seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
return SeqStmt::Flatten(seq);
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
- return Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm::make(sync)},
- Call::Intrinsic));
+ return EvaluateNode::make(
+ CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImmNode::make(sync)},
+ CallNode::Intrinsic));
}
// The local buffer index.
static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
int warp_size_{1};
// surrounding scope of thread extent.
- std::vector<const AttrStmt*> thread_extents_;
+ std::vector<const AttrStmtNode*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
- std::unordered_map<const Variable *, Expr> load_remap_;
+ std::unordered_map<const VarNode *, Expr> load_remap_;
// Allocate remap
- std::unordered_map<const Variable *, Stmt> alloc_remap_;
+ std::unordered_map<const VarNode *, Stmt> alloc_remap_;
};
LoweredFunc
}
inline Expr StackAlloca(std::string type, size_t num) {
- Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
- return Call::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
+ Array<Expr> args = {StringImmNode::make(type), ConstInt32(num)};
+ return CallNode::make(
+ DataType::Handle(),
+ intrinsic::tvm_stack_alloca,
+ args, CallNode::Intrinsic);
}
// Calculate the statistics of packed function.
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
- stmt = LetStmt::make(
+ stmt = LetStmtNode::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
- stmt = LetStmt::make(
+ stmt = LetStmtNode::make(
stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
- stmt = LetStmt::make(
+ stmt = LetStmtNode::make(
stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
- stmt = LetStmt::make(
+ stmt = LetStmtNode::make(
stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
}
- Stmt VisitStmt_(const Allocate* op) {
+ Stmt VisitStmt_(const AllocateNode* op) {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
if (op->new_expr.defined()) return stmt;
// Get constant allocation bound.
int64_t dev_type;
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error = Evaluate::make(Call::make(DataType::Int(32),
- intrinsic::tvm_throw_last_error, {},
- Call::Intrinsic));
+ Stmt throw_last_error = EvaluateNode::make(
+ CallNode::make(DataType::Int(32),
+ intrinsic::tvm_throw_last_error, {},
+ CallNode::Intrinsic));
Stmt body = SeqStmt({
- IfThenElse::make(Call::make(DataType::Bool(1),
- intrinsic::tvm_handle_is_null,
- {op->buffer_var}, Call::PureIntrinsic),
- throw_last_error),
+ IfThenElseNode::make(
+ CallNode::make(DataType::Bool(1),
+ intrinsic::tvm_handle_is_null,
+ {op->buffer_var}, CallNode::PureIntrinsic),
+ throw_last_error),
op->body});
- Stmt alloca = LetStmt::make(
+ Stmt alloca = LetStmtNode::make(
op->buffer_var,
- Call::make(op->buffer_var.dtype(),
- "TVMBackendAllocWorkspace",
- {cast(DataType::Int(32), device_type_),
- cast(DataType::Int(32), device_id_),
- cast(DataType::UInt(64), total_bytes),
- IntImm::make(DataType::Int(32), op->dtype.code()),
- IntImm::make(DataType::Int(32), op->dtype.bits())},
- Call::Extern),
+ CallNode::make(op->buffer_var.dtype(),
+ "TVMBackendAllocWorkspace",
+ {cast(DataType::Int(32), device_type_),
+ cast(DataType::Int(32), device_id_),
+ cast(DataType::UInt(64), total_bytes),
+ IntImmNode::make(DataType::Int(32), op->dtype.code()),
+ IntImmNode::make(DataType::Int(32), op->dtype.bits())},
+ CallNode::Extern),
body);
- Expr free_op = Call::make(DataType::Int(32),
- "TVMBackendFreeWorkspace",
- {cast(DataType::Int(32), device_type_),
- cast(DataType::Int(32), device_id_),
- op->buffer_var},
- Call::Extern);
- Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
+ Expr free_op = CallNode::make(DataType::Int(32),
+ "TVMBackendFreeWorkspace",
+ {cast(DataType::Int(32), device_type_),
+ cast(DataType::Int(32), device_id_),
+ op->buffer_var},
+ CallNode::Extern);
+ Stmt free_stmt = IfThenElseNode::make(
+ free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
- body = AttrStmt::make(
+ body = AttrStmtNode::make(
op->buffer_var, attr::storage_alignment,
make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
body);
return body;
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return StmtExprMutator::VisitStmt_(op);
}
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
}
}
// call shape
- Expr MakeShape(const Call* op) {
+ Expr MakeShape(const CallNode* op) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
- Store::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
+ StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin +i), const_true(1)));
}
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
- Expr MakeArray(const Call* op) {
+ Expr MakeArray(const CallNode* op) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
prep_seq_.emplace_back(
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
- Expr MakeCallPacked(const Call* op) {
+ Expr MakeCallPacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
- arg = Cast::make(api_type, arg);
+ arg = CastNode::make(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(
stack_value_, static_cast<int>(arg_stack_begin + i - 1),
intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code();
- if (api_type.is_handle() && arg.as<StringImm>()) {
+ if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kStr;
}
if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
prep_seq_.emplace_back(
- Store::make(stack_tcode_,
+ StoreNode::make(stack_tcode_,
ConstInt32(arg_tcode),
stack_index, const_true(1)));
}
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)
};
- return Call::make(
+ return CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed_lowered,
- packed_args, Call::Intrinsic);
+ packed_args, CallNode::Intrinsic);
}
- Expr MakeCallTracePacked(const Call *op) {
+ Expr MakeCallTracePacked(const CallNode *op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
size_t args_size = op->args.size();
CHECK_GT(args_size, 0);
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
- arg = Cast::make(api_type, arg);
+ arg = CastNode::make(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(
stack_value_, static_cast<int>(arg_stack_begin + i - 1),
int arg_tcode = api_type.code();
CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
- Store::make(stack_tcode_,
+ StoreNode::make(stack_tcode_,
ConstInt32(arg_tcode),
stack_index, const_true(1)));
}
// Pass traced value.
op->args[args_size - 1]
};
- return Call::make(
+ return CallNode::make(
op->dtype, intrinsic::tvm_call_trace_packed_lowered,
- packed_args, Call::Intrinsic);
+ packed_args, CallNode::Intrinsic);
}
private:
bool IsArrayHandle(const Expr& arg) {
// specially set array handle.
- if (const Call* buf = arg.as<Call>()) {
+ if (const CallNode* buf = arg.as<CallNode>()) {
if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
- buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
+ buf->args[2].as<IntImmNode>()->value == intrinsic::kArrAddr) {
return true;
}
}
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private StmtVisitor {
public:
- WarpStoreCoeffFinder(const Variable* buffer,
+ WarpStoreCoeffFinder(const VarNode* buffer,
Var warp_index,
arith::Analyzer* analyzer)
: buffer_(buffer),
private:
/// Visitor implementation
- void VisitStmt_(const Store *op) final {
+ void VisitStmt_(const StoreNode *op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
}
// The buffer variable
- const Variable* buffer_;
+ const VarNode* buffer_;
// the warp index
Var warp_index_;
// the coefficient
private:
/// Visitor implementation
- void VisitStmt_(const AttrStmt *op) final {
+ void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
: warp_size_(warp_size), analyzer_(analyzer) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
- Stmt Rewrite(const Allocate* op) {
+ Stmt Rewrite(const AllocateNode* op) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var,
op->dtype,
{make_const(DataType::Int(32), alloc_size / warp_size_)},
}
protected:
- Expr Mutate_(const Variable* op) {
+ Expr Mutate_(const VarNode* op) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
}
- Stmt VisitStmt_(const Store* op) {
+ Stmt VisitStmt_(const StoreNode* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
- return Store::make(op->buffer_var, op->value, local_index, op->predicate);
+ return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Expr Mutate_(const Load* op) {
+ Expr Mutate_(const LoadNode* op) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
CHECK(!ExprUseVar(local_index, {warp_index_.get()}))
<< "LowerWarpMemory failed to rewrite load to shuffle for index "
<< op->index << " local_index=" << local_index;
- Expr load_value = Load::make(
+ Expr load_value = LoadNode::make(
op->dtype, op->buffer_var, local_index, op->predicate);
- return Call::make(load_value.dtype(),
+ return CallNode::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle,
{load_value, group},
- Call::Intrinsic);
+ CallNode::Intrinsic);
} else {
return StmtExprMutator::VisitExpr_(op);
}
CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
local_index =
- Ramp::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
+ RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group);
}
Expr m = make_const(index.dtype(), warp_coeff_);
// the warp size
int warp_size_{0};
// The buffer variable
- const Variable* buffer_;
+ const VarNode* buffer_;
// Warp index
Var warp_index_;
// the coefficient m
explicit BindVarBoundInfo(arith::Analyzer* analyzer)
: analyzer_(analyzer) {}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
const Var& loop_var = op->loop_var;
analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
StmtVisitor::VisitStmt_(op);
}
- void VisitStmt_(const AttrStmt* op) {
+ void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
// internal analyzer.
arith::Analyzer* analyzer_;
// variable domain
- std::unordered_map<const Variable*, Range> var_dom_;
+ std::unordered_map<const VarNode*, Range> var_dom_;
};
// Mutator to change the read pattern
}
private:
- Stmt VisitStmt_(const Allocate* op) {
+ Stmt VisitStmt_(const AllocateNode* op) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_, &analyzer_);
return rewriter.Rewrite(op);
}
}
- Stmt VisitStmt_(const AttrStmt* op) {
+ Stmt VisitStmt_(const AttrStmtNode* op) {
using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
- StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
+ const VarNode* buf = op->node.as<VarNode>();
+ StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf);
Stmt ret = StmtMutator::VisitStmt_(op);
- op = ret.as<AttrStmt>();
- return AttrStmt::make(
- op->node, op->attr_key, StringImm::make("local"), op->body);
+ op = ret.as<AttrStmtNode>();
+ return AttrStmtNode::make(
+ op->node, op->attr_key, StringImmNode::make("local"), op->body);
}
}
return StmtMutator::VisitStmt_(op);
}
int warp_size_{0};
- std::unordered_set<const Variable*> warp_buffer_;
+ std::unordered_set<const VarNode*> warp_buffer_;
arith::Analyzer analyzer_;
// variable domain
- std::unordered_map<const Variable*, Range> var_dom_;
+ std::unordered_map<const VarNode*, Range> var_dom_;
};
LoweredFunc
namespace ir {
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
- return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
+ return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
}
LoweredFunc MakeAPI(Stmt body,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted) {
- const Stmt nop = Evaluate::make(0);
+ const Stmt nop = EvaluateNode::make(0);
int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args;
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check;
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<Expr> call_args{v_packed_args,
- IntImm::make(DataType::Int(32), i),
- IntImm::make(DataType::Int(32), intrinsic::kTVMValueContent)};
+ IntImmNode::make(DataType::Int(32), i),
+ IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
- Expr res = Call::make(
+ Expr res = CallNode::make(
api_type, intrinsic::tvm_struct_get, call_args,
- Call::PureIntrinsic);
+ CallNode::PureIntrinsic);
// cast to the target version.
if (api_type != t) {
- res = Cast::make(t, res);
+ res = CastNode::make(t, res);
}
return res;
};
auto f_arg_decl = [&](int i) {
std::ostringstream os;
os << "arg" << i;
- const Variable* v = api_args[i].as<Variable>();
+ const VarNode* v = api_args[i].as<VarNode>();
return Var(os.str(), v ? v->dtype: DataType::Handle());
};
// ---------------------------
Var v_arg = f_arg_decl(i);
if (i < num_packed_args) {
// Value loads
- seq_init.emplace_back(LetStmt::make(
+ seq_init.emplace_back(LetStmtNode::make(
v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
- seq_init.emplace_back(LetStmt::make(
- tcode, Load::make(
+ seq_init.emplace_back(LetStmtNode::make(
+ tcode, LoadNode::make(
DataType::Int(32), v_packed_arg_type_ids,
- IntImm::make(DataType::Int(32), i), const_true(1)),
+ IntImmNode::make(DataType::Int(32), i), const_true(1)),
nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(
- AssertStmt::make(tcode == kHandle ||
+ AssertStmtNode::make(tcode == kHandle ||
tcode == kNDArrayContainer ||
tcode == kArrayHandle ||
tcode == kNull, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int";
- seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
+ seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
- AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
+ AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
}
} else {
args.push_back(v_arg);
}
// add checks for functions.
- if (api_args[i].as<Variable>()) {
+ if (api_args[i].as<VarNode>()) {
var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
} else {
// Buffer checks
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
- body = AttrStmt::make(
+ body = AttrStmtNode::make(
make_zero(DataType::Int(32)), attr::compute_scope,
- StringImm::make(name + "_compute_"), body);
+ StringImmNode::make(name + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
- Expr node = StringImm::make("default");
+ Expr node = StringImmNode::make("default");
CHECK(vmap.count(device_type.get()));
- seq_check.push_back(AttrStmt::make(
+ seq_check.push_back(AttrStmtNode::make(
node, attr::device_context_id, device_id, nop));
- seq_check.push_back(AttrStmt::make(
+ seq_check.push_back(AttrStmtNode::make(
node, attr::device_context_type, device_type, nop));
- Stmt set_device = IfThenElse::make(
- device_type != kDLCPU, Evaluate::make(Call::make(
+ Stmt set_device = IfThenElseNode::make(
+ device_type != kDLCPU, EvaluateNode::make(CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed,
- {StringImm::make(runtime::symbol::tvm_set_device),
- device_type, device_id}, Call::Intrinsic)));
+ {StringImmNode::make(runtime::symbol::tvm_set_device),
+ device_type, device_id}, CallNode::Intrinsic)));
body = SeqStmt({set_device, body});
}
n->body = MergeNest(
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
- if (const Variable* var = op->value.as<Variable>()) {
+ if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
Expr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
- return AssertStmt::make(op->value == value, os.str(), body);
+ return AssertStmtNode::make(op->value == value, os.str(), body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const IfThenElse* op) final {
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
- op = res.as<IfThenElse>();
+ op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
return res;
}
- Expr VisitExpr_(const NE* op) final {
+ Expr VisitExpr_(const NENode* op) final {
// eager check NE for device check
Expr res = StmtExprMutator::VisitExpr_(op);
- op = res.as<NE>();
+ op = res.as<NENode>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
}
public:
- const Variable* var_{nullptr};
+ const VarNode* var_{nullptr};
int device_type_;
};
}
private:
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
auto it = tmap_.find(iv->thread_tag);
if (it != tmap_.end()) {
const IterVar& new_iv = it->second;
- const Variable* v = iv->var.get();
+ const VarNode* v = iv->var.get();
if (!vmap_.count(v)) {
vmap_[v] = new_iv->var;
} else {
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->VisitStmt(op->body);
- return AttrStmt::make(
+ return AttrStmtNode::make(
new_iv, op->attr_key, op->value, body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return StmtExprMutator::VisitExpr_(op);
// The thread map
const std::unordered_map<std::string, IterVar>& tmap_;
// variable map
- std::unordered_map<const Variable*, Var> vmap_;
+ std::unordered_map<const VarNode*, Var> vmap_;
};
LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
- const StringImm* str = kv.first.as<StringImm>();
+ const StringImmNode* str = kv.first.as<StringImmNode>();
CHECK(str != nullptr);
tmap[str->value] = kv.second;
}
// Mark the statment of each stage.
class NoOpRemover : public StmtMutator {
public:
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<LetStmt>();
+ op = stmt.as<LetStmtNode>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_debug_skip_region") {
return MakeEvaluate(0);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<AttrStmt>();
+ op = stmt.as<AttrStmtNode>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
- Stmt VisitStmt_(const IfThenElse* op) final {
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<IfThenElse>();
+ op = stmt.as<IfThenElseNode>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
- return IfThenElse::make(op->condition, op->then_case);
+ return IfThenElseNode::make(op->condition, op->then_case);
}
} else {
return stmt;
}
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<For>();
+ op = stmt.as<ForNode>();
if (is_zero(op->extent)) {
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
}
- Stmt VisitStmt_(const ProducerConsumer* op) final {
+ Stmt VisitStmt_(const ProducerConsumerNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<ProducerConsumer>();
+ op = stmt.as<ProducerConsumerNode>();
return is_no_op(op->body) ? op->body : stmt;
}
- Stmt VisitStmt_(const Realize* op) final {
+ Stmt VisitStmt_(const RealizeNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<Realize>();
+ op = stmt.as<RealizeNode>();
return is_no_op(op->body) ? op->body : stmt;
}
- Stmt VisitStmt_(const Evaluate* op) final {
+ Stmt VisitStmt_(const EvaluateNode* op) final {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
private:
Stmt MakeEvaluate(Expr value) {
if (HasSideEffect(value)) {
- return Evaluate::make(value);
+ return EvaluateNode::make(value);
} else {
- return Evaluate::make(0);
+ return EvaluateNode::make(0);
}
}
Stmt MakeEvaluate(const Array<Expr>& values) {
for (Expr e : values) {
if (HasSideEffect(e)) {
if (stmt.defined()) {
- stmt = SeqStmt({stmt, Evaluate::make(e)});
+ stmt = SeqStmt({stmt, EvaluateNode::make(e)});
} else {
- stmt = Evaluate::make(e);
+ stmt = EvaluateNode::make(e);
}
}
}
- return stmt.defined() ? stmt : Evaluate::make(0);
+ return stmt.defined() ? stmt : EvaluateNode::make(0);
}
};
public:
// select itself is always considered safe if condition is safe
// Because we will issue guard to make sure it is.
- bool VisitExpr_(const Select* op) {
+ bool VisitExpr_(const SelectNode* op) {
return VisitExpr(op->condition);
}
- bool VisitExpr_(const Call* op) {
+ bool VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
return VisitExpr(op->args[0]);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load* l = op->args[0].as<Load>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
return this->VisitExpr(l->index);
} else if (op->is_pure()) {
for (Expr e : op->args) {
return true;
}
}
- bool VisitExpr_(const Load* op) {
+ bool VisitExpr_(const LoadNode* op) {
// Load is considered unsafe.
return true;
}
- bool VisitExpr_(const Add* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Sub* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
- bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
- bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
- bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
- bool VisitExpr_(const NE* op) final { return BinaryOp(op); }
- bool VisitExpr_(const LT* op) final { return BinaryOp(op); }
- bool VisitExpr_(const LE* op) final { return BinaryOp(op); }
- bool VisitExpr_(const GT* op) final { return BinaryOp(op); }
- bool VisitExpr_(const GE* op) final { return BinaryOp(op); }
- bool VisitExpr_(const And* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Or* op) final { return BinaryOp(op); }
- bool VisitExpr_(const Not* op) final {
+ bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const NENode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const LENode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const GENode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); }
+ bool VisitExpr_(const NotNode* op) final {
return VisitExpr(op->a);
}
- bool VisitExpr_(const Let* op) final {
+ bool VisitExpr_(const LetNode* op) final {
return VisitExpr(op->body) || VisitExpr(op->value);
}
- bool VisitExpr_(const Cast* op) final {
+ bool VisitExpr_(const CastNode* op) final {
return VisitExpr(op->value);
}
- bool VisitExpr_(const Broadcast* op) final {
+ bool VisitExpr_(const BroadcastNode* op) final {
return VisitExpr(op->value);
}
- bool VisitExpr_(const Ramp* op) final {
+ bool VisitExpr_(const RampNode* op) final {
return VisitExpr(op->base) && VisitExpr(op->stride);
}
- bool VisitExpr_(const Shuffle* op) final {
+ bool VisitExpr_(const ShuffleNode* op) final {
for (Expr e : op->vectors) {
if (VisitExpr(e)) return true;
}
return false;
}
- bool VisitExpr_(const Variable* op) final { return false; }
- bool VisitExpr_(const UIntImm* op) final { return false; }
- bool VisitExpr_(const IntImm* op) final { return false; }
- bool VisitExpr_(const FloatImm* op) final { return false; }
- bool VisitExpr_(const StringImm* op) final { return false; }
+ bool VisitExpr_(const VarNode* op) final { return false; }
+ bool VisitExpr_(const UIntImmNode* op) final { return false; }
+ bool VisitExpr_(const IntImmNode* op) final { return false; }
+ bool VisitExpr_(const FloatImmNode* op) final { return false; }
+ bool VisitExpr_(const StringImmNode* op) final { return false; }
private:
template<typename T>
class UnsafeSelectRewriter : public StmtExprMutator {
public:
- Expr VisitExpr_(const Select* op) {
+ Expr VisitExpr_(const SelectNode* op) {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Select>();
+ op = expr.as<SelectNode>();
UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
if ((unsafe.VisitExpr(op->true_value) ||
unsafe.VisitExpr(op->false_value)) &&
cond_is_scalar_bool) {
- return Call::make(
+ return CallNode::make(
op->dtype,
intrinsic::tvm_if_then_else,
{op->condition, op->true_value, op->false_value},
- Call::Intrinsic);
+ CallNode::Intrinsic);
} else {
return expr;
}
ExprVisitor::VisitExpr(e);
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
if (!op->is_pure()) {
has_side_effect_ = true; return;
} else {
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
- const std::unordered_map<const Variable*, Expr>& smap)
+ const std::unordered_map<const VarNode*, Expr>& smap)
: smap_(smap) {
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
}
private:
- const std::unordered_map<const Variable*, Expr>& smap_;
+ const std::unordered_map<const VarNode*, Expr>& smap_;
};
Stmt Substitute(Stmt stmt,
- const std::unordered_map<const Variable*, Expr>& value_map) {
+ const std::unordered_map<const VarNode*, Expr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map)(std::move(stmt));
}
Expr Substitute(Expr expr,
- const std::unordered_map<const Variable*, Expr>& value_map) {
+ const std::unordered_map<const VarNode*, Expr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map)(std::move(expr));
}
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
}
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
- std::unordered_map<const Variable*, Expr> vmap;
+ std::unordered_map<const VarNode*, Expr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
ExprVisitor::VisitExpr(e);
}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
Handle(op);
}
- void VisitExpr_(const Load* op) final {
+ void VisitExpr_(const LoadNode* op) final {
Handle(op->buffer_var.get());
ExprVisitor::VisitExpr_(op);
}
- virtual void Handle(const Variable* var) = 0;
+ virtual void Handle(const VarNode* var) = 0;
bool use_var_{false};
};
class ExprUseVarVisitor : public VarTouchVisitor {
public:
- explicit ExprUseVarVisitor(const Variable* var)
+ explicit ExprUseVarVisitor(const VarNode* var)
: var_(var) {}
- void Handle(const Variable* var) final {
+ void Handle(const VarNode* var) final {
if (var == var_) use_var_ = true;
}
private:
- const Variable* var_;
+ const VarNode* var_;
};
class ExprUseVSetVisitor : public VarTouchVisitor {
public:
explicit ExprUseVSetVisitor(
- const std::unordered_set<const Variable*>& vset)
+ const std::unordered_set<const VarNode*>& vset)
: vset_(vset) {}
- void Handle(const Variable* var) final {
+ void Handle(const VarNode* var) final {
if (vset_.count(var)) use_var_ = true;
}
private:
- const std::unordered_set<const Variable*>& vset_;
+ const std::unordered_set<const VarNode*>& vset_;
};
bool ExprUseVar(const Expr& e, const Var& v) {
}
bool ExprUseVar(const Expr& e,
- const std::unordered_set<const Variable*>& vset) {
+ const std::unordered_set<const VarNode*>& vset) {
ExprUseVSetVisitor visitor(vset);
visitor(e);
return visitor.use_var_;
class AssertSkipper : public StmtMutator {
public:
- Stmt VisitStmt_(const AssertStmt* op) final {
+ Stmt VisitStmt_(const AssertStmtNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<AssertStmt>();
+ op = stmt.as<AssertStmtNode>();
return op->body;
}
};
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public StmtExprMutator {
public:
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
- return AttrStmt::make(op->node, op->attr_key, value, body);
+ return AttrStmtNode::make(op->node, op->attr_key, value, body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
- return LetStmt::make(op->var, value, body);
+ return LetStmtNode::make(op->var, value, body);
}
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
this->HandleDef(op->loop_var.get());
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
this->HandleDef(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
this->HandleUse(op->buffer_var);
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const Let* op) final {
+ Expr VisitExpr_(const LetNode* op) final {
this->HandleDef(op->var.get());
Expr body = this->VisitExpr(op->body);
// eliminate unreferenced let
value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return Let::make(op->var, value, body);
+ return LetNode::make(op->var, value, body);
}
}
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
this->HandleUse(GetRef<Expr>(op));
return StmtExprMutator::VisitExpr_(op);
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
this->HandleUse(op->buffer_var);
return StmtExprMutator::VisitExpr_(op);
}
- void HandleDef(const Variable* v) {
+ void HandleDef(const VarNode* v) {
CHECK(!def_count_.count(v))
<< "variable " << v->name_hint
<< " has already been defined, the Stmt is not SSA";
}
void HandleUse(const Expr& v) {
- CHECK(v.as<Variable>());
+ CHECK(v.as<VarNode>());
Var var = Downcast<Var>(v);
auto it = use_count_.find(var.get());
if (it != use_count_.end()) {
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<Expr> thread_extent_;
- std::unordered_map<const Variable*, int> use_count_;
- std::unordered_map<const Variable*, int> def_count_;
+ std::unordered_map<const VarNode*, int> use_count_;
+ std::unordered_map<const VarNode*, int> def_count_;
};
class HostDeviceSplitter : public StmtMutator {
public:
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return StmtMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
}
LoweredFunc f_device(n);
Array<Expr> call_args;
- call_args.push_back(StringImm::make(f_device->name));
+ call_args.push_back(StringImmNode::make(f_device->name));
for (Var arg : n->args) {
call_args.push_back(arg);
}
call_args.push_back(ext);
}
device_funcs_.emplace_back(f_device);
- return Evaluate::make(Call::make(
+ return EvaluateNode::make(CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed,
- call_args, Call::Intrinsic));
+ call_args, CallNode::Intrinsic));
}
// function name
std::string name_;
// the device functions
std::vector<LoweredFunc> device_funcs_;
- std::unordered_map<const Variable*, Expr> handle_data_type_;
+ std::unordered_map<const VarNode*, Expr> handle_data_type_;
};
if (!is_ssa) return;
StmtExprVisitor::VisitStmt(n);
}
- void VisitExpr_(const Let* op) final {
+ void VisitExpr_(const LetNode* op) final {
MarkDef(op->var.get());
StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const LetStmt* op) final {
+ void VisitStmt_(const LetStmtNode* op) final {
MarkDef(op->var.get());
StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
MarkDef(op->loop_var.get());
StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
MarkDef(op->buffer_var.get());
StmtExprVisitor::VisitStmt_(op);
}
private:
- void MarkDef(const Variable* v) {
+ void MarkDef(const VarNode* v) {
if (defined_.count(v) != 0) {
is_ssa = false; return;
} else {
defined_[v] = 1;
}
}
- std::unordered_map<const Variable*, int> defined_;
+ std::unordered_map<const VarNode*, int> defined_;
};
class IRConvertSSA final : public StmtExprMutator {
public:
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
if (scope_.count(op)) {
return scope_[op].back();
} else {
return GetRef<Expr>(op);
}
}
- Expr VisitExpr_(const Let* op) final {
+ Expr VisitExpr_(const LetNode* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = this->VisitExpr(op->value);
- VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+ VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Expr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back();
- return Let::make(new_var, value, body);
+ return LetNode::make(new_var, value, body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
if (scope_.count(op->buffer_var.get())) {
- return Load::make(
+ return LoadNode::make(
op->dtype, scope_[op->buffer_var.get()].back(),
op->index, op->predicate);
} else {
return expr;
}
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
if (scope_.count(op->buffer_var.get())) {
- return Store::make(
+ return StoreNode::make(
scope_[op->buffer_var.get()].back(), op->value,
op->index, op->predicate);
} else {
return stmt;
}
}
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
const VarExpr& v = op->var;
if (defined_.count(v.get())) {
Expr value = this->VisitExpr(op->value);
- VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+ VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back();
- return LetStmt::make(new_var, value, body);
+ return LetStmtNode::make(new_var, value, body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
const VarExpr& v = op->loop_var;
if (defined_.count(v.get())) {
- VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+ VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
- op = stmt.as<For>();
- return For::make(
+ op = stmt.as<ForNode>();
+ return ForNode::make(
new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
const VarExpr& v = op->buffer_var;
if (defined_.count(v.get())) {
- VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+ VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
- op = stmt.as<Allocate>();
- return Allocate::make(
+ op = stmt.as<AllocateNode>();
+ return AllocateNode::make(
new_var, op->dtype, op->extents, op->condition,
op->body, op->new_expr, op->free_function);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
- if (const Variable* v = op->node.as<Variable>()) {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (const VarNode* v = op->node.as<VarNode>()) {
if (op->attr_key == attr::storage_scope) {
- const Allocate* alloc = op->body.as<Allocate>();
+ const AllocateNode* alloc = op->body.as<AllocateNode>();
if (alloc && op->node.same_as(alloc->buffer_var)) {
Stmt new_alloc = this->VisitStmt(op->body);
if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
- alloc = new_alloc.as<Allocate>();
+ alloc = new_alloc.as<AllocateNode>();
CHECK(alloc);
- return AttrStmt::make(
+ return AttrStmtNode::make(
alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AttrStmt>();
+ op = stmt.as<AttrStmtNode>();
if (scope_.count(v) && scope_[v].size() != 0) {
- return AttrStmt::make(
+ return AttrStmtNode::make(
scope_[v].back(), op->attr_key, op->value, op->body);
} else {
return stmt;
}
private:
- std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
- std::unordered_set<const Variable*> defined_;
+ std::unordered_map<const VarNode*, std::vector<VarExpr> > scope_;
+ std::unordered_set<const VarNode*> defined_;
};
} // namespace
namespace tvm {
namespace ir {
-void StorageAccessVisitor::VisitExpr_(const Load* op) {
- const Variable* buf = op->buffer_var.as<Variable>();
+void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
+ const VarNode* buf = op->buffer_var.as<VarNode>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
CHECK(allow_append_);
StmtExprVisitor::VisitExpr_(op);
}
-void StorageAccessVisitor::VisitStmt_(const Store* op) {
+void StorageAccessVisitor::VisitStmt_(const StoreNode* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
- const Variable* buf = op->buffer_var.as<Variable>();
+ const VarNode* buf = op->buffer_var.as<VarNode>();
StorageScope scope = GetScope(buf);
if (Enabled(buf, scope)) {
AccessEntry e;
allow_append_ = false;
}
-void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
+void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) {
allow_append_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
allow_append_ = false;
}
-void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
+void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
+ const VarNode* buf = op->node.as<VarNode>();
storage_scope_[buf] =
- StorageScope::make(op->value.as<StringImm>()->value);
+ StorageScope::make(op->value.as<StringImmNode>()->value);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) {
CHECK(double_buffer_write_ == nullptr);
- double_buffer_write_ = op->node.as<Variable>();
+ double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
}
}
-void StorageAccessVisitor::VisitStmt_(const For* op) {
+void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
scope_.pop_back();
if (s.access.size() != 0) {
// relax the touched set to contain all ranges in the loop.
- std::unordered_map<const Variable*, arith::IntSet> relax_map;
+ std::unordered_map<const VarNode*, arith::IntSet> relax_map;
relax_map[op->loop_var.get()] = arith::IntSet::range(
Range::make_by_min_extent(op->min, op->extent));
for (AccessEntry& e : s.access) {
}
}
-void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
+void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
++condition_counter_;
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
--condition_counter_;
}
-void StorageAccessVisitor::VisitExpr_(const Call* op) {
+void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load *l = op->args[0].as<Load>();
+ const LoadNode *l = op->args[0].as<LoadNode>();
StmtExprVisitor::VisitExpr_(l);
} else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
Expr offset = op->args[2];
Expr extent = op->args[3];
- const IntImm* flag = op->args[4].as<IntImm>();
+ const IntImmNode* flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(buffer);
// The buffer scope.
if (Enabled(buffer, scope)) {
StmtExprVisitor::VisitExpr_(op);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
CHECK(allow_append_);
- const std::string& s = op->args[0].as<StringImm>()->value;
+ const std::string& s = op->args[0].as<StringImmNode>()->value;
if (s != "warp") {
StorageScope scope = StorageScope::make(s);
AccessEntry e;
}
}
-StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
+StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s;
s.rank = StorageRank::kGlobal;
class StorageAccessInfoLower : public StmtExprMutator {
public:
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return stmt;
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
- StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
+ const VarNode* buf = op->node.as<VarNode>();
+ StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
- e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
+ e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
}
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op);
} else {
private:
// tvm_access_ptr
- Expr MakeAccessPtr(const Call* op) {
+ Expr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
Expr offset = op->args[2];
auto it = storage_info_.find(buffer);
int alloc_count{0};
};
// The storage scope of each buffer
- std::unordered_map<const Variable*, StorageEntry> storage_info_;
+ std::unordered_map<const VarNode*, StorageEntry> storage_info_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
std::vector<AccessEntry> access;
};
// override visitor pattern
- void VisitExpr_(const Load* op) final;
- void VisitStmt_(const Store* op) final;
- void VisitStmt_(const Evaluate* op) final;
- void VisitStmt_(const AttrStmt* op) final;
- void VisitStmt_(const For* op) final;
- void VisitStmt_(const IfThenElse* op) final;
- void VisitExpr_(const Call* op) final;
+ void VisitExpr_(const LoadNode* op) final;
+ void VisitStmt_(const StoreNode* op) final;
+ void VisitStmt_(const EvaluateNode* op) final;
+ void VisitStmt_(const AttrStmtNode* op) final;
+ void VisitStmt_(const ForNode* op) final;
+ void VisitStmt_(const IfThenElseNode* op) final;
+ void VisitExpr_(const CallNode* op) final;
protected:
StorageAccessVisitor() {
* \param scope The scope of the buffer.
* \return Whether the analysis of buffer is enabled.
*/
- virtual bool Enabled(const Variable* buffer,
+ virtual bool Enabled(const VarNode* buffer,
const StorageScope& scope) const {
return true;
}
* the parent should taken care of to synchronize.
*/
virtual std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const For* loop) = 0;
+ std::vector<StmtEntry> seq, const ForNode* loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
- StorageScope GetScope(const Variable* buf) const;
+ StorageScope GetScope(const VarNode* buf) const;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
- const Variable* double_buffer_write_{nullptr};
+ const VarNode* double_buffer_write_{nullptr};
// the current free stmt entry.
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
// The storage scope of each buffer
- std::unordered_map<const Variable*, StorageScope> storage_scope_;
+ std::unordered_map<const VarNode*, StorageScope> storage_scope_;
};
} // namespace ir
cache_line_size_ = cache_line_size;
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
- CHECK(it->second.as<Variable>());
+ CHECK(it->second.as<VarNode>());
VarExpr buf_var = Downcast<VarExpr>(it->second);
- return Store::make(buf_var, op->value, op->index, op->predicate);
+ return StoreNode::make(buf_var, op->value, op->index, op->predicate);
} else {
return stmt;
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::realize_scope) {
- storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+ storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope &&
op->node->IsInstance<OperationNode>()) {
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
- body = AttrStmt::make(
+ body = AttrStmtNode::make(
it->second.buffer->data, op->attr_key, op->value, body);
}
return body;
return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor = Downcast<Tensor>(op->node);
- const Call* tuple = op->value.as<Call>();
+ const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
auto& vinfo = dim_align_[key];
- int dim = tuple->args[0].as<IntImm>()->value;
+ int dim = tuple->args[0].as<IntImmNode>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) {
vinfo.resize(dim + 1);
}
- vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
- vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+ vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+ vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::opengl_stage_scope) {
is_opengl_ = true;
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const Provide* op) final {
+ Stmt VisitStmt_(const ProvideNode* op) final {
if (create_bound_attributes_)
shape_collector_.clear();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Provide>();
+ op = stmt.as<ProvideNode>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
if (is_opengl_) {
- return Evaluate::make(Call::make(
+ return EvaluateNode::make(CallNode::make(
DataType(),
- Call::glsl_texture_store,
+ CallNode::glsl_texture_store,
{e.buffer->data, op->value},
- Call::Intrinsic));
+ CallNode::Intrinsic));
} else {
Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
// To create bound attribute collector should has at least one item.
if (create_bound_attributes_ && shape_collector_.size()) {
for (size_t i = 0; i < shape_collector_.size(); ++i) {
- body = AttrStmt::make(
+ body = AttrStmtNode::make(
shape_collector_[i].first, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
}
}
}
- Stmt VisitStmt_(const Realize* op) final {
+ Stmt VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
}
// use small alignment for small arrays
- int32_t const_size = Allocate::constant_allocation_size(shape);
+ int32_t const_size = AllocateNode::constant_allocation_size(shape);
int align = GetTempAllocaAlignment(op->dtype, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
}
if (strides.size() != 0) {
int first_dim = 0;
- ret = Allocate::make(
+ ret = AllocateNode::make(
e.buffer->data, storage_type,
{e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
if (shape.size() == 0) {
shape.push_back(make_const(DataType::Int(32), 1));
}
- ret = Allocate::make(
+ ret = AllocateNode::make(
e.buffer->data, storage_type, shape,
make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
- ret = AttrStmt::make(
+ ret = AttrStmtNode::make(
e.buffer->data, attr::storage_scope,
- StringImm::make(e.buffer->scope), ret);
+ StringImmNode::make(e.buffer->scope), ret);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
- ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
+ ret = AttrStmtNode::make(e.buffer->data, ir::attr::buffer_bound,
MakeBound(e.buffer->dtype, e.buffer->shape), ret);
}
return ret;
}
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
- CHECK(it->second.as<Variable>());
+ CHECK(it->second.as<VarNode>());
VarExpr buf_var = Downcast<VarExpr>(it->second);
- return Load::make(op->dtype, buf_var, op->index, op->predicate);
+ return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
} else {
return expr;
}
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
}
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
- if (op != nullptr && op->call_type == Call::Halide) {
+ op = expr.as<CallNode>();
+ if (op != nullptr && op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
}
}
- Stmt VisitStmt_(const Prefetch *op) final {
+ Stmt VisitStmt_(const PrefetchNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Prefetch>();
+ op = stmt.as<PrefetchNode>();
CHECK(op != nullptr);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
}
for (int i = starts; i >= 0; --i) {
if (i < starts) {
- stmt = For::make(
+ stmt = ForNode::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
- Expr address = Call::make(DataType::Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
- Expr prefetch = Call::make(op->dtype, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
- stmt = Evaluate::make(prefetch);
+ Expr address = CallNode::make(
+ DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
+ Expr prefetch = CallNode::make(
+ op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
+ stmt = EvaluateNode::make(prefetch);
Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
- stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
+ stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
return stmt;
//
// We do support a few relaxed case, such as bindingx
// region with shape [1, 1, n, m] to buffer with shape [n, m]
- Stmt HandleBufferBindScope(const AttrStmt* op) {
+ Stmt HandleBufferBindScope(const AttrStmtNode* op) {
Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
const TensorNode* tensor = arr[1].as<TensorNode>();
- const Call* tuple = op->value.as<Call>();
+ const CallNode* tuple = op->value.as<CallNode>();
CHECK(buffer && tensor);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
TensorKey key{tensor->op, tensor->value_index};
Expr MakeBound(const DataType &type, const Array<Expr> &shape) {
// We have already checked the shape size to be greater then 0.
- Expr bound = Mul::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
+ Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
- bound = Mul::make(
- bound, Mul::make(make_const(bound.dtype(), type.lanes()), shape[i]));
+ bound = MulNode::make(
+ bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
}
return bound;
}
// The buffer assignment map
// Variable remap
- std::unordered_map<const Variable*, Expr> var_remap_;
+ std::unordered_map<const VarNode*, Expr> var_remap_;
// Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
// Dimension alignment
// if offset < 0, means this is the end, the begin entry is current_index + offset
int64_t scope_pair_offset{0};
// The buffer variables this statment touched.
- std::vector<const Variable*> touched;
+ std::vector<const VarNode*> touched;
};
// The scope of each allocation
struct AllocEntry {
// scope level
size_t level{0};
// allocation stmt
- const Allocate* alloc{nullptr};
+ const AllocateNode* alloc{nullptr};
};
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
size_t level = scope_.size();
- const Variable* buf = op->buffer_var.get();
+ const VarNode* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
CHECK(it != alloc_info_.end());
CHECK(it->second.alloc == nullptr);
it->second.level = level;
StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const Store* op) final {
+ void VisitStmt_(const StoreNode* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
// Add write access.
- const Variable* buf = op->buffer_var.get();
+ const VarNode* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
CHECK_LT(it->second.level, scope_.size());
linear_seq_.push_back(e);
}
}
- void VisitStmt_(const Evaluate* op) final {
+ void VisitStmt_(const EvaluateNode* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
linear_seq_.push_back(e);
}
}
- void VisitExpr_(const Load* op) final {
+ void VisitExpr_(const LoadNode* op) final {
// Add write access.
StmtExprVisitor::VisitExpr_(op);
- const Variable* buf = op->buffer_var.get();
+ const VarNode* buf = op->buffer_var.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
CHECK_LT(it->second.level, scope_.size())
scope_[it->second.level].touched.push_back(buf);
}
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
- const Load* l = op->args[0].as<Load>();
+ const LoadNode* l = op->args[0].as<LoadNode>();
this->VisitExpr(l->index);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
- void VisitExpr_(const Variable* buf) final {
+ void VisitExpr_(const VarNode* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
CHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
} else if (op->attr_key == attr::virtual_thread) {
VisitNewScope(op);
} else if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
+ const VarNode* buf = op->node.as<VarNode>();
alloc_info_[buf].storage_scope =
- StorageScope::make(op->value.as<StringImm>()->value);
+ StorageScope::make(op->value.as<StringImmNode>()->value);
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
- void VisitStmt_(const IfThenElse* op) final {
+ void VisitStmt_(const IfThenElseNode* op) final {
VisitNewScope(op);
}
- void VisitStmt_(const For* op) final {
+ void VisitStmt_(const ForNode* op) final {
VisitNewScope(op);
}
- void VisitStmt_(const AssertStmt* op) final {
+ void VisitStmt_(const AssertStmtNode* op) final {
VisitNewScope(op);
}
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
- std::unordered_map<const Variable*, AllocEntry> alloc_info_;
+ std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
private:
// Whether already in thread env.
class InplaceOpVerifier : public StmtExprVisitor {
public:
bool Check(const Object* stmt,
- const Variable* dst,
- const Variable* src) {
+ const VarNode* dst,
+ const VarNode* src) {
dst_ = dst;
src_ = src;
result_ = true;
- if (stmt->IsInstance<AttrStmt>()) {
- VisitStmt_(static_cast<const AttrStmt*>(stmt));
- } else if (stmt->IsInstance<For>()) {
- VisitStmt_(static_cast<const For*>(stmt));
- } else if (stmt->IsInstance<IfThenElse>()) {
- VisitStmt_(static_cast<const IfThenElse*>(stmt));
- } else if (stmt->IsInstance<Store>()) {
- VisitStmt_(static_cast<const Store*>(stmt));
+ if (stmt->IsInstance<AttrStmtNode>()) {
+ VisitStmt_(static_cast<const AttrStmtNode*>(stmt));
+ } else if (stmt->IsInstance<ForNode>()) {
+ VisitStmt_(static_cast<const ForNode*>(stmt));
+ } else if (stmt->IsInstance<IfThenElseNode>()) {
+ VisitStmt_(static_cast<const IfThenElseNode*>(stmt));
+ } else if (stmt->IsInstance<StoreNode>()) {
+ VisitStmt_(static_cast<const StoreNode*>(stmt));
} else {
return false;
}
StmtExprVisitor::VisitExpr(n);
}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
result_ = false; return;
}
}
- void VisitStmt_(const Store* op) final {
+ void VisitStmt_(const StoreNode* op) final {
++mem_nest_;
this->VisitExpr(op->index);
--mem_nest_;
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
// always reject extern code
if (op->attr_key == attr::extern_scope ||
op->attr_key == attr::volatile_scope) {
StmtExprVisitor::VisitStmt_(op);
}
- void VisitExpr_(const Load* op) final {
- const Variable* buf = op->buffer_var.get();
+ void VisitExpr_(const LoadNode* op) final {
+ const VarNode* buf = op->buffer_var.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
result_ = false; return;
// result of the check
bool result_{true};
// destination memory
- const Variable* dst_;
+ const VarNode* dst_;
// source variable
- const Variable* src_;
+ const VarNode* src_;
// counter of load,
// it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0};
// The current store to be inspected
- const Store* store_{nullptr};
+ const StoreNode* store_{nullptr};
};
// Planner to plan and rewrite memory allocation.
for (StorageEntry* e : attach_map_.at(nullptr)) {
// CHECK_EQ(e->scope.rank, 0);
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmt::make(
+ nest.emplace_back(AttrStmtNode::make(
e->alloc_var, attr::storage_scope,
- StringImm::make(e->scope.to_string()),
- Evaluate::make(0)));
+ StringImmNode::make(e->scope.to_string()),
+ EvaluateNode::make(0)));
nest.push_back(e->new_alloc);
}
}
}
return stmt;
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
- return Store::make(it->second->alloc_var,
+ return StoreNode::make(it->second->alloc_var,
op->value,
RemapIndex(op->value.dtype(), op->index, it->second),
op->predicate);
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
- return Load::make(op->dtype,
+ return LoadNode::make(op->dtype,
it->second->alloc_var,
RemapIndex(op->dtype, op->index, it->second),
op->predicate);
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->bits_offset != 0) {
return GetRef<Expr>(op);
}
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) {
return StmtExprMutator::VisitExpr_(op);
if (se->bits_offset != 0) {
offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
- return Call::make(
+ return CallNode::make(
op->dtype, op->name,
{op->args[0], se->alloc_var, offset, extent, op->args[4]},
op->call_type);
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::thread_extent ||
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AttrStmt>();
- return AttrStmt::make(
+ op = stmt.as<AttrStmtNode>();
+ return AttrStmtNode::make(
op->node, op->attr_key, op->value,
MakeAttach(svec, op->body));
} else {
}
} else if (op->attr_key == attr::volatile_scope) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AttrStmt>();
- auto it = alloc_map_.find(op->node.as<Variable>());
+ op = stmt.as<AttrStmtNode>();
+ auto it = alloc_map_.find(op->node.as<VarNode>());
if (it == alloc_map_.end()) return stmt;
- return AttrStmt::make(
+ return AttrStmtNode::make(
it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
CHECK(op->for_type != ForType::Vectorized)
<< "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<For>();
- return For::make(
+ op = stmt.as<ForNode>();
+ return ForNode::make(
op->loop_var, op->min, op->extent, op->for_type, op->device_api,
MakeAttach(svec, op->body));
} else {
}
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
return this->VisitStmt(op->body);
}
// The storage scope.
StorageScope scope;
// Allocs that shares this entry.
- std::vector<const Allocate*> allocs;
+ std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry*> merged_children;
// The replacement allocation, if any.
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
- std::vector<const Variable*> gen;
+ std::vector<const VarNode*> gen;
// variables we kill
- std::vector<const Variable*> kill;
+ std::vector<const VarNode*> kill;
};
Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmt::make(
+ nest.emplace_back(AttrStmtNode::make(
e->alloc_var, attr::storage_scope,
- StringImm::make(e->scope.to_string()),
- Evaluate::make(0)));
+ StringImmNode::make(e->scope.to_string()),
+ EvaluateNode::make(0)));
nest.push_back(e->new_alloc);
}
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
DataType alloc_type = e->allocs[0]->dtype;
- for (const Allocate* op : e->allocs) {
+ for (const AllocateNode* op : e->allocs) {
if (op->dtype.lanes() > alloc_type.lanes()) {
alloc_type = op->dtype;
}
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
- Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
+ Expr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
make_const(DataType::Int(32), 1));
- e->new_alloc = Allocate::make(
+ e->new_alloc = AllocateNode::make(
e->alloc_var, alloc_type, {sz},
- e->allocs[0]->condition, Evaluate::make(0));
+ e->allocs[0]->condition, EvaluateNode::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
} else {
// Build a merged allocation
Expr combo_size;
- for (const Allocate* op : e->allocs) {
- Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(DataType::Int(32), 1));
+ for (const AllocateNode* op : e->allocs) {
+ Expr sz = arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
auto nbits = op->dtype.bits() * op->dtype.lanes();
- if (const auto* imm = sz.as<IntImm>()) {
+ if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
<< " * " << nbits
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = ir::Simplify(combo_size);
- e->new_alloc = Allocate::make(
+ e->new_alloc = AllocateNode::make(
e->alloc_var, alloc_type, {combo_size}, const_true(),
- Evaluate::make(0));
+ EvaluateNode::make(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
- e->new_alloc = Allocate::make(
+ e->new_alloc = AllocateNode::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
- Evaluate::make(0));
+ EvaluateNode::make(0));
if (info.defined()) {
CHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
// Liveness analysis to find gen and kill point of each variable.
void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
// find kill point, do a reverse linear scan.
- std::unordered_set<const Variable*> touched;
+ std::unordered_set<const VarNode*> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry& s = seq[i - 1];
- for (const Variable* buffer : s.touched) {
+ for (const VarNode* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].kill.push_back(buffer);
int64_t offset = seq[i].scope_pair_offset;
if (offset < 0) continue;
const StmtEntry& s = seq[i + offset];
- for (const Variable* buffer : s.touched) {
+ for (const VarNode* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].gen.push_back(buffer);
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq,
- const std::unordered_map<const Variable*, AllocEntry>& alloc_info) {
- std::unordered_set<const Variable*> inplace_flag;
+ const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) {
+ std::unordered_set<const VarNode*> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
// specially handle this
bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
- for (const Variable* var : it->second.gen) {
+ for (const VarNode* var : it->second.gen) {
CHECK(alloc_info.count(var));
const AllocEntry& ae = alloc_info.at(var);
StorageEntry* dst_entry = nullptr;
if (detect_inplace) {
// only one inplace var for s.stmt
bool inplace_found = false;
- for (const Variable* src : it->second.kill) {
+ for (const VarNode* src : it->second.kill) {
if (!inplace_flag.count(src) && alloc_map_.count(src)) {
InplaceOpVerifier visitor;
StorageEntry* src_entry = alloc_map_.at(src);
}
}
// enter/exit new scope
- if (s.stmt->IsInstance<AttrStmt>()) {
- const auto* op = static_cast<const AttrStmt*>(s.stmt);
+ if (s.stmt->IsInstance<AttrStmtNode>()) {
+ const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread ||
attr::IsPragmaKey(op->attr_key)) {
} else {
CHECK(op->attr_key == attr::extern_scope);
}
- } else if (s.stmt->IsInstance<For>()) {
- const auto* op = static_cast<const For*>(s.stmt);
+ } else if (s.stmt->IsInstance<ForNode>()) {
+ const auto* op = static_cast<const ForNode*>(s.stmt);
if (op->for_type == ForType::Parallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
- for (const Variable* var : it->second.kill) {
+ for (const VarNode* var : it->second.kill) {
// skip space which are already replaced by inplace
if (!inplace_flag.count(var)) {
this->Free(var);
}
}
// Allocate new storage entry.
- StorageEntry* NewAlloc(const Allocate* op,
+ StorageEntry* NewAlloc(const AllocateNode* op,
const Object* attach_scope,
const StorageScope& scope,
size_t const_nbits) {
return e;
}
- StorageEntry* FindAlloc(const Allocate* op,
+ StorageEntry* FindAlloc(const AllocateNode* op,
const Object* attach_scope,
const StorageScope& scope) {
CHECK(op != nullptr);
return NewAlloc(op, attach_scope, scope, const_nbits);
}
// simulated free.
- void Free(const Variable* var) {
+ void Free(const VarNode* var) {
auto it = alloc_map_.find(var);
CHECK(it != alloc_map_.end());
StorageEntry* e = it->second;
// The allocation attach map
std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
- std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
+ std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
// analyzer
// if all its access is the same vector type.
class VectorAllocRewriter : public StmtExprMutator {
public:
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
UpdateTypeMap(op->buffer_var.get(), op->dtype);
return StmtExprMutator::VisitExpr_(op);
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
DataType dtype = op->args[0].dtype();
- const Variable* buffer = op->args[1].as<Variable>();
+ const VarNode* buffer = op->args[1].as<VarNode>();
UpdateTypeMap(buffer, dtype);
}
return StmtExprMutator::VisitExpr_(op);
}
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Allocate>();
+ op = stmt.as<AllocateNode>();
const auto& tvec = acc_map_[op->buffer_var.get()];
if (tvec.size() == 1 &&
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var, tvec[0], extents,
op->condition, op->body);
}
return stmt;
}
- void UpdateTypeMap(const Variable* buffer, DataType t) {
+ void UpdateTypeMap(const VarNode* buffer, DataType t) {
auto& tvec = acc_map_[buffer];
if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
tvec.push_back(t);
}
// Internal access map
- std::unordered_map<const Variable*, std::vector<DataType> > acc_map_;
+ std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
// internal analyzer
arith::Analyzer analyzer_;
};
std::unordered_set<const Object*> syncs_inserted_;
protected:
- bool Enabled(const Variable* buf,
+ bool Enabled(const VarNode* buf,
const StorageScope& scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(
- std::vector<StmtEntry> seq, const For* loop) final {
+ std::vector<StmtEntry> seq, const ForNode* loop) final {
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
- barrier = Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm::make(sync_scope_.to_string())},
- Call::Intrinsic));
+ barrier = EvaluateNode::make(
+ CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImmNode::make(sync_scope_.to_string())},
+ CallNode::Intrinsic));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
return StmtExprMutator::VisitStmt(stmt);
}
}
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count;
}
return StmtExprMutator::VisitExpr_(op);
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].write_count;
}
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
bool temp = true;
std::swap(temp, in_thread_env_);
std::swap(temp, in_thread_env_);
// first thread scope.
if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
- ret = InitGlobalBarrier(ret.as<AttrStmt>());
+ ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
num_blocks_ = Expr();
is_lead_ = Expr();
}
return ret;
} else if (op->attr_key == attr::storage_scope) {
- const Variable* buf = op->node.as<Variable>();
+ const VarNode* buf = op->node.as<VarNode>();
storage_scope_[buf] =
- StorageScope::make(op->value.as<StringImm>()->value);
+ StorageScope::make(op->value.as<StringImmNode>()->value);
return StmtExprMutator::VisitStmt_(op);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Call>();
+ op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
- const Variable* buffer_var = op->args[1].as<Variable>();
+ const VarNode* buffer_var = op->args[1].as<VarNode>();
Var var(GetRef<Var>(buffer_var));
- const IntImm* flag = op->args[4].as<IntImm>();
+ const IntImmNode* flag = op->args[4].as<IntImmNode>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].read_count;
int write_count{0};
};
// Get current storage scope.
- StorageScope GetScope(const Variable* buf) const {
+ StorageScope GetScope(const VarNode* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s;
s.rank = StorageRank::kGlobal;
return it->second;
}
// private functions.
- Stmt InitGlobalBarrier(const AttrStmt* op) {
+ Stmt InitGlobalBarrier(const AttrStmtNode* op) {
CHECK(op != nullptr);
- Array<Expr> pargs = {StringImm::make(runtime::symbol::tvm_prepare_global_barrier)};
- Stmt prep = Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic));
+ Array<Expr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
+ Stmt prep = EvaluateNode::make(
+ CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
if (e.read_count != 0 && e.write_count != 0) {
- body = AttrStmt::make(kv.first, attr::volatile_scope, 1, body);
+ body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body);
}
}
rw_stats_.clear();
- Stmt kinit = Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
+ Stmt kinit = EvaluateNode::make(
+ CallNode::make(
+ DataType::Int(32),
+ intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
body = SeqStmt({kinit, body});
- body = AttrStmt::make(
+ body = AttrStmtNode::make(
op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
}
if (!num_blocks_.defined()) {
CHECK(!is_lead_.defined());
num_work_dim_ = thread_extents_.size();
- for (const AttrStmt* attr : thread_extents_) {
+ for (const AttrStmtNode* attr : thread_extents_) {
IterVar iv = Downcast<IterVar>(attr->node);
runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
if (s.rank == 0) {
} else {
CHECK_EQ(num_work_dim_, thread_extents_.size());
}
- return Evaluate::make(
- Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm::make(sync_scope_.to_string()),
+ return EvaluateNode::make(
+ CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImmNode::make(sync_scope_.to_string()),
is_lead_, num_blocks_},
- Call::Intrinsic));
+ CallNode::Intrinsic));
}
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
// The storage scope of each buffer
- std::unordered_map<const Variable*, StorageScope> storage_scope_;
+ std::unordered_map<const VarNode*, StorageScope> storage_scope_;
// The read write statistics of storage
std::unordered_map<VarExpr, Entry, ObjectHash, ObjectEqual> rw_stats_;
// The statistics for global barrier
bool in_thread_env_{false};
// memorized results
- std::vector<const AttrStmt*> thread_extents_;
+ std::vector<const AttrStmtNode*> thread_extents_;
size_t num_work_dim_{0};
Expr num_blocks_;
Expr is_lead_;
}
Expr unpack_type_cast(const Expr &input, const DataType &target_type) {
- auto cast = input.as<Cast>();
+ auto cast = input.as<CastNode>();
if (cast == nullptr) {
return input;
} else if (cast->dtype == target_type) {
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
- storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+ storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else {
StmtVisitor::VisitStmt_(op);
}
}
- void VisitStmt_(const Provide* op) final {
+ void VisitStmt_(const ProvideNode* op) final {
StmtVisitor::VisitStmt_(op);
auto it = buf_map_.find(TensorKey{op->func, op->value_index});
if (it == buf_map_.end()) {
}
}
- void VisitStmt_(const Realize* op) final {
+ void VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
if (!buf_map_.at(key).external) {
};
// Check whether the storage scope is local
- bool check_local_buffer_(const Call* op, BufferInfo* bi) {
- if (op->call_type == Call::Halide) {
+ bool check_local_buffer_(const CallNode* op, BufferInfo* bi) {
+ if (op->call_type == CallNode::Halide) {
auto it = storage_scope_.find(op->func.get());
if (it == storage_scope_.end()) {
return false;
}
// Do the pattern matching
- bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) {
- auto* add = op->value.as<Add>();
+ bool mma_sync_match_(const ProvideNode* op, BufferInfo store_buffer) {
+ auto* add = op->value.as<AddNode>();
if (add == nullptr) {
return false;
}
- auto* load_c = add->a.as<Call>();
+ auto* load_c = add->a.as<CallNode>();
BufferInfo buffer_c;
if (!check_local_buffer_(load_c, &buffer_c)
|| !buffer_c.same_as(store_buffer)
return false;
}
- auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<Mul>();
+ auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<MulNode>();
if (mul == nullptr) {
return false;
}
auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
- auto load_a = load_a_expr.as<Call>();
+ auto load_a = load_a_expr.as<CallNode>();
BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a)
|| !(buffer_a.dtype == DataType::Float(16) ||
}
auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
- auto load_b = load_b_expr.as<Call>();
+ auto load_b = load_b_expr.as<CallNode>();
BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b)
|| !(buffer_b.dtype == DataType::Float(16) ||
std::unordered_map<TensorKey, BufferInfo> buf_map_;
std::unordered_map<const Object*, std::string> storage_scope_;
- std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+ std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
std::unordered_set<std::string> frag_reg_;
bool matched_{false};
public:
BodyVisitor() {}
- void VisitExpr_(const Reduce* op) final {
- auto* comm_add = op->combiner->result[0].as<Add>();
+ void VisitExpr_(const ReduceNode* op) final {
+ auto* comm_add = op->combiner->result[0].as<AddNode>();
if (comm_add == nullptr || op->combiner->result.size() > 1) {
return;
}
for (Expr source : op->source) {
- auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<Mul>();
- auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<Mul>();
+ auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<MulNode>();
+ auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<MulNode>();
if (mul_0 == nullptr && mul_1 == nullptr) {
continue;
}
}
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
args_.insert(std::make_pair(op->name, op->args));
}
if (axis.size() < 2 || reduce_axis.size() != 1) {
continue;
}
- const Variable* axis_var[2];
- const Variable* reduce_axis_var;
- axis_var[0] = axis[axis.size()-2]->var.as<Variable>();
- axis_var[1] = axis[axis.size()-1]->var.as<Variable>();
- reduce_axis_var = reduce_axis[0]->var.as<Variable>();
+ const VarNode* axis_var[2];
+ const VarNode* reduce_axis_var;
+ axis_var[0] = axis[axis.size()-2]->var.as<VarNode>();
+ axis_var[1] = axis[axis.size()-1]->var.as<VarNode>();
+ reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
BodyVisitor body_visitor;
for (Expr expr : compute->body) {
if (args.size() < 2) {
continue;
}
- const Variable* var0 = args[args.size() - 2].as<Variable>();
- const Variable* var1 = args[args.size() - 1].as<Variable>();
+ const VarNode* var0 = args[args.size() - 2].as<VarNode>();
+ const VarNode* var1 = args[args.size() - 1].as<VarNode>();
if (var0 == nullptr || var1 == nullptr) {
continue;
}
for (auto &mma_sync : mma_sync_) {
auto &operands = mma_sync.second;
- auto* load_a = operands[0].as<Call>();
- auto* load_b = operands[1].as<Call>();
+ auto* load_a = operands[0].as<CallNode>();
+ auto* load_b = operands[1].as<CallNode>();
auto input0 = simplify_name(buf_name_.find(load_a)->second);
auto input1 = simplify_name(buf_name_.find(load_b)->second);
auto it0 = matrix_abc_.find(input0);
private:
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
- std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+ std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
};
public:
IndexVisitor() {}
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
loop_scaling_.insert(std::make_pair(op, scaling_factor_));
}
friend class TensorCoreIRMutator;
private:
- std::unordered_map<const Variable*, unsigned> loop_scaling_;
+ std::unordered_map<const VarNode*, unsigned> loop_scaling_;
unsigned scaling_factor_{0};
};
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
- if (const IntImm* value = op->value.as<IntImm>()) {
+ if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
std::make_pair(
op->node.as<IterVarNode>()->var->name_hint,
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
- storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+ storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
Tensor tensor = Downcast<Tensor>(op->node);
- const Call* tuple = op->value.as<Call>();
+ const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
- size_t dim = tuple->args[0].as<IntImm>()->value;
+ size_t dim = tuple->args[0].as<IntImmNode>()->value;
if (dim >= vinfo.size()) {
vinfo.resize(dim + 1);
}
- vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
- vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+ vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+ vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
this->VisitStmt(op->body);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
- void VisitStmt_(const Provide* op) final {
+ void VisitStmt_(const ProvideNode* op) final {
StmtExprVisitor::VisitStmt_(op);
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
- const IntImm* shape = bi.shape[i].as<IntImm>();
+ const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- Expr stride = IntImm::make(DataType::Int(32), 1);
+ Expr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
- stride = Mul::make(stride, bi.shape[j]);
+ stride = MulNode::make(stride, bi.shape[j]);
}
strides.push_back(stride);
}
strides_.insert(std::make_pair(key.GetName(), strides));
if (frag_reg_.count(bi.name)) {
- Expr dst = Call::make(bi.dtype,
+ Expr dst = CallNode::make(bi.dtype,
bi.name,
op->args,
- Call::Halide,
+ CallNode::Halide,
op->func,
0);
frag_load_.insert(std::make_pair(op, dst));
std::vector<int> tile_size;
for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
index_visitor.scaling_factor_ = 16;
- if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+ if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
tile_size.push_back(shape->value);
index_visitor.scaling_factor_ = shape->value;
} else {
}
}
- const Call* value = op->value.as<Call>();
+ const CallNode* value = op->value.as<CallNode>();
if (value != nullptr && frag_reg_.count(value->name)) {
- Expr dst = Call::make(bi.dtype,
+ Expr dst = CallNode::make(bi.dtype,
bi.name,
op->args,
- Call::Halide,
+ CallNode::Halide,
op->func,
0);
frag_store_.insert(std::make_pair(op, dst));
}
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
- if (op->call_type == Call::Halide) {
+ if (op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
return;
}
for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
- const IntImm* shape = bi.shape[i].as<IntImm>();
+ const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
if (shape == nullptr || shape->value % 16 != 0) {
invalid_ = true;
return;
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- Expr stride = IntImm::make(DataType::Int(32), 1);
+ Expr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
- stride = Mul::make(stride, bi.shape[j]);
+ stride = MulNode::make(stride, bi.shape[j]);
}
strides.push_back(stride);
}
}
for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
index_visitor.scaling_factor_ = 16;
- if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+ if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
index_visitor.scaling_factor_ = shape->value;
}
auto index = rel_index[i];
}
}
- void VisitStmt_(const Realize* op) final {
+ void VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<std::string, Array<Expr>> strides_;
- std::unordered_map<const Provide*, Expr> frag_load_;
- std::unordered_map<const Provide*, Expr> frag_store_;
+ std::unordered_map<const ProvideNode*, Expr> frag_load_;
+ std::unordered_map<const ProvideNode*, Expr> frag_store_;
std::unordered_map<std::string, int> thread_extent_;
IndexVisitor index_visitor;
Tile warp_tile_;
public:
explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Variable>();
+ op = expr.as<VarNode>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
- Expr zero = IntImm::make(DataType::Int(32), 0);
+ Expr zero = IntImmNode::make(DataType::Int(32), 0);
return zero;
}
if (op->name_hint == "threadIdx.y") {
- Expr div = Div::make(expr, warp_y_);
- Expr mul = Mul::make(div, warp_y_);
+ Expr div = DivNode::make(expr, warp_y_);
+ Expr mul = MulNode::make(div, warp_y_);
return mul;
}
}
warp_tile_(buffer_analyser.warp_tile_),
warp_threads_y_(buffer_analyser.warp_threads_y_) {}
- Stmt VisitStmt_(const Realize* op) final {
+ Stmt VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
bounds_[key] = op->bounds;
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Realize>();
+ op = stmt.as<RealizeNode>();
if (op != nullptr) {
if (!frag_reg_.count(key.GetName())) {
return stmt;
new_bounds.push_back(Range::make_by_min_extent(
op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
- return Realize::make(op->func, op->value_index,
+ return RealizeNode::make(op->func, op->value_index,
op->dtype, new_bounds,
op->condition, op->body);
}
return stmt;
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
auto node = op->node.as<OperationNode>();
<< "Cannot find matrix info for " << node->name;
auto matrix_abc = "wmma." + it->second;
Stmt body = this->VisitStmt(op->body);
- return AttrStmt::make(op->node,
+ return AttrStmtNode::make(op->node,
op->attr_key,
matrix_abc,
body);
return stmt;
}
- Stmt VisitStmt_(const Provide* op) final {
+ Stmt VisitStmt_(const ProvideNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
const auto &operands = it->second;
Expr a = operands[0];
- auto ca = a.as<Call>();
+ auto ca = a.as<CallNode>();
Expr b = operands[1];
- auto cb = b.as<Call>();
+ auto cb = b.as<CallNode>();
Expr c = operands[2];
- auto cc = c.as<Call>();
+ auto cc = c.as<CallNode>();
ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
ObjectPtr<BufferNode> buffer_node_b = make_object<BufferNode>();
(const Buffer &buffer) {
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
- return Evaluate::make(
- Call::make(DataType::Handle(),
+ return EvaluateNode::make(
+ CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
- Call::Intrinsic));
+ CallNode::Intrinsic));
};
auto call_add_c =
auto it2 = frag_load_.find(op);
if (it2 != frag_load_.end()) {
Expr dst = it2->second;
- if (op->value.as<FloatImm>() != nullptr ||
- op->value.as<IntImm>() != nullptr) {
- auto call = dst.as<Call>();
+ if (op->value.as<FloatImmNode>() != nullptr ||
+ op->value.as<IntImmNode>() != nullptr) {
+ auto call = dst.as<CallNode>();
auto fill_fragment_call =
[this, &op](const Buffer &buffer) {
- return Evaluate::make(
- Call::make(DataType::Handle(),
+ return EvaluateNode::make(
+ CallNode::make(DataType::Handle(),
intrinsic::tvm_fill_fragment,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, op->value},
- Call::Intrinsic));
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
fill_fragment_call, call->dtype);
}
- const Call* value = op->value.as<Call>();
+ const CallNode* value = op->value.as<CallNode>();
CHECK(value != nullptr)
<< "Can only load fragment from a buffer";
Expr stride = strides[strides.size()-2];
// thread index unification inside a warp
- Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
+ Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
Expr mutated_value = thread_idx_mutator(op->value);
- Expr src = Call::make(value->dtype,
+ Expr src = CallNode::make(value->dtype,
"&",
{mutated_value},
- Call::Extern);
+ CallNode::Extern);
- auto call = dst.as<Call>();
+ auto call = dst.as<CallNode>();
Expr matrix_major;
auto iter2 = matrix_major_.find(simplify_name(call->name));
CHECK(iter2 != matrix_major_.end())
<< "Can not determine matrix major for " << call->name;
if (iter2->second == "col_major") {
- matrix_major = StringImm::make("col_major");
+ matrix_major = StringImmNode::make("col_major");
} else if (iter2->second == "row_major") {
- matrix_major = StringImm::make("row_major");
+ matrix_major = StringImmNode::make("row_major");
} else {
LOG(FATAL) << "invalid matrix major for " << call->name;
}
auto load_matrix_call =
[this, &src, &stride, &matrix_major](const Buffer &buffer) {
- return Evaluate::make(
- Call::make(DataType::Handle(),
+ return EvaluateNode::make(
+ CallNode::make(DataType::Handle(),
intrinsic::tvm_load_matrix_sync,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, src, stride, matrix_major},
- Call::Intrinsic));
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
Expr dst = it3->second;
// thread index unification inside a warp
- Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
+ Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
- dst = Call::make(DataType::Handle(),
+ dst = CallNode::make(DataType::Handle(),
"&",
{dst},
- Call::Extern);
+ CallNode::Extern);
- auto call = op->value.as<Call>();
+ auto call = op->value.as<CallNode>();
auto store_matrix_call =
[this, &dst, &stride](const Buffer &buffer) {
- return Evaluate::make(
- Call::make(DataType::Handle(),
+ return EvaluateNode::make(
+ CallNode::make(DataType::Handle(),
intrinsic::tvm_store_matrix_sync,
{buffer->data,
warp_tile_.m, warp_tile_.n, warp_tile_.k,
buffer->elem_offset, dst, stride,
- StringImm::make("col_major")},
- Call::Intrinsic));
+ StringImmNode::make("col_major")},
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
return stmt;
}
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<For>();
+ op = stmt.as<ForNode>();
if (op != nullptr) {
auto it = loop_scaling_.find(op->loop_var.get());
if (it != loop_scaling_.end()) {
int scale_factor = it->second;
int scaled_extent_value = 1;
- if (const IntImm *ori_extent = op->extent.as<IntImm>()) {
+ if (const IntImmNode *ori_extent = op->extent.as<IntImmNode>()) {
int ori_extent_value = ori_extent->value;
scaled_extent_value = ori_extent_value / scale_factor;
}
Expr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
- stmt = For::make(op->loop_var, op->min, scaled_extent, op->for_type,
+ stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type,
op->device_api, op->body);
}
}
return tile_size;
}
- Stmt add_buffer_bind_scope_(const Call* call,
+ Stmt add_buffer_bind_scope_(const CallNode* call,
const ObjectPtr<BufferNode> &buffer_node, const TensorKey &key,
const std::function<Stmt(const Buffer &buffer)> &call_back,
DataType datatype) {
Array<Expr> strides;
for (size_t i = 1; i < shape.size(); ++i) {
- Expr stride = IntImm::make(DataType::Int(32), 1);
+ Expr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = shape.size() - 1; j >= i; --j) {
- stride = Mul::make(stride, shape[j]);
+ stride = MulNode::make(stride, shape[j]);
}
strides.push_back(stride);
}
strides.push_back(make_const(DataType::Int(32), 1));
- Expr elem_offset = IntImm::make(DataType::Int(32), 0);
+ Expr elem_offset = IntImmNode::make(DataType::Int(32), 0);
CHECK_EQ(call->args.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
- elem_offset = Add::make(
- elem_offset, Mul::make(
- strides[i], Sub::make(call->args[i], min_bound[i])));
+ elem_offset = AddNode::make(
+ elem_offset, MulNode::make(
+ strides[i], SubNode::make(call->args[i], min_bound[i])));
}
auto it2 = matrix_abc_.find(simplify_name(call->name));
CHECK(it2 != matrix_abc_.end())
<< "Cannot find matrix info for " << call->name;
- buffer_node->data = Variable::make(DataType::Handle(), call->name);
+ buffer_node->data = VarNode::make(DataType::Handle(), call->name);
buffer_node->name = call->name;
buffer_node->scope = "wmma." + it2->second;
buffer_node->dtype = datatype;
args.push_back(call->args[i]);
args.push_back(shape[i]);
}
- auto tuple = Call::make(DataType::Handle(),
+ auto tuple = CallNode::make(DataType::Handle(),
intrinsic::tvm_tuple,
args,
- Call::Intrinsic);
+ CallNode::Intrinsic);
Array<ObjectRef> node = {buffer, tensor};
- return AttrStmt::make(node,
+ return AttrStmtNode::make(node,
"buffer_bind_scope",
tuple,
call_back(buffer));
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
- std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+ std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
std::unordered_map<std::string, Array<Expr>> strides_;
std::unordered_set<std::string> frag_reg_;
- std::unordered_map<const Variable*, unsigned> loop_scaling_;
- std::unordered_map<const Provide*, Expr> frag_load_;
- std::unordered_map<const Provide*, Expr> frag_store_;
+ std::unordered_map<const VarNode*, unsigned> loop_scaling_;
+ std::unordered_map<const ProvideNode*, Expr> frag_load_;
+ std::unordered_map<const ProvideNode*, Expr> frag_store_;
std::unordered_map<TensorKey, Region> bounds_;
Tile warp_tile_;
int warp_threads_y_{-1};
explicit_unroll_(explicit_unroll) {
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
}
}
- Stmt VisitStmt_(const For* op) {
+ Stmt VisitStmt_(const ForNode* op) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<For>();
+ op = stmt.as<ForNode>();
int value = GetExtent(op);
// condition for auto unroll
bool auto_unroll = (
} else {
if (auto_unroll) {
if (op->for_type != ForType::Unrolled) {
- return For::make(
+ return ForNode::make(
op->loop_var, op->min, op->extent,
ForType::Unrolled, op->device_api, op->body);
}
}
}
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
++step_count_;
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const Evaluate* op) final {
+ Stmt VisitStmt_(const EvaluateNode* op) final {
++step_count_;
return StmtExprMutator::VisitStmt_(op);
}
return StmtMutator::VisitSeqStmt_(op, false, fmutate);
}
- Stmt Unroll(const For* op) {
+ Stmt Unroll(const ForNode* op) {
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
- if (value == 0) return Evaluate::make(0);
+ if (value == 0) return EvaluateNode::make(0);
Stmt body = op->body;
Map<Var, Expr> vmap;
Array<Stmt> unrolled;
private:
// returns the extent of the loop if it's a constant integer, otherwise return -1
- int GetExtent(const For* op) {
+ int GetExtent(const ForNode* op) {
// constant folding.
Expr extent = ir::Simplify(op->extent);
- const IntImm *v1 = extent.as<IntImm>();
- const UIntImm *v2 = extent.as<UIntImm>();
+ const IntImmNode *v1 = extent.as<IntImmNode>();
+ const UIntImmNode *v2 = extent.as<UIntImmNode>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
Stmt UnrollLoopExplicitly(Stmt stmt) {
- const For* op = stmt.as<For>();
+ const ForNode* op = stmt.as<ForNode>();
if (!op) {
LOG(FATAL) << "attempted to unroll a non-loop statement";
}
inline Expr BroadcastTo(Expr e, int lanes) {
if (e.dtype().lanes() == lanes) return e;
- if (const Broadcast* op = e.as<Broadcast>()) {
+ if (const BroadcastNode* op = e.as<BroadcastNode>()) {
if (lanes % op->lanes == 0) {
- return Broadcast::make(op->value, lanes);
+ return BroadcastNode::make(op->value, lanes);
}
}
CHECK_EQ(e.dtype().lanes(), 1)
<< "Cannot broadcast lane=" << e.dtype().lanes()
<< " to " << lanes;
- return Broadcast::make(e, lanes);
+ return BroadcastNode::make(e, lanes);
}
// Rewrite vectorized allocation access
//
class VecAllocAccess : public StmtExprMutator {
public:
- VecAllocAccess(const Variable* buf, Var var, int var_lanes)
+ VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
// Load
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<Load>();
+ op = expr.as<LoadNode>();
if (op->buffer_var.get() == buf_) {
- return Load::make(op->dtype, op->buffer_var,
+ return LoadNode::make(op->dtype, op->buffer_var,
op->index * var_lanes_ + var_,
op->predicate);
} else {
}
}
// Store
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<Store>();
+ op = stmt.as<StoreNode>();
if (op->buffer_var.get() == buf_) {
- return Store::make(op->buffer_var,
+ return StoreNode::make(op->buffer_var,
op->value,
op->index * var_lanes_ + var_,
op->predicate);
private:
// buffer var
- const Variable* buf_;
+ const VarNode* buf_;
// variable to be replaced
Var var_;
// the lanes.
public:
Vectorizer(Var var, int var_lanes)
: var_(var), var_lanes_(var_lanes) {
- ramp_ = Ramp::make(0, 1, var_lanes);
+ ramp_ = RampNode::make(0, 1, var_lanes);
}
Stmt VisitStmt(const Stmt& stmt) final {
}
}
- Expr VisitExpr_(const Add* op) final {
+ Expr VisitExpr_(const AddNode* op) final {
return AddSubVec(op);
}
- Expr VisitExpr_(const Sub* op) final {
+ Expr VisitExpr_(const SubNode* op) final {
return AddSubVec(op);
}
- Expr VisitExpr_(const Mul* op) final {
+ Expr VisitExpr_(const MulNode* op) final {
Expr a = this->VisitExpr(op->a);
Expr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
- const Ramp* b_ramp = b.as<Ramp>();
- const Ramp* a_ramp = a.as<Ramp>();
+ const RampNode* b_ramp = b.as<RampNode>();
+ const RampNode* a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
- return Ramp::make(
+ return RampNode::make(
a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
}
if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
- return Ramp::make(
+ return RampNode::make(
b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
}
}
- return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+ return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec(op);
}
- Expr VisitExpr_(const Div* op) final {
+ Expr VisitExpr_(const DivNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const Mod* op) final {
+ Expr VisitExpr_(const ModNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const FloorDiv* op) final {
+ Expr VisitExpr_(const FloorDivNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const FloorMod* op) final {
+ Expr VisitExpr_(const FloorModNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const Min* op) final {
+ Expr VisitExpr_(const MinNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const Max* op) final {
+ Expr VisitExpr_(const MaxNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const EQ* op) final {
+ Expr VisitExpr_(const EQNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const NE* op) final {
+ Expr VisitExpr_(const NENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const LT* op) final {
+ Expr VisitExpr_(const LTNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const LE* op) final {
+ Expr VisitExpr_(const LENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const GT* op) final {
+ Expr VisitExpr_(const GTNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const GE* op) final {
+ Expr VisitExpr_(const GENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const And* op) final {
+ Expr VisitExpr_(const AndNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const Or* op) final {
+ Expr VisitExpr_(const OrNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const Ramp* op) final {
+ Expr VisitExpr_(const RampNode* op) final {
Expr base = this->VisitExpr(op->base);
Expr stride = this->VisitExpr(op->stride);
if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
- const Ramp* base_ramp = base.as<Ramp>();
+ const RampNode* base_ramp = base.as<RampNode>();
if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
- return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
+ return RampNode::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
}
}
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
Array<Expr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(
- Ramp::make(Shuffle::make_extract_element(base, i),
- Shuffle::make_extract_element(stride, i),
+ RampNode::make(ShuffleNode::make_extract_element(base, i),
+ ShuffleNode::make_extract_element(stride, i),
op->lanes));
}
- return Shuffle::make_concat(elems);
+ return ShuffleNode::make_concat(elems);
}
- Expr VisitExpr_(const Select *op) final {
+ Expr VisitExpr_(const SelectNode *op) final {
Expr cond = this->VisitExpr(op->condition);
Expr t = this->VisitExpr(op->true_value);
Expr f = this->VisitExpr(op->false_value);
int lanes = std::max(std::max(
cond.dtype().lanes(),
t.dtype().lanes()), f.dtype().lanes());
- return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
+ return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
- Expr VisitExpr_(const Cast *op) final {
+ Expr VisitExpr_(const CastNode *op) final {
Expr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
- return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value);
+ return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
// Variable
- Expr VisitExpr_(const Variable* v) final {
+ Expr VisitExpr_(const VarNode* v) final {
if (v == var_.get()) {
return ramp_;
} else if (lets_.count(v)) {
}
}
// IfThenElse expr
- Expr MutateIfThenElseExpr_(const Call *op) {
+ Expr MutateIfThenElseExpr_(const CallNode *op) {
Expr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_vector()) {
need_scalarize_ = true;
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
f = BroadcastTo(f, lanes);
- return Call::make(
+ return CallNode::make(
op->dtype.with_lanes(lanes), op->name,
{cond, t, f}, op->call_type, op->func, op->value_index);
}
}
// Call
- Expr VisitExpr_(const Call* op) final {
+ Expr VisitExpr_(const CallNode* op) final {
if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op);
}
if (op->args.same_as(new_args)) {
return GetRef<Expr>(op);
} else {
- return Call::make(
+ return CallNode::make(
op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
}
} else {
if (op->args.same_as(new_args)) {
return GetRef<Expr>(op);
} else {
- return Call::make(
+ return CallNode::make(
op->dtype.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
}
}
}
// Load
- Expr VisitExpr_(const Load* op) final {
+ Expr VisitExpr_(const LoadNode* op) final {
Expr index = this->VisitExpr(op->index);
Expr pred = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
return GetRef<Expr>(op);
} else {
int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
- return Load::make(
+ return LoadNode::make(
op->dtype.with_lanes(lanes),
op->buffer_var,
BroadcastTo(index, lanes),
}
}
// Let
- Expr VisitExpr_(const Let* op) final {
+ Expr VisitExpr_(const LetNode* op) final {
Expr value = this->VisitExpr(op->value);
CHECK(!lets_.count(op->var.get())) << "not SSA";
if (value.dtype().lanes() != op->value.dtype().lanes()) {
Var v(op->var->name_hint, value.dtype());
lets_[op->var.get()] = v;
- return Let::make(v, value, this->VisitExpr(op->body));
+ return LetNode::make(v, value, this->VisitExpr(op->body));
} else {
Expr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
- return Let::make(op->var, value, body);
+ return LetNode::make(op->var, value, body);
}
}
}
// Provide
- Stmt VisitStmt_(const Provide* op) final {
+ Stmt VisitStmt_(const ProvideNode* op) final {
Expr new_value = this->VisitExpr(op->value);
int lane = new_value.dtype().lanes();
Array<Expr> new_args = MutateArray(op->args, &lane);
return GetRef<Stmt>(op);
} else {
new_value = BroadcastTo(new_value, lane);
- return Provide::make(op->func, op->value_index, new_value, new_args);
+ return ProvideNode::make(op->func, op->value_index, new_value, new_args);
}
}
// Store
- Stmt VisitStmt_(const Store* op) final {
+ Stmt VisitStmt_(const StoreNode* op) final {
Expr value = this->VisitExpr(op->value);
Expr index = this->VisitExpr(op->index);
Expr pred = this->VisitExpr(op->predicate);
} else {
int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
lanes = std::max(lanes, pred.dtype().lanes());
- return Store::make(op->buffer_var,
+ return StoreNode::make(op->buffer_var,
BroadcastTo(value, lanes),
BroadcastTo(index, lanes),
BroadcastTo(pred, lanes));
}
}
// For
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Vectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return For::make(
+ return ForNode::make(
op->loop_var, op->min, extent,
op->for_type, op->device_api, body);
}
}
// IfThenElse
- Stmt VisitStmt_(const IfThenElse* op) final {
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
CHECK(!op->condition.dtype().is_vector());
Expr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElse::make(condition, then_case, else_case);
+ return IfThenElseNode::make(condition, then_case, else_case);
}
}
// LetStmt
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize";
return Scalarize(GetRef<Stmt>(op));
}
// Allocate
- Stmt VisitStmt_(const Allocate* op) final {
+ Stmt VisitStmt_(const AllocateNode* op) final {
if (op->new_expr.defined()) {
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(GetRef<Stmt>(op));
Stmt body = VecAllocAccess(
op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
- return Allocate::make(
+ return AllocateNode::make(
op->buffer_var, op->dtype,
extents, condition, body,
op->new_expr, op->free_function);
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, Expr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
+ return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
private:
// flag to mark requirment of scalarization.
bool need_scalarize_{false};
// The lets
- std::unordered_map<const Variable*, Expr> lets_;
+ std::unordered_map<const VarNode*, Expr> lets_;
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) {
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
- const Ramp* b_ramp = b.as<Ramp>();
- const Ramp* a_ramp = a.as<Ramp>();
+ const RampNode* b_ramp = b.as<RampNode>();
+ const RampNode* a_ramp = a.as<RampNode>();
if (a.dtype().lanes() == 1 && b_ramp) {
- return Ramp::make(
+ return RampNode::make(
arith::Compute<T>(a, b_ramp->base),
arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
b_ramp->lanes);
}
if (b.dtype().lanes() == 1 && a_ramp) {
- return Ramp::make(
+ return RampNode::make(
arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
class LoopVectorizer : public StmtMutator {
public:
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
int lanes = 0;
class VectorizeSkipper : public StmtMutator {
public:
- Stmt VisitStmt_(const For* op) final {
+ Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
- op = stmt.as<For>();
+ op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
- return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
+ return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
op->body);
} else {
return stmt;
return is_compact_;
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
StmtVisitor::VisitStmt_(op);
if (op->attr_key == attr::buffer_bind_scope) {
is_compact_ = true;
return valid_;
}
- void VisitStmt_(const ProducerConsumer* op) final {
+ void VisitStmt_(const ProducerConsumerNode* op) final {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
}
}
- void VisitStmt_(const Allocate* op) final {
+ void VisitStmt_(const AllocateNode* op) final {
StmtVisitor::VisitStmt_(op);
// visit an allocation of a buffer in shared memory, record its size
if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
}
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
- std::string op_value = op->value.as<StringImm>()->value;
+ std::string op_value = op->value.as<StringImmNode>()->value;
if (op_value == "local") {
- visited_local_buffers_.insert(op->node.as<tvm::Variable>());
+ visited_local_buffers_.insert(op->node.as<tvm::VarNode>());
} else if (op_value == "shared") {
- visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
+ visited_shared_buffers_.insert(op->node.as<tvm::VarNode>());
}
} else if (op->attr_key == attr::thread_extent) {
VarExpr var = op->node.as<tvm::IterVarNode>()->var;
- const auto *extent = op->value.as<IntImm>();
+ const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
// record the number of threads in a block
private:
int nest_level_{0};
- std::unordered_set<const tvm::Variable *> visited_local_buffers_;
- std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
+ std::unordered_set<const tvm::VarNode *> visited_local_buffers_;
+ std::unordered_set<const tvm::VarNode *> visited_shared_buffers_;
std::unordered_set<std::string> visited_threads_;
size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
int64_t max_thread_z = INT64_MAX;
for (auto iter : constraints) {
- const IntImm* val = iter.second.as<IntImm>();
+ const IntImmNode* val = iter.second.as<IntImmNode>();
if (iter.first == "max_local_memory_per_block")
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
StmtExprVisitor::VisitStmt(n);
}
- void VisitStmt_(const LetStmt* op) final {
+ void VisitStmt_(const LetStmtNode* op) final {
// Book keep definitions
defs_[op->var.get()] = op->value;
return StmtExprVisitor::VisitStmt_(op);
}
- void VisitStmt_(const AttrStmt* op) final {
+ void VisitStmt_(const AttrStmtNode* op) final {
if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope)) {
EnterThreadEnv();
}
}
- void VisitStmt_(const ProducerConsumer* op) final {
+ void VisitStmt_(const ProducerConsumerNode* op) final {
EnterProducerConsumer(op);
StmtExprVisitor::VisitStmt_(op);
ExitProducerConsumer();
}
- void VisitExpr_(const Load* op) final {
+ void VisitExpr_(const LoadNode* op) final {
HandleLoadStoreToVariable(op->buffer_var);
return StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const Store* op) final {
+ void VisitStmt_(const StoreNode* op) final {
HandleLoadStoreToVariable(op->buffer_var);
return StmtExprVisitor::VisitStmt_(op);
}
//@}
/// Check if the value of a Variable comes from function argument.
- bool IsFromFunctionArgs(const Variable *var) const {
- const Variable *V = var;
+ bool IsFromFunctionArgs(const VarNode *var) const {
+ const VarNode *V = var;
while (true) {
CHECK(V) << "Invalid Variable\n";
// Get the first argument of tvm_struct_get, and continue.
const auto &iter = defs_.find(V);
if (iter == defs_.end()) return false;
- const Call *C = iter->second.as<const Call>();
+ const CallNode *C = iter->second.as<const CallNode>();
if (!C || C->name != intrinsic::tvm_struct_get) return false;
- V = C->args[0].as<Variable>();
+ V = C->args[0].as<VarNode>();
}
return false;
}
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
bool InProducerConsumer() const { return pc_ != nullptr; }
- const ProducerConsumer *GetCurrentProducerConsumer() const { return pc_; }
- void EnterProducerConsumer(const ProducerConsumer *pc) { this->pc_ = pc; }
+ const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; }
+ void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; }
void ExitProducerConsumer() { pc_ = nullptr; }
void SetFailure() { failure_ = true; }
//@}
/// Status of visitor
//@{
bool in_thread_env_{false};
- const ProducerConsumer *pc_{nullptr};
+ const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
LoweredFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
- std::unordered_map<const Variable *, Expr> defs_; ///< Variable definitions
+ std::unordered_map<const VarNode *, Expr> defs_; ///< Variable definitions
};
} // namespace
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
for (auto expr : names) {
- auto key = expr.as<ir::StringImm>()->value;
+ auto key = expr.as<ir::StringImmNode>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
}
return ret;
Array<tvm::Expr> ListParamNames() {
Array<tvm::Expr> ret;
for (const auto& kv : params_) {
- ret.push_back(ir::StringImm::make(kv.first));
+ ret.push_back(ir::StringImmNode::make(kv.first));
}
return ret;
}
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
- res.push_back(ir::IntImm::make(DataType::Int(32), *pval));
- } else if (val->IsInstance<ir::Any>()) {
- res.push_back(val.as<ir::Any>()->ToVar());
+ res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval));
+ } else if (val->IsInstance<ir::AnyNode>()) {
+ res.push_back(val.as<ir::AnyNode>()->ToVar());
} else {
res.push_back(val);
}
// set inputs
for (auto param : prim_func->params) {
int state = param_states_[param];
- cache_node->shape_func_param_states.push_back(IntImm::make(DataType::Int(32), state));
+ cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state));
if (state & kNeedInputData) {
for (auto t : param_data_[param]) {
cache_node->inputs.push_back(t);
auto ret_type = call_node->checked_type();
Array<IndexExpr> out_ndims;
if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
- out_ndims.push_back(IntImm::make(DataType::Int(32), ttype->shape.size()));
+ out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
} else {
auto rtype = ret_type.as<TupleTypeNode>();
// TODO(@icemelon): Allow recursive tuple
for (size_t i = 0; i < rtype->fields.size(); ++i) {
auto ttype = rtype->fields[i].as<TensorTypeNode>();
CHECK(ttype);
- out_ndims.push_back(IntImm::make(DataType::Int(32), ttype->shape.size()));
+ out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
}
}
// Call shape function
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
- const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
+ const tvm::ir::StringImmNode* code_gen = compiler.as<tvm::ir::StringImmNode>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
- const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
+ const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
- FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
+ FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImmNode>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
* \return An external symbol.
*/
std::string GetExtSymbol(const Function& func) const {
- const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
+ const auto name_node =
+ FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImmNode>();
CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
CHECK(ttype) << "Expect TensorTypeNode";
std::vector<int> shape;
for (size_t i = 0; i < ttype->shape.size(); ++i) {
- auto* val = ttype->shape[i].as<IntImm>();
+ auto* val = ttype->shape[i].as<IntImmNode>();
CHECK(val);
shape.push_back(val->value);
}
// Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
args.push_back(std::to_string(wshape[0]));
args.push_back(std::to_string(conv2d_attr->groups));
- args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImm>()->value));
- args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImm>()->value));
+ args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
+ args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
args.push_back(std::to_string(wshape[2]));
args.push_back(std::to_string(wshape[3]));
- args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImm>()->value));
- args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImm>()->value));
+ args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
+ args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
return args;
}
Map<Integer, tvm::Target> tmp = args[1];
TargetsMap targets;
for (const auto& it : tmp) {
- auto dev_type = it.first.as<ir::IntImm>();
+ auto dev_type = it.first.as<ir::IntImmNode>();
CHECK(dev_type);
targets[dev_type->value] = it.second;
}
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Array<tvm::Expr> ret;
for (const auto &kv : this->output_.params) {
- tvm::Expr name = ir::StringImm::make(kv.first);
+ tvm::Expr name = ir::StringImmNode::make(kv.first);
ret.push_back(name);
}
*rv = ret;
bool IsClosure(const Function& func) {
ObjectRef res = FunctionGetAttr(func, attr::kClosure);
- const ir::IntImm* pval = res.as<ir::IntImm>();
+ const ir::IntImmNode* pval = res.as<ir::IntImmNode>();
return pval && pval->value != 0;
}
Array<tvm::Expr> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
- auto* str_name = entry.as<ir::StringImm>();
+ auto* str_name = entry.as<ir::StringImmNode>();
auto funcs = CallTracer(module).Trace(str_name->value);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
}
}
using AttrsEqualHandler::VisitAttr_;
- bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final {
+ bool VisitAttr_(const tvm::VarNode* lhs, const ObjectRef& other) final {
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
}
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
shape.push_back(
- tvm::ir::IntImm::make(DataType::Int(32), data->shape[i]));
+ tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
bool FunctionNode::IsPrimitive() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
- const ir::IntImm* pval = res.as<ir::IntImm>();
+ const ir::IntImmNode* pval = res.as<ir::IntImmNode>();
return pval && pval->value != 0;
}
bool FunctionNode::UseDefaultCompiler() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
- const ir::StringImm* pval = res.as<ir::StringImm>();
+ const ir::StringImmNode* pval = res.as<ir::StringImmNode>();
return pval == nullptr || pval->value == "default";
}
}
using AttrsHashHandler::VisitAttr_;
- size_t VisitAttr_(const Variable* var) final {
- size_t hash = std::hash<std::string>()(Variable::_type_key);
+ size_t VisitAttr_(const tvm::VarNode* var) final {
+ size_t hash = std::hash<std::string>()(VarNode::_type_key);
auto it = hash_map_.find(GetRef<VarExpr>(var));
if (it != hash_map_.end()) {
return it->second;
Doc PrintAttr(const ObjectRef& value, bool meta = false) {
if (value.defined()) {
Doc printed_attr;
- if (value.as<tvm::ir::Any>()) {
+ if (value.as<tvm::ir::AnyNode>()) {
printed_attr << "?";
} else if (meta) {
printed_attr = meta_.GetMetaNode(Downcast<ObjectRef>(value));
return doc;
}
- Doc VisitAttr_(const ir::IntImm* op) final {
+ Doc VisitAttr_(const ir::IntImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
}
- Doc VisitAttr_(const ir::UIntImm* op) final {
+ Doc VisitAttr_(const ir::UIntImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
}
- Doc VisitAttr_(const ir::FloatImm* op) final {
+ Doc VisitAttr_(const ir::FloatImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
}
- Doc VisitAttr_(const ir::StringImm* op) final {
+ Doc VisitAttr_(const ir::StringImmNode* op) final {
return PrintString(op->value);
}
// Second argument should be shape tensor.
auto tt = types[1].as<TensorTypeNode>();
CHECK(tt != nullptr) << "must be tensor type";
- auto rank = tt->shape[0].as<tvm::IntImm>();
+ auto rank = tt->shape[0].as<tvm::IntImmNode>();
CHECK(rank != nullptr);
auto dims = rank->value;
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- if (!dshape_nchw[2].as<ir::Any>()) {
+ if (!dshape_nchw[2].as<ir::AnyNode>()) {
oshape.Set(2, (dshape_nchw[2] + pad_h
- dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
- if (!dshape_nchw[3].as<ir::Any>()) {
+ if (!dshape_nchw[3].as<ir::AnyNode>()) {
oshape.Set(3, (dshape_nchw[3] + pad_w
- dilated_ksize_x) / param->strides[1] + 1);
} else {
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
- if (!dshape_nchw[2].as<ir::Any>()) {
+ if (!dshape_nchw[2].as<ir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
- if (!dshape_nchw[3].as<ir::Any>()) {
+ if (!dshape_nchw[3].as<ir::AnyNode>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
} else {
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
- if (!dshape_ncdhw[2].as<ir::Any>()) {
+ if (!dshape_ncdhw[2].as<ir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
- if (!dshape_ncdhw[3].as<ir::Any>()) {
+ if (!dshape_ncdhw[3].as<ir::AnyNode>()) {
oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
- if (!dshape_ncdhw[4].as<ir::Any>()) {
+ if (!dshape_ncdhw[4].as<ir::AnyNode>()) {
oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
param->strides[2]) + 1);
} else {
auto target_dim = make_const(DataType::Int(32), 1);
for (uint32_t i = 1; i < data->shape.size(); ++i) {
- if (!data->shape[i].as<ir::Any>()) {
+ if (!data->shape[i].as<ir::AnyNode>()) {
target_dim = target_dim * data->shape[i];
} else {
target_dim = data->shape[i];
// If any pad_width element is not zero, do not change the layout.
for (auto width : axis_pad_width.at(dual_axis_name)) {
- if (auto* width_imm = width.as<IntImm>()) {
+ if (auto* width_imm = width.as<IntImmNode>()) {
if (width_imm->value != 0) {
is_layout_modified = false;
}
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
- if (!data->shape[i].as<ir::Any>()) {
+ if (!data->shape[i].as<ir::AnyNode>()) {
auto padding = make_const(data->shape[i].dtype(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
} else {
oshape.push_back(e);
}
- if (dshape[hidx].as<ir::Any>()) {
+ if (dshape[hidx].as<ir::AnyNode>()) {
oshape[hidx] = dshape[hidx];
} else {
if (param->ceil_mode) {
oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
}
}
- if (dshape[widx].as<ir::Any>()) {
+ if (dshape[widx].as<ir::AnyNode>()) {
oshape[widx] = dshape[widx];
} else {
if (param->ceil_mode) {
std::vector<int> idxes = {didx, hidx, widx};
for (int i = 0; i < 3; i++) {
int ii = idxes[i];
- if (dshape[ii].as<ir::Any>()) {
+ if (dshape[ii].as<ir::AnyNode>()) {
oshape[ii] = dshape[ii];
} else {
if (param->ceil_mode) {
<< " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
- oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
- oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
+ oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
+ oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
// assign output type
reporter->Assign(types[1],
<< " But got " << in_layout;
auto oshape = layout_converter.ForwardShape(data->shape);
- oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
- oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
- oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
+ oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
+ oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
+ oshape.Set(4, ir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
// assign output type
reporter->Assign(types[1],
auto max_shape = make_const(DataType::Int(64), 1);
bool is_dynamic_input = false;
for (int64_t axis : r_axes) {
- if (in_shape[axis].as<IntImm>()) {
+ if (in_shape[axis].as<IntImmNode>()) {
max_shape *= in_shape[axis];
} else {
is_dynamic_input = true;
namespace tvm {
namespace relay {
-using ir::IntImm;
+using ir::IntImmNode;
// relay.cast
TVM_REGISTER_NODE_TYPE(CastAttrs);
CHECK(out_ttype != nullptr);
Array<IndexExpr> newshape;
for (auto val : out_ttype->shape) {
- if (val->IsInstance<ir::Any>()) {
- newshape.push_back(val.as<ir::Any>()->ToVar());
+ if (val->IsInstance<ir::AnyNode>()) {
+ newshape.push_back(val.as<ir::AnyNode>()->ToVar());
} else {
newshape.push_back(val);
}
// Only check When input data has static shape.
bool is_static_shape = true;
for (size_t i = 0; i < data->shape.size(); ++i) {
- if (!data->shape[i].as<IntImm>()) {
+ if (!data->shape[i].as<IntImmNode>()) {
is_static_shape = false;
break;
}
const auto& input_rank = input_shape.size();
std::vector<IndexExpr> result_shape;
result_shape.push_back(Any::make());
- result_shape.push_back(IntImm::make(DataType::Int(32), input_rank));
+ result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank));
reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
return true;
}
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
for (size_t i = 0; i < rndim; ++i) {
- if (const tvm::ir::IntImm* val = reps[i].as<tvm::ir::IntImm>()) {
+ if (const tvm::ir::IntImmNode* val = reps[i].as<tvm::ir::IntImmNode>()) {
CHECK_GT(val->value, 0)
<< "Tile reps value should always be larger than 0, but get: " << val->value;
}
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
// Save Any if it is dynamic shape
- if (!data_shape[i].as<IntImm>()) {
+ if (!data_shape[i].as<IntImmNode>()) {
oshape.emplace_back(Any::make());
} else {
oshape.emplace_back(data_shape[i] * reps_shape[i]);
// if axes is None, squeeze all axes of dimension 1
if (!param->axis.defined()) {
for (const auto& e : data->shape) {
- if (!e.as<IntImm>()) {
+ if (!e.as<IntImmNode>()) {
LOG(FATAL) << "axis needs to be defined for dynamic input.";
}
const int64_t* axis_ptr = as_const_int(e);
// Adapter function to make int array.
Array<Integer> GetIntArray(Array<IndexExpr> arr) {
for (size_t i = 0; i < arr.size(); ++i) {
- CHECK(!arr[i].defined() || arr[i].as<IntImm>())
+ CHECK(!arr[i].defined() || arr[i].as<IntImmNode>())
<< "Expect an int array";
}
return Downcast<Array<Integer> >(arr);
}
int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
int64_t end = params->end[i].defined() ? params->end[i]->value :
- shape[i].as<IntImm>()->value;
+ shape[i].as<IntImmNode>()->value;
if (begin % factor || end % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
CHECK_GE(axis, 0)
<< "axis should be within the input dimension range.";
- if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
+ if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
CHECK(reporter->Assert(indexmod(data->shape[axis],
sections->value) == make_zero(DataType::Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);
- if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
+ if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
int64_t num_sections = sections->value;
return Array<Tensor>{
topi::split_sections(inputs[0], num_sections, param->axis) };
return false;
}
const size_t ndim = data->shape.size();
- const IntImm* mdim = indices->shape[0].as<IntImm>();
+ const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
return Downcast<Function>(CanonicalizeOps(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
min_num_branches));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
};
return CreateFunctionPass(
pass_func, 3, "ConvertLayout",
- {ir::StringImm::make("InferType"),
- ir::StringImm::make("CanonicalizeOps")});
+ {ir::StringImmNode::make("InferType"),
+ ir::StringImmNode::make("CanonicalizeOps")});
}
TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
};
return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
value = runtime::NDArray::Empty(cshape, cdtype, ctx);
int32_t* dims = static_cast<int32_t*>(value->data);
- using ::tvm::ir::IntImm;
+ using ::tvm::ir::IntImmNode;
for (size_t i = 0; i < ishape.size(); ++i) {
- if (const IntImm* dim = ishape[i].as<IntImm>()) {
+ if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
dims[i] = dim->value;
} else {
return expr;
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
return Downcast<Function>(FuseOps(f, opt_level, m));
};
return CreateFunctionPass(pass_func, 1, "FuseOps",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// 4) a) Check if this shape element is 1.
bool is_shape_one = false;
- if (auto* shape_int = shape_val.as<IntImm>()) {
+ if (auto* shape_int = shape_val.as<IntImmNode>()) {
if (shape_int->value == 1) {
new_layout += "1";
is_shape_one = true;
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
- return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImm::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
int64_t ret = 1;
for (size_t i = 0; i < arr.size(); i++) {
- const auto* intImm = arr[i].as<IntImm>();
+ const auto* intImm = arr[i].as<IntImmNode>();
ret *= static_cast<int64_t>(intImm->value);
}
return ret;
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK_NE(C_ind, -1)
<< "There is no input channel dimension.";
- int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
+ int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
if (c_ind != -1)
- input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
+ input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
CHECK_EQ(kernel_size.size(), 2)
<< "The dimension of the kernel in Conv 2D should be 2.";
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK_NE(C_ind, -1)
<< "There is no input channel dimension.";
- int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
+ int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
if (c_ind != -1)
- input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
+ input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
CHECK_EQ(kernel_size.size(), 2)
<< "The dimension of the kernel in Conv 2D Transpose should be 2.";
Array<IndexExpr> weight_shape = weight_type->shape;
CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
<< "The dimension of an input tensor to Dense node should be 2.";
- int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
- int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
- int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
- int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
+ int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImmNode>()->value);
+ int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImmNode>()->value);
+ int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImmNode>()->value);
+ int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImmNode>()->value);
CHECK_EQ(d2, d4)
<< "The dimensions of input arguments do not match.";
int64_t count = d1 * d2 * d3;
CHECK_EQ(args.size(), 2);
Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
- int64_t batch = x_shape[0].as<IntImm>()->value;
- int64_t m = x_shape[1].as<IntImm>()->value;
- int64_t k = x_shape[2].as<IntImm>()->value;
- int64_t n = y_shape[1].as<IntImm>()->value;
+ int64_t batch = x_shape[0].as<IntImmNode>()->value;
+ int64_t m = x_shape[1].as<IntImmNode>()->value;
+ int64_t k = x_shape[2].as<IntImmNode>()->value;
+ int64_t n = y_shape[1].as<IntImmNode>()->value;
return batch * m * k * n;
}
bool FunctionPassNode::SkipFunction(const Function& func) const {
ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
- const ir::IntImm* pval = skip_opt.as<ir::IntImm>();
+ const ir::IntImmNode* pval = skip_opt.as<ir::IntImmNode>();
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
}
inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
- auto* str_name = x.as<ir::StringImm>();
+ auto* str_name = x.as<ir::StringImmNode>();
CHECK(str_name) << "pass name must be str";
if (str_name->value == pass_name) return true;
}
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
- const auto* name = it.as<tvm::ir::StringImm>();
+ const auto* name = it.as<tvm::ir::StringImmNode>();
CHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: [" << "\n";
for (const auto& it : node->required) {
- const auto* str = it.as<tvm::ir::StringImm>();
+ const auto* str = it.as<tvm::ir::StringImmNode>();
p->stream << str->value << ", ";
}
p->stream << "]\n";
inline bool IsScalar(const Expr& expr) {
if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
for (auto dim_index_expr : tensor_type->shape) {
- if (auto dim_index = dim_index_expr.as<IntImm>()) {
+ if (auto dim_index = dim_index_expr.as<IntImmNode>()) {
if (dim_index->value != 1) {
return false;
}
return Downcast<Function>(SimplifyInference(f));
};
return CreateFunctionPass(pass_func, 0, "SimplifyInference",
- {ir::StringImm::make("InferType")});
+ {ir::StringImmNode::make("InferType")});
}
TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
return Any::make();
}
- auto left_index0 = ulhs.as<tvm::Variable>();
- auto right_index0 = urhs.as<tvm::IntImm>();
+ auto left_index0 = ulhs.as<tvm::VarNode>();
+ auto right_index0 = urhs.as<tvm::IntImmNode>();
if (left_index0 && right_index0) {
solver_->shape_uf_.Set(ulhs, urhs);
return urhs;
}
- auto left_index1 = ulhs.as<tvm::IntImm>();
- auto right_index1 = urhs.as<tvm::Variable>();
+ auto left_index1 = ulhs.as<tvm::IntImmNode>();
+ auto right_index1 = urhs.as<tvm::VarNode>();
if (left_index1 && right_index1) {
solver_->shape_uf_.Set(urhs, ulhs);
return ulhs;
}
- auto left_index2 = ulhs.as<tvm::IntImm>();
- auto right_index2 = urhs.as<tvm::IntImm>();
+ auto left_index2 = ulhs.as<tvm::IntImmNode>();
+ auto right_index2 = urhs.as<tvm::IntImmNode>();
if (left_index2 && right_index2 && left_index2->value == right_index2->value) {
return ulhs;
}
ExprVisitor::VisitExpr(e);
}
- void VisitExpr_(const Call* op) final {
+ void VisitExpr_(const CallNode* op) final {
Array<Expr> axis = op->args;
if (axis_.size() != axis.size()) {
is_elem_wise_ = false;
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
- std::unordered_map<const Variable*, IntSet> relax_set;
+ std::unordered_map<const VarNode*, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
// Get the domain of the consumer
PassUpDomain(op_stage, *rmap, &up_state);
// Relax if needed.
- std::unordered_map<const Variable*, IntSet> dom_map;
+ std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto iv : op->root_iter_vars()) {
Range r;
int value_index;
int dim;
TensorDimKey() {}
- TensorDimKey(const ir::Call* op, int dim)
+ TensorDimKey(const ir::CallNode* op, int dim)
: f(op->func), value_index(op->value_index), dim(dim) {
}
TensorDimKey(const Tensor& t, int dim)
reach[TensorDimKey(t, i)] = {};
}
auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
- const ir::Call *call = n.as<ir::Call>();
+ const ir::CallNode *call = n.as<ir::CallNode>();
if (call != nullptr && call->func.defined()) {
if (!bset.count(call->func.get())) return;
for (size_t i = 0; i < call->args.size(); ++i) {
TensorDimKey dkey(call, static_cast<int>(i));
auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
- const Variable *v = node.as<Variable>();
+ const VarNode *v = node.as<VarNode>();
auto it = vmap.find(v);
if (it != vmap.end()) {
reach[it->second].push_back(dkey);
}
auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
const ObjectRef& n) {
- const ir::Call *call = n.as<ir::Call>();
+ const ir::CallNode *call = n.as<ir::CallNode>();
if (call != nullptr && call->func.defined()) {
for (size_t i = 0; i < call->args.size(); ++i) {
auto it = vmap.find(call->args[i].get());
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<Expr> preds;
- std::unordered_map<const Variable*, IntSet> iset_dmap;
+ std::unordered_map<const VarNode*, IntSet> iset_dmap;
// setup domain map for set analysis
for (const auto& kv : dom_map) {
class VarReplacer : public ir::StmtExprMutator {
public:
explicit VarReplacer(
- const std::unordered_map<const Variable*, Expr>& vsub)
+ const std::unordered_map<const VarNode*, Expr>& vsub)
: vsub_(vsub) {}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
return GetRef<Expr>(op);
}
}
- Expr VisitExpr_(const ir::Reduce* op) final {
+ Expr VisitExpr_(const ir::ReduceNode* op) final {
Expr new_e = StmtExprMutator::VisitExpr_(op);
- const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
+ const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
if (op->combiner.same_as(new_combiner)) {
return new_e;
} else {
- return ir::Reduce::make(
+ return ir::ReduceNode::make(
new_combiner,
new_reduce->source,
new_reduce->axis,
}
private:
- const std::unordered_map<const Variable*, Expr>& vsub_;
+ const std::unordered_map<const VarNode*, Expr>& vsub_;
};
Expr InjectPredicate(const Array<Expr>& predicates,
Expr body) {
- using ir::Reduce;
- using ir::Select;
+ using ir::ReduceNode;
+ using ir::SelectNode;
if (predicates.size() == 0) return body;
- const Reduce* reduce = body.as<Reduce>();
+ const ReduceNode* reduce = body.as<ReduceNode>();
if (reduce) {
- auto n = make_object<Reduce>(*reduce);
- n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
+ auto n = make_object<ReduceNode>(*reduce);
+ n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, Expr());
return Expr(n);
}
- return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
+ return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, Expr()),
body,
make_zero(body.dtype()));
}
}
}
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) &&
(a->source.same_as(b->source)) &&
(a->axis.same_as(b->axis)) &&
std::unordered_set<IterVar>* p_red_axis,
Array<IterVar>* p_new_axis,
std::unordered_map<IterVar, Range>* p_dom_map,
- std::unordered_map<const Variable*, Expr>* p_vsub,
- std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
+ std::unordered_map<const VarNode*, Expr>* p_vsub,
+ std::unordered_map<const VarNode*, Expr>* p_vsub2newvar,
std::vector<Expr>* p_predicates) {
auto& red_axis = *p_red_axis;
auto& new_axis = *p_new_axis;
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<const Variable*, Expr> vsub;
- std::unordered_map<const Variable*, Expr> vsub2newvar;
+ std::unordered_map<const VarNode*, Expr> vsub;
+ std::unordered_map<const VarNode*, Expr> vsub2newvar;
std::vector<Expr> predicates;
PrepareAxisMapping(orig_stage, compute,
Expr body;
Array<Expr> body_list;
- const ir::Reduce* first_reduce = nullptr;
+ const ir::ReduceNode* first_reduce = nullptr;
for (auto cbody : compute->body) {
body = VarReplacer(vsub)(cbody);
body = InjectPredicate(predicates, body);
body = VarReplacer(vsub2newvar)(body);
// Reduce nodes in ONE computeOp must be the same except value_index
// This is right only if the original body ensures Reduce nodes are the same
- if (body->IsInstance<ir::Reduce>()) {
- const ir::Reduce* reduce_body = body.as<ir::Reduce>();
+ if (body->IsInstance<ir::ReduceNode>()) {
+ const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
if (first_reduce != nullptr) {
CHECK(ReduceEqual(reduce_body, first_reduce));
- body = ir::Reduce::make(first_reduce->combiner,
+ body = ir::ReduceNode::make(first_reduce->combiner,
first_reduce->source,
first_reduce->axis,
first_reduce->condition,
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<const Variable*, Expr> vsub;
- std::unordered_map<const Variable*, Expr> vsub2newvar;
+ std::unordered_map<const VarNode*, Expr> vsub;
+ std::unordered_map<const VarNode*, Expr> vsub2newvar;
std::vector<Expr> predicates;
PrepareAxisMapping(orig_stage, tensor_op,
if (!new_body[j].size()) {
new_body[j] = compute->body;
}
- if (new_body[j][0]->IsInstance<ir::Reduce>()) {
+ if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
// specially handle reduction inline for multiplre reductions.
- const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
+ const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
for (size_t k = 1; k < new_body[j].size(); ++k) {
- const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
+ const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
CHECK(reduce_);
CHECK(ReduceEqual(reduce_, reduce))
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
- Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
- stage->op, args, body).as<ir::Evaluate>()->value;
+ Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+ stage->op, args, body).as<ir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
- const ir::Reduce* r = new_value.as<ir::Reduce>();
+ const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
CHECK_EQ(new_body[j].size(), r->source.size());
CHECK(r != nullptr);
for (size_t k = 0; k < new_body[j].size(); ++k) {
- auto n = make_object<ir::Reduce>(*r);
+ auto n = make_object<ir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
new_body[j].Set(k, Expr(n));
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
- Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
- stage->op, args, body).as<ir::Evaluate>()->value;
+ Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+ stage->op, args, body).as<ir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
const IterVar& axis,
int factor_axis) {
(*this)->InvalidateCache();
- using ir::Reduce;
+ using ir::ReduceNode;
CHECK_EQ(axis->iter_type, kCommReduce)
<< "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
}
// predicate generation, copy not touched axis.
int idx = tensor->value_index;
- const Reduce* reduce = compute_op->body[idx].as<Reduce>();
+ const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
- Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
+ Expr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, Expr()));
- std::unordered_map<const Variable*, Expr> vsub;
+ std::unordered_map<const VarNode*, Expr> vsub;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
std::vector<Expr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
- body.emplace_back(Reduce::make(reduce->combiner,
+ body.emplace_back(ReduceNode::make(reduce->combiner,
new_source,
n->reduce_axis,
new_pred,
Array<IterVar> axis = {repl_red_axis};
Expr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
- reductions.push_back(Reduce::make(reduce->combiner,
+ reductions.push_back(ReduceNode::make(reduce->combiner,
factor_exprs, axis, cond, idx));
}
return reductions;
} else {
UpdateIterVarAttr(
operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
- n->pragma_keys.push_back(ir::StringImm::make(pragma_type));
+ n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
n->pragma_values.push_back(pragma_value);
});
}
bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (producer.defined()) {
- producer = ProducerConsumer::make(s->op, true, producer);
+ producer = ProducerConsumerNode::make(s->op, true, producer);
}
if (s->double_buffer) {
- producer = AttrStmt::make(
+ producer = AttrStmtNode::make(
s->op, ir::attr::double_buffer_scope, 1, producer);
}
Stmt pipeline = producer;
if (consumer.defined() && !is_no_op(consumer)) {
- consumer = ProducerConsumer::make(s->op, false, consumer);
+ consumer = ProducerConsumerNode::make(s->op, false, consumer);
pipeline = SeqStmt({producer, consumer});
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation.
- pipeline = AttrStmt::make(
+ pipeline = AttrStmtNode::make(
s->op, ir::attr::realize_scope,
- StringImm::make(s->scope),
+ StringImmNode::make(s->scope),
pipeline);
if (s->is_opengl) {
- pipeline = AttrStmt::make(
- s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
+ pipeline = AttrStmtNode::make(
+ s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
}
return pipeline;
}
Stmt VisitStmt(const Stmt& input_stmt) final {
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
- const AttrStmt* op = stmt.as<AttrStmt>();
+ const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr &&
op->attr_key == attr::loop_scope) {
if (attach_spec_->attach_type == kScope &&
<< "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true;
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
CHECK(input_stmt.defined());
auto stmt = StmtMutator::VisitStmt(input_stmt);
// update
- const AttrStmt* op = stmt.as<AttrStmt>();
+ const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr &&
((op->attr_key == attr::scan_update_scope && !is_init_) ||
(op->attr_key == attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
- stmt = AttrStmt::make(
+ stmt = AttrStmtNode::make(
op->node, op->attr_key, op->value,
MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
// Replace the init and update's expression by scan's buffer.
class SchedulePostProc : public StmtExprMutator {
public:
- Stmt VisitStmt_(const ProducerConsumer* op) final {
+ Stmt VisitStmt_(const ProducerConsumerNode* op) final {
auto it = replace_op_.find(op->func.get());
if (it != replace_op_.end()) {
Stmt body = this->VisitStmt(op->body);
if (it->second.defined()) {
- return ProducerConsumer::make(
+ return ProducerConsumerNode::make(
it->second, op->is_producer, body);
} else {
return body;
return StmtExprMutator::VisitStmt_(op);
}
}
- Stmt VisitStmt_(const LetStmt* op) final {
+ Stmt VisitStmt_(const LetStmtNode* op) final {
if (!HasSideEffect(op->value)) {
var_value_[op->var.get()] = this->VisitExpr(op->value);
return this->VisitStmt(op->body);
}
}
- Stmt VisitStmt_(const AttrStmt* op) final {
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::loop_scope ||
op->attr_key == attr::scan_init_scope) {
return this->VisitStmt(op->body);
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- Stmt ret = AttrStmt::make(
+ Stmt ret = AttrStmtNode::make(
it->second, op->attr_key, op->value, op->body);
return this->VisitStmt(ret);
} else {
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- return AttrStmt::make(
+ return AttrStmtNode::make(
Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
op->attr_key, op->value, this->VisitStmt(op->body));
} else {
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- return AttrStmt::make(
+ return AttrStmtNode::make(
it->second.output(tensor->value_index),
op->attr_key, op->value, this->VisitStmt(op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const Realize* op) final {
+ Stmt VisitStmt_(const RealizeNode* op) final {
TensorKey key{op->func, op->value_index};
auto it = replace_realize_.find(key);
if (it != replace_realize_.end()) {
if (it->second.defined()) {
- Stmt ret = Realize::make(
+ Stmt ret = RealizeNode::make(
it->second->op, it->second->value_index,
op->dtype, op->bounds, op->condition, op->body);
return this->VisitStmt(ret);
}
}
- Stmt VisitStmt_(const Provide* op) final {
+ Stmt VisitStmt_(const ProvideNode* op) final {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- Stmt ret = Provide::make(
+ Stmt ret = ProvideNode::make(
dst->op, dst->value_index, op->value, op->args);
return this->VisitStmt(ret);
} else {
}
}
- Expr VisitExpr_(const Call* op) final {
- if (op->call_type == Call::Halide) {
+ Expr VisitExpr_(const CallNode* op) final {
+ if (op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- Expr ret = Call::make(
+ Expr ret = CallNode::make(
op->dtype, dst->op->name, op->args,
op->call_type, dst->op, dst->value_index);
return this->VisitExpr(ret);
return StmtExprMutator::VisitExpr_(op);
}
- Expr VisitExpr_(const Variable* op) final {
+ Expr VisitExpr_(const VarNode* op) final {
auto it = var_value_.find(op);
if (it != var_value_.end()) {
return it->second;
// The thread extent scope.
std::unordered_map<const Object*, Expr> thread_extent_scope_;
// The scan value
- std::unordered_map<const Variable*, Expr> var_value_;
+ std::unordered_map<const VarNode*, Expr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, Tensor> replace_buffer_;
// buffere realization to be replaced
n->InitBySeq("name", "xxx", "expr", 128);
CHECK_EQ(n->name, "xxx");
CHECK_EQ(n->axis, 10);
- CHECK_EQ(n->expr.as<tvm::ir::IntImm>()->value, 128);
+ CHECK_EQ(n->expr.as<tvm::ir::IntImmNode>()->value, 128);
// Check docstring
std::ostringstream os;
n->PrintDocString(os);
using namespace tvm;
Array<Expr> array{1, 2, 3};
std::vector<Expr> vector(array.begin(), array.end());
- CHECK(vector[1].as<IntImm>()->value == 2);
+ CHECK(vector[1].as<IntImmNode>()->value == 2);
}
TEST(Map, Expr) {
Map<Expr, Expr> map1{{a, b}};
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>
map2(map1.begin(), map1.end());
- CHECK(map2[a].as<IntImm>()->value == 2);
+ CHECK(map2[a].as<IntImmNode>()->value == 2);
}
int main(int argc, char** argv) {
using namespace tvm;
Var x("x");
Expr z = max(x + 1 + 2, 100);
- const ir::Max* op = z.as<ir::Max>();
+ const ir::MaxNode* op = z.as<ir::MaxNode>();
CHECK(GetRef<ObjectRef>(op).same_as(z));
}
auto z = x + 1;
NodeFunctor<int(const ObjectRef& n, int b)> f;
- f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
+ f.set_dispatch<VarNode>([](const ObjectRef& n, int b) {
return b;
});
- f.set_dispatch<Add>([](const ObjectRef& n, int b) {
+ f.set_dispatch<AddNode>([](const ObjectRef& n, int b) {
return b + 2;
});
CHECK_EQ(f(x, 2), 2);
auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
- if (n.as<Variable>()) ++n_var;
+ if (n.as<VarNode>()) ++n_var;
});
CHECK_EQ(n_var, 2);
}
class MyExprFunctor
: public ir::ExprFunctor<int(const Expr&, int)> {
public:
- int VisitExpr_(const Variable* op, int b) final {
+ int VisitExpr_(const VarNode* op, int b) final {
return b;
}
- int VisitExpr_(const IntImm* op, int b) final {
+ int VisitExpr_(const IntImmNode* op, int b) final {
return op->value;
}
- int VisitExpr_(const Add* op, int b) final {
+ int VisitExpr_(const AddNode* op, int b) final {
return VisitExpr(op->a, b) + VisitExpr(op->b, b);
}
};
public:
int count = 0;
// implementation
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
++count;
}
- void VisitExpr_(const IntImm* op) final {
+ void VisitExpr_(const IntImmNode* op) final {
}
- void VisitExpr_(const Add* op) final {
+ void VisitExpr_(const AddNode* op) final {
VisitExpr(op->a);
VisitExpr(op->b);
}
- void VisitStmt_(const Evaluate* op) final {
+ void VisitStmt_(const EvaluateNode* op) final {
VisitExpr(op->value);
}
};
MyVisitor v;
- v.VisitStmt(Evaluate::make(z));
+ v.VisitStmt(EvaluateNode::make(z));
CHECK_EQ(v.count, 1);
}
public:
int count = 0;
// implementation
- void VisitExpr_(const Variable* op) final {
+ void VisitExpr_(const VarNode* op) final {
++count;
}
};
MyVisitor v;
auto fmaketest = [&]() {
auto z = x + 1;
- Stmt body = Evaluate::make(z);
+ Stmt body = EvaluateNode::make(z);
Var buffer("b", DataType::Handle());
- return Allocate::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
+ return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
};
v(fmaketest());
CHECK_EQ(v.count, 3);
protected:
// implementation
- Expr VisitExpr_(const Add* op) final {
+ Expr VisitExpr_(const AddNode* op) final {
return op->a;
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
};
auto fmakealloc = [&]() {
auto z = x + 1;
- Stmt body = Evaluate::make(z);
+ Stmt body = EvaluateNode::make(z);
Var buffer("b", DataType::Handle());
- return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
+ return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
};
auto fmakeif = [&]() {
auto z = x + 1;
- Stmt body = Evaluate::make(z);
- return IfThenElse::make(x < 0, Evaluate::make(0), body);
+ Stmt body = EvaluateNode::make(z);
+ return IfThenElseNode::make(x, EvaluateNode::make(0), body);
};
MyVisitor v;
{
auto body = fmakealloc();
- Stmt body2 = Evaluate::make(1);
- Stmt bref = body.as<Allocate>()->body;
- auto* extentptr = body.as<Allocate>()->extents.get();
+ Stmt body2 = EvaluateNode::make(1);
+ Stmt bref = body.as<AllocateNode>()->body;
+ auto* extentptr = body.as<AllocateNode>()->extents.get();
Array<Stmt> arr{std::move(body), body2, body2};
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() == arrptr);
// inplace update body
- CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
- CHECK(arr[0].as<Allocate>()->extents.get() == extentptr);
+ CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
+ CHECK(arr[0].as<AllocateNode>()->extents.get() == extentptr);
// copy because there is additional refs
- CHECK(!arr[0].as<Allocate>()->body.same_as(bref));
- CHECK(arr[0].as<Allocate>()->body.as<Evaluate>()->value.same_as(x));
- CHECK(bref.as<Evaluate>()->value.as<Add>());
+ CHECK(!arr[0].as<AllocateNode>()->body.same_as(bref));
+ CHECK(arr[0].as<AllocateNode>()->body.as<EvaluateNode>()->value.same_as(x));
+ CHECK(bref.as<EvaluateNode>()->value.as<AddNode>());
}
{
Array<Stmt> arr{fmakealloc()};
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
CHECK(arr.get() != arrptr);
- CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
- CHECK(!arr2[0].as<Allocate>()->extents[1].same_as(x));
+ CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
+ CHECK(!arr2[0].as<AllocateNode>()->extents[1].same_as(x));
// mutate but no content change.
arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
{
Array<Stmt> arr{fmakeif()};
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
- CHECK(arr[0].as<IfThenElse>()->else_case.as<Evaluate>()->value.same_as(x));
+ CHECK(arr[0].as<IfThenElseNode>()->else_case.as<EvaluateNode>()->value.same_as(x));
// mutate but no content change.
auto arr2 = arr;
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
}
{
- auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern));
+ auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
auto res = v(std::move(body));
- CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));
+ CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
}
{
auto body = fmakealloc();
- Stmt body2 = Evaluate::make(1);
+ Stmt body2 = EvaluateNode::make(1);
auto* ref2 = body2.get();
- auto* extentptr = body.as<Allocate>()->extents.get();
+ auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
CHECK(body.as<SeqStmtNode>()->size() == 3);
- CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() == extentptr);
+ CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() == extentptr);
CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
}
{
// Cannot cow because of bref
auto body = fmakealloc();
- Stmt body2 = Evaluate::make(1);
- auto* extentptr = body.as<Allocate>()->extents.get();
+ Stmt body2 = EvaluateNode::make(1);
+ auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
auto bref = body;
body = SeqStmt({body, body2});
body = v(std::move(body));
// the seq get flattened
- CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() != extentptr);
+ CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
}
}
// Mod::make is used instead of % to avoid constant folding during
// calling operator%(x,y). Mod::make doesn't try constant folding,
// and therefore, the constant folding will be attempted in CanonicalSimplify
- auto mod = tvm::ir::CanonicalSimplify(tvm::ir::Mod::make(x, y));
+ auto mod = tvm::ir::CanonicalSimplify(tvm::ir::ModNode::make(x, y));
auto es = tvm::ir::CanonicalSimplify(mod - x);
CHECK(is_zero(es));
}
using namespace tvm;
using namespace tvm::ir;
Var x("x"), y;
- Expr let = Let::make(x, 1, x + 1);
+ Expr let = LetNode::make(x, 1, x + 1);
- auto z = Evaluate::make(let + let);
+ auto z = EvaluateNode::make(let + let);
CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(z);
CHECK(ir::VerifySSA(z_ssa));
using namespace tvm::ir;
using namespace tvm;
Var x("x"), y;
- auto z = Evaluate::make(x + y);
+ auto z = EvaluateNode::make(x + y);
CHECK(ir::VerifySSA(z));
}
// automatic conversion of int to expr
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
Expr x = args[0];
- *rv = x.as<tvm::ir::IntImm>()->value + 1;
+ *rv = x.as<tvm::ir::IntImmNode>()->value + 1;
});
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
{
CHECK(select(px >= pz, py, py + pz).Match(
- ir::Select::make((x + 1) >= 1, y, y + 1)));
+ ir::SelectNode::make((x + 1) >= 1, y, y + 1)));
CHECK(ir::Equal(px.Eval(), x + 1));
}
// bit intrinsics
// select
{
CHECK(select(px > pz, py, py + pz).Match(
- ir::Select::make(x > 1, y, y + 1)));
+ ir::SelectNode::make(x > 1, y, y + 1)));
CHECK(is_const_int(pz.Eval(), 1));
}
CHECK(!select(px > pz, py, py + pz).Match(
- ir::Select::make(x > 2, y, y + 1)));
+ ir::SelectNode::make(x > 2, y, y + 1)));
CHECK(!select(px > pz, py, py).Match(
- ir::Select::make(x > 2, y, y + 1)));
+ ir::SelectNode::make(x > 2, y, y + 1)));
{
CHECK(select(px, py, pz).Match(
- ir::Select::make(x > 2, y, y + 1)));
+ ir::SelectNode::make(x > 2, y, y + 1)));
CHECK(ir::Equal(pz.Eval(), y + 1));
}
// if_then_else
// cast pattern
{
CHECK(!cast(PConst<DataType>(
- DataType::Int(32)), px).Match(ir::Cast::make(DataType::Float(64), x)));
- CHECK(cast(pt, px).Match(ir::Cast::make(DataType::Float(64), x)));
+ DataType::Int(32)), px).Match(ir::CastNode::make(DataType::Float(64), x)));
+ CHECK(cast(pt, px).Match(ir::CastNode::make(DataType::Float(64), x)));
CHECK(pt.Eval() == DataType::Float(64));
auto zz = cast(pt, px).Eval();
CHECK((cast(pt, px) - cast(pt, py)).Match(
- ir::Cast::make(DataType::Float(64), x) - ir::Cast::make(DataType::Int(64), x)));
- auto expr = ir::Cast::make(DataType::Int(32), ir::Cast::make(DataType::Float(64), x));
+ ir::CastNode::make(DataType::Float(64), x) - ir::CastNode::make(DataType::Int(64), x)));
+ auto expr = ir::CastNode::make(DataType::Int(32), ir::CastNode::make(DataType::Float(64), x));
CHECK(!(cast(pt, cast(pt, px))).Match(expr));
}
// ramp pattern
{
CHECK(ramp(px, PConst<Expr>(1), planes).Match(
- ir::Ramp::make(x, 1, 10)));
+ ir::RampNode::make(x, 1, 10)));
CHECK(planes.Eval() == 10);
CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
- ir::Ramp::make(x, 2, 10)));
+ ir::RampNode::make(x, 2, 10)));
}
// broadcast pattern
{
CHECK(broadcast(px, planes).Match(
- ir::Broadcast::make(x, 10)));
+ ir::BroadcastNode::make(x, 10)));
CHECK(planes.Eval() == 10);
CHECK(broadcast(px * py , planes).Match(
- ir::Broadcast::make(x * 10, 10)));
+ ir::BroadcastNode::make(x * 10, 10)));
}
}
int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
- const Variable* var1 = shape1[s1_size - i].as<Variable>();
- const Variable* var2 = shape2[s2_size - i].as<Variable>();
+ const VarNode* var1 = shape1[s1_size - i].as<VarNode>();
+ const VarNode* var2 = shape2[s2_size - i].as<VarNode>();
bh.all_vars.push_front(tvm::Var());
if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
*/
inline bool IsConstInt(Expr expr) {
return
- expr->IsInstance<tvm::ir::IntImm>() ||
- expr->IsInstance<tvm::ir::UIntImm>();
+ expr->IsInstance<tvm::ir::IntImmNode>() ||
+ expr->IsInstance<tvm::ir::UIntImmNode>();
}
/*!
* \return The integer value.
*/
inline int64_t GetConstInt(Expr expr) {
- if (expr->IsInstance<tvm::ir::IntImm>()) {
- return expr.as<tvm::ir::IntImm>()->value;
+ if (expr->IsInstance<tvm::ir::IntImmNode>()) {
+ return expr.as<tvm::ir::IntImmNode>()->value;
}
- if (expr->IsInstance<tvm::ir::UIntImm>()) {
- return expr.as<tvm::ir::UIntImm>()->value;
+ if (expr->IsInstance<tvm::ir::UIntImmNode>()) {
+ return expr.as<tvm::ir::UIntImmNode>()->value;
}
LOG(ERROR) << "expr must be a constant integer";
return -1;
}
auto body = fextern(input_placeholders, output_placeholders);
- auto body_stmt = tvm::ir::Evaluate::make(body);
+ auto body_stmt = tvm::ir::EvaluateNode::make(body);
auto op = ExternOpNode::make(
name, tag, attrs, inputs,
*/
inline Expr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
- auto shape = tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
- buf->shape, tvm::ir::Call::CallType::Intrinsic);
+ auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
+ buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
Expr strides;
if (buf->strides.size() > 0) {
- strides = tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
- buf->shape, tvm::ir::Call::CallType::Intrinsic);
+ strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
+ buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
} else {
strides = 0;
}
make_const(buf->dtype, 0),
buf->elem_offset
};
- return tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
- pack_args, tvm::ir::Call::CallType::Intrinsic);
+ return tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
+ pack_args, tvm::ir::CallNode::CallType::Intrinsic);
}
/*!
* \return An expression representing the invocation
*/
inline Expr call_packed(Array<Expr> args) {
- return tvm::ir::Call::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
- args, tvm::ir::Call::CallType::Intrinsic);
+ return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
+ args, tvm::ir::CallNode::CallType::Intrinsic);
}
} // namespace detail
inline bool is_empty_shape(const Array<Expr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
- if (auto int_dim = dim.as<IntImm>()) {
+ if (auto int_dim = dim.as<IntImmNode>()) {
if (int_dim->value == 0) {
is_empty = true;
break;
Expr zero = make_zero(x->dtype);
Expr one = make_const(x->dtype, 1);
Expr minus_one = make_const(x->dtype, -1);
- auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
- auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
+ auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero);
+ auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1);
return s2;
}, name, tag);
}
if (expr.dtype().lanes() == type.lanes()) {
return expr;
} else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
- return tvm::ir::Broadcast::make(expr, type.lanes());
+ return tvm::ir::BroadcastNode::make(expr, type.lanes());
}
}
std::string tag = kElementWise) {
return compute(x->shape,
[&](const Array<Var>& i) {
- return tvm::ir::Call::make(type, "reinterpret", {x(i)},
- tvm::ir::Call::PureIntrinsic);
+ return tvm::ir::CallNode::make(type, "reinterpret", {x(i)},
+ tvm::ir::CallNode::PureIntrinsic);
},
name, tag);
}
[&](const tvm::Array<tvm::Var>& i) {
auto value = t(i);
auto calpha = tvm::make_const(value.dtype(), alpha);
- return tvm::ir::Select::make(value > 0, value, value * calpha);
+ return tvm::ir::SelectNode::make(value > 0, value, value * calpha);
},
name,
tag);
return tvm::compute(x->shape,
[&](const tvm::Array<tvm::Var> &indices) {
auto xval = x(indices);
- return tvm::ir::Select::make(
+ return tvm::ir::SelectNode::make(
xval > 0,
xval,
xval * slope(indices[axis]));
if (sel.size() != 0) {
if (pad_mode == "constant") {
return tvm::if_then_else(
- detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
+ detail::Map(sel, tvm::ir::AndNode::make), t(indices), pad_value);
} else if (pad_mode == "edge" || pad_mode == "reflect") {
return tvm::if_then_else(
- detail::Map(sel, tvm::ir::And::make), t(indices), t(pad_idx));
+ detail::Map(sel, tvm::ir::AndNode::make), t(indices), t(pad_idx));
}
}
return t(indices);
} else {
Expr h_start = output[height_axis] * stride_height - pad_top;
Expr w_start = output[width_axis] * stride_width - pad_left;
- Expr h_end = ir::Min::make(h_start + kernel_height, height);
- Expr w_end = ir::Min::make(w_start + kernel_width, width);
- h_start = ir::Max::make(h_start, make_const(DataType::DataType::Int(32), 0));
- w_start = ir::Max::make(w_start, make_const(DataType::DataType::Int(32), 0));
- Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
+ Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
+ Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+ h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
+ w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
+ Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
make_const(DataType::DataType::Int(32), 1));
return div(pool_sum(indices), divide_factor);
}
out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
- Expr out_idx_lower_h = ir::Select::make(
+ Expr out_idx_lower_h = ir::SelectNode::make(
pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
(pad_inds[height_axis] - kernel_height) / stride_height + 1);
- Expr out_idx_lower_w = ir::Select::make(
+ Expr out_idx_lower_w = ir::SelectNode::make(
pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
(pad_inds[width_axis] - kernel_width) / stride_width + 1);
return tvm::sum(
- tvm::if_then_else(ir::And::make(
- ir::And::make(out_idx[height_axis] >= out_idx_lower_h,
+ tvm::if_then_else(ir::AndNode::make(
+ ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
out_idx[width_axis] >= out_idx_lower_w),
mp_inds(out_idx) == idx),
out_grad(out_idx), make_const(x->dtype, 0)),
out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
- Expr out_idx_lower_h = ir::Select::make(
+ Expr out_idx_lower_h = ir::SelectNode::make(
pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
(pad_h_idx - kernel_height) / stride_height + 1);
- Expr out_idx_lower_w = ir::Select::make(
+ Expr out_idx_lower_w = ir::SelectNode::make(
pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
(pad_w_idx - kernel_width) / stride_width + 1);
} else {
Expr h_start = out_idx[height_axis] * stride_height - pad_top;
Expr w_start = out_idx[width_axis] * stride_width - pad_left;
- Expr h_end = ir::Min::make(h_start + kernel_height, height);
- Expr w_end = ir::Min::make(w_start + kernel_width, width);
- h_start = ir::Max::make(h_start, make_const(DataType::Int(32), 0));
- w_start = ir::Max::make(w_start, make_const(DataType::Int(32), 0));
+ Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
+ Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+ h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0));
+ w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0));
divide_factor =
- ir::Max::make((h_end - h_start) * (w_end - w_start),
+ ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
make_const(DataType::Int(32), 1));
}
return tvm::sum(tvm::if_then_else(
- ir::And::make(
- ir::And::make(out_idx[height_axis] >= out_idx_lower_h,
+ ir::AndNode::make(
+ ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
out_idx[height_axis] < out_height),
- ir::And::make(out_idx[width_axis] >= out_idx_lower_w,
+ ir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w,
out_idx[width_axis] < out_width)),
out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
{windowh, windoww});
const Expr& odim,
const Expr& idim) {
Expr tmp = indexdiv((out_index + 1) * idim, odim);
- return tvm::ir::Select::make(indexmod((out_index + 1) * idim, odim) == 0,
+ return tvm::ir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0,
tmp, tmp + 1);
}
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
start[i] = output[ii] * stride[i] - pad_head[i];
- end[i] = ir::Min::make(start[i] + kernel[i], x->shape[ii]);
- start[i] = ir::Max::make(start[i], make_const(DataType::Int(32), 0));
+ end[i] = ir::MinNode::make(start[i] + kernel[i], x->shape[ii]);
+ start[i] = ir::MaxNode::make(start[i], make_const(DataType::Int(32), 0));
kernel_size *= (end[i] - start[i]);
}
- Expr divide_factor = ir::Max::make(kernel_size, make_const(DataType::Int(32), 1));
+ Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
return div(pool_sum(indices), divide_factor);
}
}, "tensor", kElementWise);
auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem);
Array<Expr> outputs;
for (size_t i = 0; i < exprs.size(); ++i) {
- outputs.push_back(tvm::ir::Reduce::make(combiner, exprs, axis, cond, static_cast<int>(i)));
+ outputs.push_back(
+ tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
}
return outputs;
};
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
- result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
- result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
+ result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
+ result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
Array<Expr> result;
- result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
- result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
+ result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
+ result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
Array<Expr> target_shape;
for (const auto &ele : newshape) {
- if (ele.as<IntImm>()) {
+ if (ele.as<IntImmNode>()) {
target_shape.push_back(cast(DataType::Int(32), ele));
} else {
target_shape.push_back(ele);
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape, [&](const Array<Var>& indices) {
- return tvm::ir::Select::make(condition(indices) != 0, x(indices), y(indices));
+ return tvm::ir::SelectNode::make(condition(indices) != 0, x(indices), y(indices));
}, name, tag);
} else {
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
out = compute(
oshape, [&](const Array<Var>& indices) {
Array<Expr> condition_idx{indices[0]};
- return tvm::ir::Select::make(condition(condition_idx) != 0,
+ return tvm::ir::SelectNode::make(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
}
}
auto idx = iter_vars[true_axis];
- return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
+ return ir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
}, name, tag);
}