From: Tianqi Chen Date: Wed, 8 Jan 2020 17:01:00 +0000 (-0800) Subject: [REFACTOR][IR] Add Node suffix to low-level IR nodes (#4649) X-Git-Tag: upstream/0.7.0~1424 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f4c5f93b4a14d9da9166366cc1c4e867d838370d;p=platform%2Fupstream%2Ftvm.git [REFACTOR][IR] Add Node suffix to low-level IR nodes (#4649) * [REFACTOR][IR] Variable -> VarNode * [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc. * [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc. * [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc. * [REFACTOR][IR] Add Node suffix to Select/Call/Load/Ramp/Shuffle/Let * [REFACTOR][IR] Add node suffix to IntImm/UIntImm/FloatImm/StringImm * [REFACTOR][IR] Add Node suffix to Any, AttrStmt, AssertStmt * [REFACTOR][IR] Add Node suffix to Store/Provide/Allocate/Free * [REFACTOR][IR] Add Node suffix to ProducerConsumer * Fix lint * style updates, test fixes --- diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index e5f7567..d135d30 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -564,7 +564,7 @@ IntSet EvalSet(Expr e, * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Expr e, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -586,7 +586,7 @@ IntSet EvalSet(Range r, * \return An integer set that can cover all the possible values. */ IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s, * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Range r, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; @@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map; */ ExprIntSetMap EvalSetForEachSubExpr( Expr e, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets @@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond, * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map); + const std::unordered_map& hint_map, + const std::unordered_map& relax_map); /*! * \brief Infer a regular domain that covers all the calls or provides within the given statement. diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 0178eab..13c8b30 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } else { Expr expr = val; CHECK(expr.defined()); - if (const ir::IntImm* op = expr.as()) { + if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImm* op = expr.as()) { + } else if (const ir::UIntImmNode* op = expr.as()) { *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); @@ -503,7 +503,7 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { *ptr = val.operator std::string(); } else { Expr expr = val; - const ir::StringImm* op = expr.as(); + const ir::StringImmNode* op = expr.as(); CHECK(op != nullptr); *ptr = op->value; } @@ -519,11 +519,11 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } else { Expr expr = val; CHECK(expr.defined()); - if (const ir::IntImm* op = expr.as()) { + if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::IntImm* op = expr.as()) { + } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImm* op = expr.as()) { + } else if (const ir::UIntImmNode* op = expr.as()) { *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); diff --git a/include/tvm/expr.h b/include/tvm/expr.h index aee565d..64d7547 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -102,7 +102,7 @@ class Var; * - Let * - LetStmt */ -class Variable : public ExprNode { +class VarNode : public ExprNode { public: /*! * \brief The hint to the variable name. @@ -118,7 +118,7 @@ class Variable : public ExprNode { } static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; /*! \brief a named variable in TVM */ @@ -139,18 +139,18 @@ class Var : public Expr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const Variable* operator->() const { + const VarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const Variable* get() const { - return static_cast(data_.get()); + const VarNode* get() const { + return static_cast(data_.get()); } /*! \brief type indicate the container type */ - using ContainerType = Variable; + using ContainerType = VarNode; }; // Backward compatibility, will be removed later. @@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual; class Integer; /*! \brief ExprNode: constant integer. */ -class IntImm : public ExprNode { +class IntImmNode : public ExprNode { public: /*! \brief the Internal value. */ int64_t value; @@ -174,7 +174,7 @@ class IntImm : public ExprNode { TVM_DLL static Integer make(DataType t, int64_t value); static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode); }; /*! @@ -206,8 +206,8 @@ class Integer : public Expr { * \brief Get pointer to the internal value. * \return the content of the integer. */ - const IntImm* operator->() const { - return static_cast(get()); + const IntImmNode* operator->() const { + return static_cast(get()); } /*! * \brief convert to int64_t @@ -218,7 +218,7 @@ class Integer : public Expr { return (*this)->value; } /*! \brief type indicate the container type */ - using ContainerType = IntImm; + using ContainerType = IntImmNode; }; /*! \brief range over one dimension */ diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index a73edb4..bf8b1a3 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) { */ inline const int64_t* as_const_int(const Expr& x) { if (!x.defined()) return nullptr; - if (const ir::IntImm* op = x.as()) { + if (const ir::IntImmNode* op = x.as()) { return &(op->value); } else { return nullptr; @@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) { */ inline const uint64_t* as_const_uint(const Expr& x) { if (!x.defined()) return nullptr; - if (const ir::UIntImm* op = x.as()) { + if (const ir::UIntImmNode* op = x.as()) { return &(op->value); } else { return nullptr; @@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ - return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \ + return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \ } \ TVM_DECLARE_INTRIN_UNARY(exp); @@ -617,11 +617,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const Expr& x) { - if (x.as() || x.as()) { + if (x.as() || x.as()) { return true; - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { const Expr& val = op->value; - if (val.as() || val.as()) { + if (val.as() || val.as()) { return true; } } @@ -629,9 +629,9 @@ inline bool is_const(const Expr& x) { } inline bool is_positive_const(const Expr& a) { - if (const ir::IntImm* op = a.as()) { + if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImm* op = a.as()) { + } else if (const ir::UIntImmNode* op = a.as()) { return op->value > 0; } else { return false; @@ -639,7 +639,7 @@ inline bool is_positive_const(const Expr& a) { } inline bool is_negative_const(const Expr& a) { - if (const ir::IntImm* op = a.as()) { + if (const ir::IntImmNode* op = a.as()) { return op->value < 0; } else { return false; @@ -647,15 +647,15 @@ inline bool is_negative_const(const Expr& a) { } inline bool is_const_int(const Expr& x, int64_t value) { - if (const auto* op = x.as()) { + if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { return op->value == static_cast(value); - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { const Expr& val = op->value; - if (const auto* opv = val.as()) { + if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { + } else if (const auto* opv = val.as()) { return opv->value == static_cast(value); } } @@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) { inline bool is_no_op(const Stmt& stmt) { if (!stmt.defined()) return true; - if (const auto* op = stmt.as()) { + if (const auto* op = stmt.as()) { return is_const(op->value); } if (const auto* op = stmt.as()) { @@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) { template inline Expr MakeConstScalar(DataType t, ValueType value) { - if (t.is_int()) return ir::IntImm::make(t, static_cast(value)); - if (t.is_uint()) return ir::UIntImm::make(t, static_cast(value)); - if (t.is_float()) return ir::FloatImm::make(t, static_cast(value)); + if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); + if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); + if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? if (static_cast(t.code()) >= static_cast(kCustomBegin)) - return ir::FloatImm::make(t, static_cast(value)); + return ir::FloatImmNode::make(t, static_cast(value)); LOG(FATAL) << "cannot make const for type " << t; return Expr(); } @@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return ir::Broadcast::make( + return ir::BroadcastNode::make( MakeConstScalar(t.element_of(), value), t.lanes()); } } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b1cefff..11ce09d 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -36,11 +36,11 @@ namespace tvm { namespace ir { -using IntImm = tvm::IntImm; -using Variable = tvm::Variable; +using IntImmNode = tvm::IntImmNode; +using VarNode = tvm::VarNode; /*! \brief constant unsigned integer. */ -class UIntImm : public ExprNode { +class UIntImmNode : public ExprNode { public: /*! \brief The constant value content. */ uint64_t value; @@ -53,11 +53,11 @@ class UIntImm : public ExprNode { TVM_DLL static Expr make(DataType t, uint64_t value); static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode); }; /*! \brief Floating point constants. */ -class FloatImm : public ExprNode { +class FloatImmNode : public ExprNode { public: /*! \brief The constant value content. */ double value; @@ -70,11 +70,11 @@ class FloatImm : public ExprNode { TVM_DLL static Expr make(DataType t, double value); static constexpr const char* _type_key = "FloatImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(FloatImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode); }; /*! \brief String constants, only used in asserts. */ -class StringImm : public ExprNode { +class StringImmNode : public ExprNode { public: /*! \brief The constant value content. */ std::string value; @@ -87,14 +87,14 @@ class StringImm : public ExprNode { TVM_DLL Expr static make(std::string value); static constexpr const char* _type_key = "StringImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(StringImm, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode); }; /*! * \brief Cast value from one data type to another. * \note The lanes of value should keep fixed. */ -class Cast : public ExprNode { +class CastNode : public ExprNode { public: /*! \brief Original data type. */ Expr value; @@ -107,7 +107,7 @@ class Cast : public ExprNode { TVM_DLL static Expr make(DataType t, Expr v); static constexpr const char* _type_key = "Cast"; - TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode); }; /*! @@ -143,19 +143,19 @@ class BinaryOpNode : public ExprNode { }; /*! \brief a + b */ -class Add : public BinaryOpNode { +class AddNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Add"; }; /*! \brief a - b */ -class Sub : public BinaryOpNode { +class SubNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Sub"; }; /*! \brief a * b */ -class Mul : public BinaryOpNode { +class MulNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Mul"; }; @@ -164,7 +164,7 @@ class Mul : public BinaryOpNode { * \brief a / b in the C semnatics. * \note For integer division, C standard uses trunc div. */ -class Div : public BinaryOpNode
{ +class DivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Div"; }; @@ -173,31 +173,31 @@ class Div : public BinaryOpNode
{ * \brief a % b in the C semnatics. * \note For integer division, C standard uses trunc div. */ -class Mod : public BinaryOpNode { +class ModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Mod"; }; /*! \brief Floor division, floor(a/b) */ -class FloorDiv : public BinaryOpNode { +class FloorDivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorDiv"; }; /*! \brief The remainder of the floordiv */ -class FloorMod : public BinaryOpNode { +class FloorModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorMod"; }; /*! \brief min(a, b) */ -class Min : public BinaryOpNode { +class MinNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Min"; }; /*! \brief max(a, b) */ -class Max : public BinaryOpNode { +class MaxNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Max"; }; @@ -235,43 +235,43 @@ class CmpOpNode : public ExprNode { }; /*! \brief a == b */ -class EQ : public CmpOpNode { +class EQNode : public CmpOpNode { public: static constexpr const char* _type_key = "EQ"; }; /*! \brief a != b */ -class NE : public CmpOpNode { +class NENode : public CmpOpNode { public: static constexpr const char* _type_key = "NE"; }; /*! \brief a < b */ -class LT : public CmpOpNode { +class LTNode : public CmpOpNode { public: static constexpr const char* _type_key = "LT"; }; /*! \brief a <= b */ -struct LE : public CmpOpNode { +struct LENode : public CmpOpNode { public: static constexpr const char* _type_key = "LE"; }; /*! \brief a > b */ -class GT : public CmpOpNode { +class GTNode : public CmpOpNode { public: static constexpr const char* _type_key = "GT"; }; /*! \brief a >= b */ -class GE : public CmpOpNode { +class GENode : public CmpOpNode { public: static constexpr const char* _type_key = "GE"; }; /*! \brief a && b */ -class And : public ExprNode { +class AndNode : public ExprNode { public: /*! \brief The left operand. */ Expr a; @@ -287,11 +287,11 @@ class And : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "And"; - TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode); }; /*! \brief a || b */ -class Or : public ExprNode { +class OrNode : public ExprNode { public: /*! \brief The left operand. */ Expr a; @@ -307,11 +307,11 @@ class Or : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "Or"; - TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode); }; /*! \brief !a */ -class Not : public ExprNode { +class NotNode : public ExprNode { public: /*! \brief The input operand. */ Expr a; @@ -324,7 +324,7 @@ class Not : public ExprNode { TVM_DLL static Expr make(Expr a); static constexpr const char* _type_key = "Not"; - TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode); }; /*! @@ -334,7 +334,7 @@ class Not : public ExprNode { * Do not use it to guard against out of bound access, * please use if_then_else instead. */ -class Select : public ExprNode { +class SelectNode : public ExprNode { public: /*! \brief The condition */ Expr condition; @@ -353,7 +353,7 @@ class Select : public ExprNode { TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); static constexpr const char* _type_key = "Select"; - TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode); }; /*! @@ -371,7 +371,7 @@ class Select : public ExprNode { * * \endcode */ -class Load : public ExprNode { +class LoadNode : public ExprNode { public: /*! \brief The buffer variable. */ Var buffer_var; @@ -390,7 +390,7 @@ class Load : public ExprNode { TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate); static constexpr const char* _type_key = "Load"; - TVM_DECLARE_FINAL_OBJECT_INFO(Load, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode); }; /*! @@ -402,7 +402,7 @@ class Load : public ExprNode { * - ramp(0, 1, 3) = [0, 1, 2] * - ramp(1, 2, 4) = [1, 3, 5, 7] */ -class Ramp : public ExprNode { +class RampNode : public ExprNode { public: /*! \brief The base value. */ Expr base; @@ -421,11 +421,11 @@ class Ramp : public ExprNode { TVM_DLL static Expr make(Expr base, Expr stride, int lanes); static constexpr const char* _type_key = "Ramp"; - TVM_DECLARE_FINAL_OBJECT_INFO(Ramp, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode); }; /*! \brief Create a vector where all the elements are value. */ -class Broadcast : public ExprNode { +class BroadcastNode : public ExprNode { public: /*! \brief The base value. */ Expr value; @@ -441,13 +441,13 @@ class Broadcast : public ExprNode { TVM_DLL static Expr make(Expr value, int lanes); static constexpr const char* _type_key = "Broadcast"; - TVM_DECLARE_FINAL_OBJECT_INFO(Broadcast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode); }; /*! * \brief Let binding. Bind var to value then evaluate body. */ -class Let : public ExprNode { +class LetNode : public ExprNode { public: /*! \brief The variable. */ Var var; @@ -466,7 +466,7 @@ class Let : public ExprNode { TVM_DLL static Expr make(Var var, Expr value, Expr body); static constexpr const char* _type_key = "Let"; - TVM_DECLARE_FINAL_OBJECT_INFO(Let, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); }; // Call node, represent a function call or a multi-dimensional array load. @@ -494,7 +494,7 @@ class FunctionRef : public ObjectRef { /*! * \brief Call node. */ -class Call : public ExprNode { +class CallNode : public ExprNode { public: /*! \brief Possible types of calls. */ enum CallType : int { @@ -560,7 +560,7 @@ class Call : public ExprNode { bool is_vectorizable() const; static constexpr const char* _type_key = "Call"; - TVM_DECLARE_FINAL_OBJECT_INFO(Call, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); // Build-in intrinsics static constexpr const char* reinterpret = "reinterpret"; @@ -585,7 +585,7 @@ class Call : public ExprNode { * vec = concat(vectors) * result = (vec[indices[0]], vec[indices[1]] ...) */ -class Shuffle : public ExprNode { +class ShuffleNode : public ExprNode { public: /*! \brief the input vectors. */ Array vectors; @@ -602,7 +602,7 @@ class Shuffle : public ExprNode { TVM_DLL static Expr make_extract_element(Expr vector, int index); static constexpr const char* _type_key = "Shuffle"; - TVM_DECLARE_FINAL_OBJECT_INFO(Shuffle, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode); }; // Reduce operator @@ -671,7 +671,7 @@ inline const CommReducerNode* CommReducer::operator->() const { } /*! \brief Reduction operator operator */ -class Reduce : public ExprNode { +class ReduceNode : public ExprNode { public: /*! \brief The commutative combiner */ CommReducer combiner; @@ -704,29 +704,29 @@ class Reduce : public ExprNode { } static constexpr const char* _type_key = "Reduce"; - TVM_DECLARE_FINAL_OBJECT_INFO(Reduce, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode); }; /*! \brief Any shape. */ -class Any : public ExprNode { +class AnyNode : public ExprNode { public: void VisitAttrs(AttrVisitor* v) {} /*! \brief Convert to var. */ Var ToVar() const { - return Variable::make(DataType::Int(32), "any_dim"); + return VarNode::make(DataType::Int(32), "any_dim"); } TVM_DLL static Expr make(); static constexpr const char* _type_key = "Any"; - TVM_DECLARE_FINAL_OBJECT_INFO(Any, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode); }; // Statements /*! * \brief Let binding, bind var to value, then run body. */ -class LetStmt : public StmtNode { +class LetStmtNode : public StmtNode { public: /*! \brief The variable. */ Var var; @@ -744,7 +744,7 @@ class LetStmt : public StmtNode { TVM_DLL static Stmt make(Var var, Expr value, Stmt body); static constexpr const char* _type_key = "LetStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(LetStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); }; /*! @@ -757,7 +757,7 @@ class LetStmt : public StmtNode { * - Bound of function, variables. * - Hint which block corresponds to a parallel region. */ -class AttrStmt : public StmtNode { +class AttrStmtNode : public StmtNode { public: /*! \brief this is attribute about certain node */ ObjectRef node; @@ -781,13 +781,13 @@ class AttrStmt : public StmtNode { Stmt body); static constexpr const char* _type_key = "AttrStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); }; /*! * \brief Assert condition, if an error occurs, return the error message. */ -class AssertStmt : public StmtNode { +class AssertStmtNode : public StmtNode { public: /*! \brief Condition to be checked. */ Expr condition; @@ -808,12 +808,12 @@ class AssertStmt : public StmtNode { TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body); static constexpr const char* _type_key = "AssertStmt"; - TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmt, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; // TODO(tvm-team): consider consolidate with AttrStmt. /*! \brief annotation node of producer/consumer relation. */ -class ProducerConsumer : public StmtNode { +class ProducerConsumerNode : public StmtNode { public: /*! \brief The corresponding tensor. */ FunctionRef func; @@ -831,7 +831,7 @@ class ProducerConsumer : public StmtNode { TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); static constexpr const char* _type_key = "ProducerConsumer"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumer, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode); }; /*! @@ -850,9 +850,9 @@ class ProducerConsumer : public StmtNode { * buffer[index.v2] = value.v2; * * \endcode - * \sa Load + * \sa LoadNode */ -class Store : public StmtNode { +class StoreNode : public StmtNode { public: /*! \brief The buffer variable. */ Var buffer_var; @@ -876,13 +876,13 @@ class Store : public StmtNode { Expr predicate); static constexpr const char* _type_key = "Store"; - TVM_DECLARE_FINAL_OBJECT_INFO(Store, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); }; /*! * \brief Store value into mult-dimensional array defined by func. */ -class Provide : public StmtNode { +class ProvideNode : public StmtNode { public: /*! \brief The function to be updated. */ FunctionRef func; @@ -906,13 +906,13 @@ class Provide : public StmtNode { Array args); static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_FINAL_OBJECT_INFO(Provide, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); }; /*! * \brief Allocate a buffer that can be used in body. */ -class Allocate : public StmtNode { +class AllocateNode : public StmtNode { public: /*! \brief The buffer variable. */ Var buffer_var; @@ -963,11 +963,11 @@ class Allocate : public StmtNode { const Array& extents); static constexpr const char* _type_key = "Allocate"; - TVM_DECLARE_FINAL_OBJECT_INFO(Allocate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; /*! \brief Free the resources in the buffer before the scope ends. */ -class Free : public StmtNode { +class FreeNode : public StmtNode { public: /*! \brief The buffer variable. */ Var buffer_var; @@ -979,14 +979,14 @@ class Free : public StmtNode { TVM_DLL static Stmt make(Var buffer_var); static constexpr const char* _type_key = "Free"; - TVM_DECLARE_FINAL_OBJECT_INFO(Free, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); }; /*! * \brief Annotate the bounds where func need to be written and read in body. * We will need to allocate space for the corresponding regions. */ -class Realize : public StmtNode { +class RealizeNode : public StmtNode { public: /*! \brief The function to be realized. */ FunctionRef func; @@ -1018,7 +1018,7 @@ class Realize : public StmtNode { Stmt body); static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); }; /*! @@ -1104,7 +1104,7 @@ class SeqStmt : public Stmt { if (!stmt.defined()) return; if (auto* op = stmt.as()) { operator()(0, op->seq); - } else if (auto* op = stmt.as()) { + } else if (auto* op = stmt.as()) { // NOTE: The consumer block annotation was not as useful and can be safely dropped. if (!op->is_producer) { operator()(0, op->body); @@ -1133,7 +1133,7 @@ class SeqStmt : public Stmt { /*! * \brief IfThenElse statment. */ -class IfThenElse : public StmtNode { +class IfThenElseNode : public StmtNode { public: /*! \brief The condition. */ Expr condition; @@ -1151,7 +1151,7 @@ class IfThenElse : public StmtNode { TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt()); static constexpr const char* _type_key = "IfThenElse"; - TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElse, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); }; /*! @@ -1160,7 +1160,7 @@ class IfThenElse : public StmtNode { * * If value do not have side-effect, this node can be safely removed. */ -class Evaluate : public StmtNode { +class EvaluateNode : public StmtNode { public: /*! \brief The expression to be evaluated. */ Expr value; @@ -1172,7 +1172,7 @@ class Evaluate : public StmtNode { TVM_DLL static Stmt make(Expr v); static constexpr const char* _type_key = "Evaluate"; - TVM_DECLARE_FINAL_OBJECT_INFO(Evaluate, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); }; /*! \brief Additional annotation of for loop. */ @@ -1204,7 +1204,7 @@ enum class DeviceAPI: int { * } * \endcode */ -class For : public StmtNode { +class ForNode : public StmtNode { public: /*! \brief The loop variable. */ Var loop_var; @@ -1239,13 +1239,13 @@ class For : public StmtNode { } static constexpr const char* _type_key = "For"; - TVM_DECLARE_FINAL_OBJECT_INFO(For, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; /*! * \brief A prefetch hint of func. */ -class Prefetch : public StmtNode { +class PrefetchNode : public StmtNode { public: /*! \brief The function to be prefetched. */ FunctionRef func; @@ -1269,7 +1269,7 @@ class Prefetch : public StmtNode { Region bounds); static constexpr const char* _type_key = "Prefetch"; - TVM_DECLARE_FINAL_OBJECT_INFO(Prefetch, StmtNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); }; /*! @@ -1708,9 +1708,9 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; * \return Expr a expression with dtype. */ inline Expr TypeAnnotation(DataType dtype) { - return ir::Call::make(dtype, + return ir::CallNode::make(dtype, "type_annotation", {}, - ir::Call::PureIntrinsic); + ir::CallNode::PureIntrinsic); } // overload printing of for type. diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 6cc6d70..d70c8de 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -132,38 +132,38 @@ class ExprFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -174,38 +174,38 @@ class ExprFunctor { static FType InitVTable() { FType vtable; // Set dispatch - IR_EXPR_FUNCTOR_DISPATCH(Variable); - IR_EXPR_FUNCTOR_DISPATCH(Load); - IR_EXPR_FUNCTOR_DISPATCH(Let); - IR_EXPR_FUNCTOR_DISPATCH(Call); - IR_EXPR_FUNCTOR_DISPATCH(Add); - IR_EXPR_FUNCTOR_DISPATCH(Sub); - IR_EXPR_FUNCTOR_DISPATCH(Mul); - IR_EXPR_FUNCTOR_DISPATCH(Div); - IR_EXPR_FUNCTOR_DISPATCH(Mod); - IR_EXPR_FUNCTOR_DISPATCH(FloorDiv); - IR_EXPR_FUNCTOR_DISPATCH(FloorMod); - IR_EXPR_FUNCTOR_DISPATCH(Min); - IR_EXPR_FUNCTOR_DISPATCH(Max); - IR_EXPR_FUNCTOR_DISPATCH(EQ); - IR_EXPR_FUNCTOR_DISPATCH(NE); - IR_EXPR_FUNCTOR_DISPATCH(LT); - IR_EXPR_FUNCTOR_DISPATCH(LE); - IR_EXPR_FUNCTOR_DISPATCH(GT); - IR_EXPR_FUNCTOR_DISPATCH(GE); - IR_EXPR_FUNCTOR_DISPATCH(And); - IR_EXPR_FUNCTOR_DISPATCH(Or); - IR_EXPR_FUNCTOR_DISPATCH(Reduce); - IR_EXPR_FUNCTOR_DISPATCH(Cast); - IR_EXPR_FUNCTOR_DISPATCH(Not); - IR_EXPR_FUNCTOR_DISPATCH(Select); - IR_EXPR_FUNCTOR_DISPATCH(Ramp); - IR_EXPR_FUNCTOR_DISPATCH(Shuffle); - IR_EXPR_FUNCTOR_DISPATCH(Broadcast); - IR_EXPR_FUNCTOR_DISPATCH(IntImm); - IR_EXPR_FUNCTOR_DISPATCH(UIntImm); - IR_EXPR_FUNCTOR_DISPATCH(FloatImm); - IR_EXPR_FUNCTOR_DISPATCH(StringImm); + IR_EXPR_FUNCTOR_DISPATCH(VarNode); + IR_EXPR_FUNCTOR_DISPATCH(LoadNode); + IR_EXPR_FUNCTOR_DISPATCH(LetNode); + IR_EXPR_FUNCTOR_DISPATCH(CallNode); + IR_EXPR_FUNCTOR_DISPATCH(AddNode); + IR_EXPR_FUNCTOR_DISPATCH(SubNode); + IR_EXPR_FUNCTOR_DISPATCH(MulNode); + IR_EXPR_FUNCTOR_DISPATCH(DivNode); + IR_EXPR_FUNCTOR_DISPATCH(ModNode); + IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); + IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); + IR_EXPR_FUNCTOR_DISPATCH(MinNode); + IR_EXPR_FUNCTOR_DISPATCH(MaxNode); + IR_EXPR_FUNCTOR_DISPATCH(EQNode); + IR_EXPR_FUNCTOR_DISPATCH(NENode); + IR_EXPR_FUNCTOR_DISPATCH(LTNode); + IR_EXPR_FUNCTOR_DISPATCH(LENode); + IR_EXPR_FUNCTOR_DISPATCH(GTNode); + IR_EXPR_FUNCTOR_DISPATCH(GENode); + IR_EXPR_FUNCTOR_DISPATCH(AndNode); + IR_EXPR_FUNCTOR_DISPATCH(OrNode); + IR_EXPR_FUNCTOR_DISPATCH(ReduceNode); + IR_EXPR_FUNCTOR_DISPATCH(CastNode); + IR_EXPR_FUNCTOR_DISPATCH(NotNode); + IR_EXPR_FUNCTOR_DISPATCH(SelectNode); + IR_EXPR_FUNCTOR_DISPATCH(RampNode); + IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); + IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); + IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); + IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); + IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); + IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; } }; @@ -241,20 +241,20 @@ class StmtFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -264,20 +264,20 @@ class StmtFunctor { // initialize the vtable. static FType InitVTable() { FType vtable; - IR_STMT_FUNCTOR_DISPATCH(LetStmt); - IR_STMT_FUNCTOR_DISPATCH(AttrStmt); - IR_STMT_FUNCTOR_DISPATCH(IfThenElse); - IR_STMT_FUNCTOR_DISPATCH(For); - IR_STMT_FUNCTOR_DISPATCH(Allocate); - IR_STMT_FUNCTOR_DISPATCH(Store); - IR_STMT_FUNCTOR_DISPATCH(Free); - IR_STMT_FUNCTOR_DISPATCH(AssertStmt); - IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer); - IR_STMT_FUNCTOR_DISPATCH(Provide); - IR_STMT_FUNCTOR_DISPATCH(Realize); - IR_STMT_FUNCTOR_DISPATCH(Prefetch); + IR_STMT_FUNCTOR_DISPATCH(LetStmtNode); + IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); + IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); + IR_STMT_FUNCTOR_DISPATCH(ForNode); + IR_STMT_FUNCTOR_DISPATCH(AllocateNode); + IR_STMT_FUNCTOR_DISPATCH(StoreNode); + IR_STMT_FUNCTOR_DISPATCH(FreeNode); + IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); + IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode); + IR_STMT_FUNCTOR_DISPATCH(ProvideNode); + IR_STMT_FUNCTOR_DISPATCH(RealizeNode); + IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); - IR_STMT_FUNCTOR_DISPATCH(Evaluate); + IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); return vtable; } }; @@ -298,38 +298,38 @@ class TVM_DLL ExprVisitor : protected: using ExprFunctor::VisitExpr; // list of functions to override. - void VisitExpr_(const Variable* op) override; - void VisitExpr_(const Load* op) override; - void VisitExpr_(const Let* op) override; - void VisitExpr_(const Call* op) override; - void VisitExpr_(const Add* op) override; - void VisitExpr_(const Sub* op) override; - void VisitExpr_(const Mul* op) override; - void VisitExpr_(const Div* op) override; - void VisitExpr_(const Mod* op) override; - void VisitExpr_(const FloorDiv* op) override; - void VisitExpr_(const FloorMod* op) override; - void VisitExpr_(const Min* op) override; - void VisitExpr_(const Max* op) override; - void VisitExpr_(const EQ* op) override; - void VisitExpr_(const NE* op) override; - void VisitExpr_(const LT* op) override; - void VisitExpr_(const LE* op) override; - void VisitExpr_(const GT* op) override; - void VisitExpr_(const GE* op) override; - void VisitExpr_(const And* op) override; - void VisitExpr_(const Or* op) override; - void VisitExpr_(const Reduce* op) override; - void VisitExpr_(const Cast* op) override; - void VisitExpr_(const Not* op) override; - void VisitExpr_(const Select* op) override; - void VisitExpr_(const Ramp* op) override; - void VisitExpr_(const Broadcast* op) override; - void VisitExpr_(const Shuffle* op) override; - void VisitExpr_(const IntImm* op) override; - void VisitExpr_(const UIntImm* op) override; - void VisitExpr_(const FloatImm* op) override; - void VisitExpr_(const StringImm* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const AddNode* op) override; + void VisitExpr_(const SubNode* op) override; + void VisitExpr_(const MulNode* op) override; + void VisitExpr_(const DivNode* op) override; + void VisitExpr_(const ModNode* op) override; + void VisitExpr_(const FloorDivNode* op) override; + void VisitExpr_(const FloorModNode* op) override; + void VisitExpr_(const MinNode* op) override; + void VisitExpr_(const MaxNode* op) override; + void VisitExpr_(const EQNode* op) override; + void VisitExpr_(const NENode* op) override; + void VisitExpr_(const LTNode* op) override; + void VisitExpr_(const LENode* op) override; + void VisitExpr_(const GTNode* op) override; + void VisitExpr_(const GENode* op) override; + void VisitExpr_(const AndNode* op) override; + void VisitExpr_(const OrNode* op) override; + void VisitExpr_(const ReduceNode* op) override; + void VisitExpr_(const CastNode* op) override; + void VisitExpr_(const NotNode* op) override; + void VisitExpr_(const SelectNode* op) override; + void VisitExpr_(const RampNode* op) override; + void VisitExpr_(const BroadcastNode* op) override; + void VisitExpr_(const ShuffleNode* op) override; + void VisitExpr_(const IntImmNode* op) override; + void VisitExpr_(const UIntImmNode* op) override; + void VisitExpr_(const FloatImmNode* op) override; + void VisitExpr_(const StringImmNode* op) override; }; /*! @@ -343,38 +343,38 @@ class TVM_DLL ExprMutator : protected: using ExprFunctor::VisitExpr; // list of functions to override. - Expr VisitExpr_(const Variable* op) override; - Expr VisitExpr_(const Load* op) override; - Expr VisitExpr_(const Let* op) override; - Expr VisitExpr_(const Call* op) override; - Expr VisitExpr_(const Add* op) override; - Expr VisitExpr_(const Sub* op) override; - Expr VisitExpr_(const Mul* op) override; - Expr VisitExpr_(const Div* op) override; - Expr VisitExpr_(const Mod* op) override; - Expr VisitExpr_(const FloorDiv* op) override; - Expr VisitExpr_(const FloorMod* op) override; - Expr VisitExpr_(const Min* op) override; - Expr VisitExpr_(const Max* op) override; - Expr VisitExpr_(const EQ* op) override; - Expr VisitExpr_(const NE* op) override; - Expr VisitExpr_(const LT* op) override; - Expr VisitExpr_(const LE* op) override; - Expr VisitExpr_(const GT* op) override; - Expr VisitExpr_(const GE* op) override; - Expr VisitExpr_(const And* op) override; - Expr VisitExpr_(const Or* op) override; - Expr VisitExpr_(const Reduce* op) override; - Expr VisitExpr_(const Cast* op) override; - Expr VisitExpr_(const Not* op) override; - Expr VisitExpr_(const Select* op) override; - Expr VisitExpr_(const Ramp* op) override; - Expr VisitExpr_(const Broadcast* op) override; - Expr VisitExpr_(const Shuffle* op) override; - Expr VisitExpr_(const IntImm* op) override; - Expr VisitExpr_(const UIntImm* op) override; - Expr VisitExpr_(const FloatImm* op) override; - Expr VisitExpr_(const StringImm* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const LoadNode* op) override; + Expr VisitExpr_(const LetNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const AddNode* op) override; + Expr VisitExpr_(const SubNode* op) override; + Expr VisitExpr_(const MulNode* op) override; + Expr VisitExpr_(const DivNode* op) override; + Expr VisitExpr_(const ModNode* op) override; + Expr VisitExpr_(const FloorDivNode* op) override; + Expr VisitExpr_(const FloorModNode* op) override; + Expr VisitExpr_(const MinNode* op) override; + Expr VisitExpr_(const MaxNode* op) override; + Expr VisitExpr_(const EQNode* op) override; + Expr VisitExpr_(const NENode* op) override; + Expr VisitExpr_(const LTNode* op) override; + Expr VisitExpr_(const LENode* op) override; + Expr VisitExpr_(const GTNode* op) override; + Expr VisitExpr_(const GENode* op) override; + Expr VisitExpr_(const AndNode* op) override; + Expr VisitExpr_(const OrNode* op) override; + Expr VisitExpr_(const ReduceNode* op) override; + Expr VisitExpr_(const CastNode* op) override; + Expr VisitExpr_(const NotNode* op) override; + Expr VisitExpr_(const SelectNode* op) override; + Expr VisitExpr_(const RampNode* op) override; + Expr VisitExpr_(const BroadcastNode* op) override; + Expr VisitExpr_(const ShuffleNode* op) override; + Expr VisitExpr_(const IntImmNode* op) override; + Expr VisitExpr_(const UIntImmNode* op) override; + Expr VisitExpr_(const FloatImmNode* op) override; + Expr VisitExpr_(const StringImmNode* op) override; }; /*! @@ -396,20 +396,20 @@ class TVM_DLL StmtVisitor : */ virtual void VisitExpr(const Expr& e) {} // statement visitor - void VisitStmt_(const AttrStmt* op) override; - void VisitStmt_(const IfThenElse* op) override; - void VisitStmt_(const LetStmt* op) override; - void VisitStmt_(const For* op) override; - void VisitStmt_(const Allocate* op) override; - void VisitStmt_(const Store* op) override; - void VisitStmt_(const Free* op) override; - void VisitStmt_(const AssertStmt* op) override; - void VisitStmt_(const ProducerConsumer* op) override; - void VisitStmt_(const Provide* op) override; - void VisitStmt_(const Realize* op) override; - void VisitStmt_(const Prefetch* op) override; + void VisitStmt_(const AttrStmtNode* op) override; + void VisitStmt_(const IfThenElseNode* op) override; + void VisitStmt_(const LetStmtNode* op) override; + void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const StoreNode* op) override; + void VisitStmt_(const FreeNode* op) override; + void VisitStmt_(const AssertStmtNode* op) override; + void VisitStmt_(const ProducerConsumerNode* op) override; + void VisitStmt_(const ProvideNode* op) override; + void VisitStmt_(const RealizeNode* op) override; + void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; - void VisitStmt_(const Evaluate* op) override; + void VisitStmt_(const EvaluateNode* op) override; }; /*! @@ -490,20 +490,20 @@ class TVM_DLL StmtMutator : return e; } // statement visitor - Stmt VisitStmt_(const AttrStmt* op) override; - Stmt VisitStmt_(const IfThenElse* op) override; - Stmt VisitStmt_(const LetStmt* op) override; - Stmt VisitStmt_(const For* op) override; - Stmt VisitStmt_(const Allocate* op) override; - Stmt VisitStmt_(const Store* op) override; - Stmt VisitStmt_(const Free* op) override; - Stmt VisitStmt_(const AssertStmt* op) override; - Stmt VisitStmt_(const ProducerConsumer* op) override; - Stmt VisitStmt_(const Provide* op) override; - Stmt VisitStmt_(const Realize* op) override; - Stmt VisitStmt_(const Prefetch* op) override; + Stmt VisitStmt_(const AttrStmtNode* op) override; + Stmt VisitStmt_(const IfThenElseNode* op) override; + Stmt VisitStmt_(const LetStmtNode* op) override; + Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const StoreNode* op) override; + Stmt VisitStmt_(const FreeNode* op) override; + Stmt VisitStmt_(const AssertStmtNode* op) override; + Stmt VisitStmt_(const ProducerConsumerNode* op) override; + Stmt VisitStmt_(const ProvideNode* op) override; + Stmt VisitStmt_(const RealizeNode* op) override; + Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; - Stmt VisitStmt_(const Evaluate* op) override; + Stmt VisitStmt_(const EvaluateNode* op) override; /*! * \brief Alternative advance method for SeqStmtNode. * diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5a81d59..aa1415e 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v); * \param vset The variable set. * \return Whether e uses vset. */ -bool ExprUseVar(const Expr& e, const std::unordered_set& vset); +bool ExprUseVar(const Expr& e, const std::unordered_set& vset); /*! * \brief Convert a IR node to be SSA form. @@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt); * \return The converted form. */ Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); + const std::unordered_map& value_map); /*! * \brief Substitute the var specified in key->var to be value. @@ -157,7 +157,7 @@ Stmt Substitute(Stmt stmt, * \return The converted expression. */ Expr Substitute(Expr expr, - const std::unordered_map& value_map); + const std::unordered_map& value_map); /*! * \brief Substitute the var specified in key->var to be value. diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 681d068..ad8f825 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode { virtual void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. @@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index c8a02a8..31e85f9 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -39,7 +39,7 @@ namespace tvm { namespace relay { -using Any = tvm::ir::Any; +using Any = tvm::ir::AnyNode; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 034405f..ba04239 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -33,7 +33,7 @@ namespace ir { TVM_REGISTER_GLOBAL("_Var") .set_body_typed([](std::string s, DataType t) { - return Variable::make(t, s); + return VarNode::make(t, s); }); TVM_REGISTER_GLOBAL("make.abs") @@ -73,7 +73,7 @@ TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, int for_type, int device_api, Stmt body) { - return For::make(loop_var, + return ForNode::make(loop_var, min, extent, static_cast(for_type), @@ -85,9 +85,9 @@ TVM_REGISTER_GLOBAL("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { DataType t = args[0]; if (args.size() == 3) { - *ret = Load::make(t, args[1], args[2], const_true(t.lanes())); + *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); } else { - *ret = Load::make(t, args[1], args[2], args[3]); + *ret = LoadNode::make(t, args[1], args[2], args[3]); } }); @@ -95,14 +95,14 @@ TVM_REGISTER_GLOBAL("make.Store") .set_body([](TVMArgs args, TVMRetValue *ret) { Expr value = args[1]; if (args.size() == 3) { - *ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes())); + *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); } else { - *ret = Store::make(args[0], value, args[2], args[3]); + *ret = StoreNode::make(args[0], value, args[2], args[3]); } }); TVM_REGISTER_GLOBAL("make.Realize") -.set_body_typed(Realize::make); +.set_body_typed(RealizeNode::make); TVM_REGISTER_GLOBAL("make.Call") .set_body_typed([]( @@ -110,10 +110,10 @@ TVM_REGISTER_GLOBAL("make.Call") Array args, int call_type, FunctionRef func, int value_index ) { - return Call::make(type, + return CallNode::make(type, name, args, - static_cast(call_type), + static_cast(call_type), func, value_index); }); @@ -122,9 +122,10 @@ TVM_REGISTER_GLOBAL("make.CommReducer") .set_body_typed(CommReducerNode::make); // make from two arguments -#define REGISTER_MAKE(Node) \ - TVM_REGISTER_GLOBAL("make."#Node) \ - .set_body_typed(Node::make); \ +#define REGISTER_MAKE(NodeName) \ + TVM_REGISTER_GLOBAL("make."#NodeName) \ + .set_body_typed(NodeName ## Node::make); \ + REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); @@ -174,7 +175,7 @@ TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed([]( VarExpr buffer_var, DataType type, Array extents, Expr condition, Stmt body ){ - return Allocate::make(buffer_var, type, extents, condition, body); + return AllocateNode::make(buffer_var, type, extents, condition, body); }); // operator overloading, smarter than make diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 804d8f1..4e635ad 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("_const") }); TVM_REGISTER_GLOBAL("_str") -.set_body_typed(ir::StringImm::make); +.set_body_typed(ir::StringImmNode::make); TVM_REGISTER_GLOBAL("_Array") @@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("_MapItems") auto* n = static_cast(ptr); auto rkvs = make_object(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImm::make(kv.first)); + rkvs->data.push_back(ir::StringImmNode::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = Array(rkvs); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 404f88d..68e0b05 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() { } bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value >= lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const Expr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index bb2e340..40f86de 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor { friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(Expr target, Expr expr, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); @@ -94,29 +94,29 @@ class BoundDeducer: public ExprVisitor { } } - void VisitExpr_(const LT* op) final { + void VisitExpr_(const LTNode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const LE* op) final { + void VisitExpr_(const LENode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const GT* op) final { + void VisitExpr_(const GTNode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const GE* op) final { + void VisitExpr_(const GENode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const Add* op) final { + void VisitExpr_(const AddNode* op) final { bool left = op->a.get() == path_[iter_]; result_ -= left ? op->b : op->a; this->VisitExpr(left ? op->a : op->b); } - void VisitExpr_(const Sub* op) final { + void VisitExpr_(const SubNode* op) final { bool left = op->a.get() == path_[iter_]; if (left) { result_ += op->b; @@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor { this->VisitExpr(left ? op->a : op->b); } - void VisitExpr_(const Mul* op) final { + void VisitExpr_(const MulNode* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; Expr target_var = left ? op->a : op->b; @@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor { CompareOp ReverseOp(CompareOp comp_op); Expr target_; Expr expr_; - const std::unordered_map& hint_map_; - const std::unordered_map& relax_map_; + const std::unordered_map& hint_map_; + const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; @@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { void BoundDeducer::Transform() { // We will ensure to set expr_ such that it contains target_ - if (const LT* op = expr_.as()) { + if (const LTNode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 comp_op = kGreater; @@ -245,7 +245,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b - 1; } - } else if (const LE* op = expr_.as()) { + } else if (const LENode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a comp_op = kGreater; @@ -256,7 +256,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b; } - } else if (const GT* op = expr_.as()) { + } else if (const GTNode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 comp_op = kLess; @@ -268,7 +268,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b + 1; } - } else if (const GE* op = expr_.as()) { + } else if (const GENode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a comp_op = kLess; @@ -279,7 +279,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b; } - } else if (const EQ* op = expr_.as()) { + } else if (const EQNode* op = expr_.as()) { comp_op = kEqual; if (GetPath(target_, op->a).empty()) { // if the b == a -> a == b @@ -330,8 +330,8 @@ void BoundDeducer::Relax() { } IntSet DeduceBound(Expr v, Expr e, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) { + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); @@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e, const Map& hint_map, const Map& relax_map) { - std::unordered_map hmap; + std::unordered_map hmap; for (auto kv : hint_map) { hmap[kv.first.get()] = kv.second; } - std::unordered_map rmap; + std::unordered_map rmap; for (auto kv : relax_map) { rmap[kv.first.get()] = kv.second; } diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d05ee2d..e33b0c5 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -450,14 +450,14 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } using Rewriter::VisitExpr_; - Expr VisitExpr_(const Add* op) final; - Expr VisitExpr_(const Sub* op) final; - Expr VisitExpr_(const Mul* op) final; - Expr VisitExpr_(const Div* op) final; - Expr VisitExpr_(const Mod* op) final; - Expr VisitExpr_(const FloorDiv* op) final; - Expr VisitExpr_(const FloorMod* op) final; - Expr VisitExpr_(const Reduce* op) final; + Expr VisitExpr_(const AddNode* op) final; + Expr VisitExpr_(const SubNode* op) final; + Expr VisitExpr_(const MulNode* op) final; + Expr VisitExpr_(const DivNode* op) final; + Expr VisitExpr_(const ModNode* op) final; + Expr VisitExpr_(const FloorDivNode* op) final; + Expr VisitExpr_(const FloorModNode* op) final; + Expr VisitExpr_(const ReduceNode* op) final; private: /*! @@ -553,7 +553,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } ObjectPtr n = make_object(); n->dtype = expr.dtype(); - if (const auto* op = expr.as()) { + if (const auto* op = expr.as()) { n->base = op->value; return SumExpr(n); } else { @@ -562,11 +562,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } } // Simplify the combiner used in reduce. - Expr SimplifyReduceCombiner(const Reduce* op); + Expr SimplifyReduceCombiner(const ReduceNode* op); }; Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Add* op) { +VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -575,13 +575,13 @@ VisitExpr_(const Add* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. SumExpr ret = ToSumExpr(std::move(a)); - if (const auto* op = b.as()) { + if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(op->value); } else if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(GetRef(op), 1); @@ -592,7 +592,7 @@ VisitExpr_(const Add* op) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Sub* op) { +VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -601,13 +601,13 @@ VisitExpr_(const Sub* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. SumExpr ret = ToSumExpr(std::move(a)); - if (const auto* op = b.as()) { + if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(-op->value); } else if (const auto* op = b.as()) { ret.CopyOnWrite()->AddToSelf(GetRef(op), -1); @@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) { Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Mul* op) { +VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -628,14 +628,14 @@ VisitExpr_(const Mul* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // x * c - if (a.as()) { + if (a.as()) { std::swap(a, b); } - if (const auto* bconst = b.as()) { + if (const auto* bconst = b.as()) { if (a.as()) { SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); @@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Mul::make(a, b); + return MulNode::make(a, b); } } @@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Div* op) { +VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -735,7 +735,7 @@ VisitExpr_(const Div* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold
(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -756,7 +756,7 @@ VisitExpr_(const Div* op) { analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { lhs.CopyOnWrite()->DivideBy(cval); Expr temp = Normalize(extra); - if (const auto* pconst = temp.as()) { + if (const auto* pconst = temp.as()) { lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); } else { // if 0 <= extra < cval, it means the extra can be eliminated. @@ -782,12 +782,12 @@ VisitExpr_(const Div* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Div::make(a, b); + return DivNode::make(a, b); } } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorDiv* op) { +VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -813,7 +813,7 @@ VisitExpr_(const FloorDiv* op) { // continue simplification. lhs.CopyOnWrite()->DivideBy(cval); Expr temp = Normalize(extra); - if (const auto* pconst = temp.as()) { + if (const auto* pconst = temp.as()) { lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); } else { // if 0 <= extra < cval, it means the extra can be eliminated. @@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorDiv::make(a, b); + return FloorDivNode::make(a, b); } } @@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Mod* op) { +VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -919,7 +919,7 @@ VisitExpr_(const Mod* op) { if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { Expr temp = Normalize(extra); - if (temp.as()) { + if (temp.as()) { return truncmod(temp, c1.Eval()); } else { // If temp < cval && temp >=0 then can remove the mod. @@ -958,12 +958,12 @@ VisitExpr_(const Mod* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Mod::make(a, b); + return ModNode::make(a, b); } } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorMod* op) { +VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -983,7 +983,7 @@ VisitExpr_(const FloorMod* op) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); Expr temp = Normalize(extra); - if (temp.as()) { + if (temp.as()) { return floormod(temp, c1.Eval()); } else { // If temp < cval && temp >=0 then can remove the mod. @@ -1018,13 +1018,13 @@ VisitExpr_(const FloorMod* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorMod::make(a, b); + return FloorModNode::make(a, b); } } // Simplify reduce expression. Expr CanonicalSimplifier::Impl:: -SimplifyReduceCombiner(const Reduce* op) { +SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { @@ -1089,15 +1089,15 @@ SimplifyReduceCombiner(const Reduce* op) { CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return Reduce::make( + return ReduceNode::make( new_combiner, new_source, op->axis, op->condition, new_value_index); } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Reduce* op) { +VisitExpr_(const ReduceNode* op) { // Recursively call simplification when necessary. Expr ret = RewriteSimplifier::Impl::VisitExpr_(op); - op = ret.as(); + op = ret.as(); // already been simplified by const reduction axis removal if (op == nullptr) return ret; if (op->axis.empty()) { @@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) { // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. return this->VisitExpr( - Select::make(op->condition, + SelectNode::make(op->condition, op->source[op->value_index], op->combiner->identity_element[op->value_index])); } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 806587a..aca26e8 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -77,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) { } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a + b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a - b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a * b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return truncdiv(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return truncmod(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return max(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return min(a, b); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 8b4ea2f..db98a7e 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,21 +76,21 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImm; \ - using ir::UIntImm; \ - using ir::FloatImm; \ - const IntImm* pa = a.as(); \ - const IntImm* pb = b.as(); \ - const FloatImm* fa = a.as(); \ - const FloatImm* fb = b.as(); \ + using ir::IntImmNode; \ + using ir::UIntImmNode; \ + using ir::FloatImmNode; \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const FloatImmNode* fa = a.as(); \ + const FloatImmNode* fb = b.as(); \ BODY; #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImm; \ - using ir::UIntImm; \ - const IntImm* pa = a.as(); \ - const IntImm* pb = b.as(); \ + using ir::IntImmNode; \ + using ir::UIntImmNode; \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ const DataType& tb = b.dtype(); \ if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ @@ -100,13 +100,13 @@ inline bool IsIndexType(const DataType& type) { // specialization of constant folders. template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); + if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); + if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); @@ -114,22 +114,22 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); + if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); + if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); + if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -138,7 +138,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pb->value == 1) return a; if (pb->value == 0) return b; } - if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); + if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value); if (fa) { if (fa->value == 1) return b; if (fa->value == 0) return a; @@ -152,14 +152,14 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm::make(rtype, pa->value / pb->value); + return IntImmNode::make(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; @@ -169,7 +169,7 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm::make(rtype, fa->value / fb->value); + return FloatImmNode::make(rtype, fa->value / fb->value); } if (fa && fa->value == 0) return a; if (fb) { @@ -181,11 +181,11 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImm::make(rtype, pa->value % pb->value); + return IntImmNode::make(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; @@ -199,12 +199,12 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm::make(rtype, arith::floordiv(pa->value, pb->value)); + return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -214,7 +214,7 @@ inline Expr TryConstFold(Expr a, Expr b) { CHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm::make(rtype, std::floor(fa->value / fb->value)); + return FloatImmNode::make(rtype, std::floor(fa->value / fb->value)); } if (fa && fa->value == 0) return a; if (fb) { @@ -226,11 +226,11 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImm::make(rtype, arith::floormod(pa->value, pb->value)); + return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -244,86 +244,86 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); + if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); + if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); }); return Expr(); } template<> -inline Expr TryConstFold(Expr a, Expr b) { - using ir::UIntImm; - const UIntImm* pa = a.as(); - const UIntImm* pb = b.as(); +inline Expr TryConstFold(Expr a, Expr b) { + using ir::UIntImmNode; + const UIntImmNode* pa = a.as(); + const UIntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -332,10 +332,10 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { - using ir::UIntImm; - const UIntImm* pa = a.as(); - const UIntImm* pb = b.as(); +inline Expr TryConstFold(Expr a, Expr b) { + using ir::UIntImmNode; + const UIntImmNode* pa = a.as(); + const UIntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -344,11 +344,11 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a) { - using ir::UIntImm; - const UIntImm* pa = a.as(); +inline Expr TryConstFold(Expr a) { + using ir::UIntImmNode; + const UIntImmNode* pa = a.as(); if (pa) { - return UIntImm::make(DataType::UInt(1), !(pa->value)); + return UIntImmNode::make(DataType::UInt(1), !(pa->value)); } return Expr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index ef405d8..d3f885a 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -140,17 +140,17 @@ class ConstIntBoundAnalyzer::Impl : return res; } - Entry VisitExpr_(const Cast* op) final { + Entry VisitExpr_(const CastNode* op) final { Entry a = VisitExpr(op->value); Entry b = Everything(op->dtype); return Intersect(a, b); } - Entry VisitExpr_(const IntImm* op) final { + Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImm* op) final { + Entry VisitExpr_(const UIntImmNode* op) final { if (op->value <= static_cast(kPosInf)) { return MakeBound(op->value, op->value); } else { @@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const Add* op) final { + Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Sub* op) final { + Entry VisitExpr_(const SubNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Mul* op) final { + Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return BinaryOpBoundry(a, b, InfAwareMul); } - Entry VisitExpr_(const Div* op) final { + Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "divide by zero"; @@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl : return BinaryOpBoundry(a, b, InfAwareDiv); } - Entry VisitExpr_(const Mod* op) final { + Entry VisitExpr_(const ModNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); if (b.min_value > 0) { @@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const FloorDiv* op) final { + Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "floordiv by zero"; @@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl : return BinaryOpBoundry(a, b, InfAwareFloorDiv); } - Entry VisitExpr_(const FloorMod* op) final { + Entry VisitExpr_(const FloorModNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); if (b.min_value > 0) { @@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const Min* op) final { + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Max* op) final { + Entry VisitExpr_(const MaxNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -264,25 +264,25 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Select* op) final { + Entry VisitExpr_(const SelectNode* op) final { Entry a = VisitExpr(op->true_value); Entry b = VisitExpr(op->false_value); return Union(a, b); } - Entry VisitExpr_(const Call* op) final { + Entry VisitExpr_(const CallNode* op) final { // only special handle >> and & which can be // used for index calculation. - if (op->is_intrinsic(Call::shift_right)) { + if (op->is_intrinsic(CallNode::shift_right)) { return VisitRightShift(op); - } else if (op->is_intrinsic(Call::bitwise_and)) { + } else if (op->is_intrinsic(CallNode::bitwise_and)) { return VisitBitwiseAnd(op); } else { return Everything(op->dtype); } } - Entry VisitExpr_(const Variable* op) final { + Entry VisitExpr_(const VarNode* op) final { Var v = GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { @@ -292,13 +292,13 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitRightShift(const Call* op) { + Entry VisitRightShift(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); return BinaryOpBoundry(a, b, InfAwareRightShift); } - Entry VisitBitwiseAnd(const Call* op) { + Entry VisitBitwiseAnd(const CallNode* op) { Entry a = VisitExpr(op->args[0]); Entry b = VisitExpr(op->args[1]); // handle positive index case. @@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl : return kNegInf; } if (y == kPosInf || y == kNegInf) return y; - if (WillOverflow(x, y, kNegInf, kPosInf)) { + if (WillOverflow(x, y, kNegInf, kPosInf)) { if (x > 0) return kPosInf; return kNegInf; } @@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl : * \return the result. */ static int64_t InfAwareMul(int64_t x, int64_t y) { - if (!WillOverflow(x, y, kNegInf, kPosInf)) return x * y; + if (!WillOverflow(x, y, kNegInf, kPosInf)) return x * y; if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; return kNegInf; } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index b8ec974..7785801 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -60,7 +60,7 @@ class LinearEqDetector return true; } - LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -70,7 +70,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -80,7 +80,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -96,7 +96,7 @@ class LinearEqDetector ret.coeff = MulCombine(a.base, b.coeff); return ret; } - LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final { LinearEqEntry ret; if (op == var_.get()) { ret.coeff = make_const(op->dtype, 1); @@ -152,7 +152,7 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { base = std::move(ret.base); } - std::unordered_set vset; + std::unordered_set vset; for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable @@ -167,11 +167,11 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { // Detect clip condition as min max value bool DetectClipBound( const Expr& cond, - std::unordered_map* bmap) { + std::unordered_map* bmap) { int flag = 0; Var var; auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { - if (const Variable* v = n.as()) { + if (const VarNode* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { var = Downcast(n); @@ -188,16 +188,16 @@ bool DetectClipBound( if (flag != 1) return false; // canonical form: exp >= 0 Expr canonical; - if (const LT* op = cond.as()) { + if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a - make_const(op->a.dtype(), 1); - } else if (const LE* op = cond.as()) { + } else if (const LENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a; - } else if (const GT* op = cond.as()) { + } else if (const GTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b - make_const(op->a.dtype(), 1); - } else if (const GE* op = cond.as()) { + } else if (const GENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; } else { @@ -210,7 +210,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { - p.min_value = ir::Max::make(p.min_value, -ret.base); + p.min_value = ir::MaxNode::make(p.min_value, -ret.base); } else { p.min_value = -ret.base; } @@ -219,7 +219,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { - p.max_value = ir::Min::make(p.max_value, ret.base); + p.max_value = ir::MinNode::make(p.max_value, ret.base); } else { p.max_value = ret.base; } @@ -243,8 +243,8 @@ void SplitCommExpr(const Expr& e, std::vector* ret) { // e must be connected by and. Array DetectClipBound(const Expr& e, const Array& vars) { std::vector splits; - SplitCommExpr(e, &splits); - std::unordered_map rmap; + SplitCommExpr(e, &splits); + std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); } diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index 02f3578..1821c16 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -53,15 +53,15 @@ class FuncTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const For *op) final { - const Variable* var = op->loop_var.get(); + void VisitStmt_(const ForNode *op) final { + const VarNode* var = op->loop_var.get(); dom_map_[var] = IntSet::range( Range::make_by_min_extent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } - void VisitStmt_(const LetStmt* op) final { + void VisitStmt_(const LetStmtNode* op) final { dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); StmtExprVisitor::VisitStmt_(op); @@ -69,11 +69,11 @@ class FuncTouchedDomain final : public StmtExprVisitor { } /* TODO: Thread extent unitest not generated.*/ - void VisitStmt_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); - const Variable* var = thread_axis->var.get(); + const VarNode* var = thread_axis->var.get(); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); @@ -82,7 +82,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - void VisitExpr_(const Call* op) final { + void VisitExpr_(const CallNode* op) final { if (consider_calls_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); @@ -90,7 +90,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const Provide* op) final { + void VisitStmt_(const ProvideNode* op) final { if (consider_provides_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); @@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { const Tensor &tensor_; bool consider_calls_, consider_provides_; std::vector > bounds_; - std::unordered_map dom_map_; + std::unordered_map dom_map_; }; Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) { diff --git a/src/arithmetic/int_operator.h b/src/arithmetic/int_operator.h index e3adf1f..fd51091 100644 --- a/src/arithmetic/int_operator.h +++ b/src/arithmetic/int_operator.h @@ -47,30 +47,30 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +inline bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { if ((y > 0) && (x > max_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true; return false; } template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +inline bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { if ((y > 0) && (x < min_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true; return false; } template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +inline bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { if (y == 0) return false; if (y > 0) { if (x < min_value / y) return true; @@ -84,10 +84,10 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +inline bool WillOverflow(int64_t x, + int64_t y, + int64_t min_value, + int64_t max_value) { return y == 0; } diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index bf1cdf0..c60c825 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -83,15 +83,15 @@ struct is_logical_op { static const bool value = true; \ }; -TVM_DECLARE_LOGICAL_OP(And); -TVM_DECLARE_LOGICAL_OP(Or); -TVM_DECLARE_LOGICAL_OP(EQ); -TVM_DECLARE_LOGICAL_OP(NE); -TVM_DECLARE_LOGICAL_OP(GE); -TVM_DECLARE_LOGICAL_OP(GT); -TVM_DECLARE_LOGICAL_OP(LE); -TVM_DECLARE_LOGICAL_OP(LT); -TVM_DECLARE_LOGICAL_OP(Not); +TVM_DECLARE_LOGICAL_OP(AndNode); +TVM_DECLARE_LOGICAL_OP(OrNode); +TVM_DECLARE_LOGICAL_OP(EQNode); +TVM_DECLARE_LOGICAL_OP(NENode); +TVM_DECLARE_LOGICAL_OP(GENode); +TVM_DECLARE_LOGICAL_OP(GTNode); +TVM_DECLARE_LOGICAL_OP(LENode); +TVM_DECLARE_LOGICAL_OP(LTNode); +TVM_DECLARE_LOGICAL_OP(NotNode); /*! * \brief Combine two interval set under arithmetic operations. @@ -118,9 +118,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -136,9 +136,9 @@ inline IntervalSet Combine(Analyzer* analyer, } template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -155,9 +155,9 @@ inline IntervalSet Combine(Analyzer* analyer, template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -178,11 +178,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value * b->min_value; Expr e2 = a->max_value * b->min_value; - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Mul"; @@ -190,9 +190,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -213,11 +213,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value / b->min_value; Expr e2 = a->max_value / b->min_value; - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -225,9 +225,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -256,9 +256,9 @@ inline IntervalSet Combine(Analyzer* analyzer, template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -279,11 +279,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = floordiv(a->min_value, b->min_value); Expr e2 = floordiv(a->max_value, b->min_value); - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -291,9 +291,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -317,9 +317,9 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -330,9 +330,9 @@ inline IntervalSet Combine(Analyzer* analzyer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -380,15 +380,15 @@ class IntervalSetEvaluator : return IntervalSet(min_set->min_value, max_set->max_value); } - IntervalSet VisitExpr_(const IntImm* op) final { + IntervalSet VisitExpr_(const IntImmNode* op) final { return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImm* op) final { + IntervalSet VisitExpr_(const UIntImmNode* op) final { return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const Variable* op) final { + IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); if (it != dom_map_.end()) { @@ -405,75 +405,75 @@ class IntervalSetEvaluator : } } - IntervalSet VisitExpr_(const Add* op) final { + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Sub* op) final { + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Mul* op) final { + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Div* op) final { + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Mod* op) final { + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDiv* op) final { + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorMod* op) final { + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Min* op) final { + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Max* op) final { + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const EQ* op) final { + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NE* op) final { + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LT* op) final { + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LE* op) final { + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GT* op) final { + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GE* op) final { + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const And* op) final { + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Or* op) final { + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Ramp* op) final { + IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); IntervalSet base = Eval(op->base); PVar stride; @@ -481,12 +481,12 @@ class IntervalSetEvaluator : DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; if (vstride> 0) { - return Combine( + return Combine( analyzer_, base, IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine( + return Combine( analyzer_, base, IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); @@ -496,12 +496,12 @@ class IntervalSetEvaluator : return IntervalSet::Everything(); } - IntervalSet VisitExpr_(const Broadcast* op) final { + IntervalSet VisitExpr_(const BroadcastNode* op) final { CHECK(eval_vec_); return VisitExpr(op->value); } - IntervalSet VisitExpr_(const Select* op) final { + IntervalSet VisitExpr_(const SelectNode* op) final { IntervalSet true_set = this->Eval(op->true_value); IntervalSet false_set = this->Eval(op->false_value); return Union(analyzer_, false_set, true_set); @@ -720,7 +720,7 @@ Map ConvertDomMap(const Map& dom_map) { } Map ConvertDomMap( - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Map dmap; for (auto kv : dom_map) { dmap.Set(GetRef(kv.first), kv.second); @@ -746,7 +746,7 @@ IntSet EvalSet(Expr e, } IntSet EvalSet(Expr e, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } @@ -761,12 +761,12 @@ IntSet EvalSet(Range r, } IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); @@ -796,7 +796,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { ExprIntSetMap EvalSetForEachSubExpr( Expr e, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); SubExprIntervalSetEvaluator m(&ana, dmap); diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index bfce2c2..961c476 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -30,14 +30,14 @@ namespace arith { using namespace ir; Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const For* op) { +VisitStmt_(const ForNode* op) { analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprMutator::VisitStmt_(op); } Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const LetStmt* op) { +VisitStmt_(const LetStmtNode* op) { Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -57,7 +57,7 @@ VisitStmt_(const LetStmt* op) { } Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const IfThenElse* op) { +VisitStmt_(const IfThenElseNode* op) { Expr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { @@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) { } if (op->else_case.defined()) { With ctx(analyzer_, - analyzer_->rewrite_simplify(Not::make(condition))); + analyzer_->rewrite_simplify(NotNode::make(condition))); else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; @@ -74,7 +74,7 @@ VisitStmt_(const IfThenElse* op) { if (else_case.defined()) { return else_case; } - return Evaluate::make(0); + return EvaluateNode::make(0); } if (condition.same_as(op->condition) && @@ -91,7 +91,7 @@ VisitStmt_(const IfThenElse* op) { } Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AttrStmt* op) { +VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); @@ -106,7 +106,7 @@ VisitStmt_(const AttrStmt* op) { } Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AssertStmt* op) { +VisitStmt_(const AssertStmtNode* op) { Expr condition = this->VisitExpr(op->condition); Expr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); @@ -126,7 +126,7 @@ VisitStmt_(const AssertStmt* op) { } Expr IRMutatorWithAnalyzer:: -VisitExpr_(const Call* op) { +VisitExpr_(const CallNode* op) { // add condition context to if_then_else if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { Expr cond = this->VisitExpr(op->args[0]); @@ -137,7 +137,7 @@ VisitExpr_(const Call* op) { } { With constraint(analyzer_, - analyzer_->rewrite_simplify(Not::make(cond))); + analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { @@ -151,7 +151,7 @@ VisitExpr_(const Call* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return Call::make(op->dtype, op->name, + return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } @@ -160,7 +160,7 @@ VisitExpr_(const Call* op) { } Expr IRMutatorWithAnalyzer:: -VisitExpr_(const Let* op) { +VisitExpr_(const LetNode* op) { Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -172,12 +172,12 @@ VisitExpr_(const Let* op) { body.same_as(op->body)) { return GetRef(op); } else { - return Let::make(op->var, value, body); + return LetNode::make(op->var, value, body); } } Expr IRMutatorWithAnalyzer:: -VisitExpr_(const Select* op) { +VisitExpr_(const SelectNode* op) { Expr cond = this->VisitExpr(op->condition); Expr true_value, false_value; { @@ -186,7 +186,7 @@ VisitExpr_(const Select* op) { } { With constraint(analyzer_, - analyzer_->rewrite_simplify(Not::make(cond))); + analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -201,12 +201,12 @@ VisitExpr_(const Select* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return Select::make(cond, true_value, false_value); + return SelectNode::make(cond, true_value, false_value); } } Expr IRMutatorWithAnalyzer:: -VisitExpr_(const Reduce* op) { +VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index 9e3a86b..1e96c0a 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { using StmtExprMutator::VisitExpr_; // override functions that need to populate the context information. - Stmt VisitStmt_(const ir::For* op) override; - Stmt VisitStmt_(const ir::LetStmt* op) override; - Stmt VisitStmt_(const ir::IfThenElse* op) override; - Stmt VisitStmt_(const ir::AttrStmt* op) override; - Stmt VisitStmt_(const ir::AssertStmt* op) override; - Expr VisitExpr_(const ir::Let* op) override; - Expr VisitExpr_(const ir::Select* op) override; - Expr VisitExpr_(const ir::Call* op) override; - Expr VisitExpr_(const ir::Reduce* op) override; + Stmt VisitStmt_(const ir::ForNode* op) override; + Stmt VisitStmt_(const ir::LetStmtNode* op) override; + Stmt VisitStmt_(const ir::IfThenElseNode* op) override; + Stmt VisitStmt_(const ir::AttrStmtNode* op) override; + Stmt VisitStmt_(const ir::AssertStmtNode* op) override; + Expr VisitExpr_(const ir::LetNode* op) override; + Expr VisitExpr_(const ir::SelectNode* op) override; + Expr VisitExpr_(const ir::CallNode* op) override; + Expr VisitExpr_(const ir::ReduceNode* op) override; protected: /*! \brief internal analyzer field. */ diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index b8750df..07ec186 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -38,13 +38,13 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { return analyzer_.Simplify(expr); } - void VisitStmt_(const For* op) { + void VisitStmt_(const ForNode* op) { analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const AttrStmt* op) { + void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); @@ -57,7 +57,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor { } } - void VisitExpr_(const Reduce* op) { + void VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_.Bind(iv->var, iv->dom); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index a83e987..8e2e065 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -124,15 +124,15 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Cast* op) final { + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } - Entry VisitExpr_(const IntImm* op) final { + Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } - Entry VisitExpr_(const UIntImm* op) final { + Entry VisitExpr_(const UIntImmNode* op) final { if (op->value < std::numeric_limits::max()) { return Entry(0, static_cast(op->value)); } else { @@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl : } } - Entry VisitExpr_(const Add* op) final { + Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base + b.base); } - Entry VisitExpr_(const Sub* op) final { + Entry VisitExpr_(const SubNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base - b.base); } - Entry VisitExpr_(const Mul* op) final { + Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); // Simplification rule, x, y, z are in Z @@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Div* op) final { + Entry VisitExpr_(const DivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, false); @@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const FloorDiv* op) final { + Entry VisitExpr_(const FloorDivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, true); @@ -204,35 +204,35 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Min* op) final { + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); } - Entry VisitExpr_(const Max* op) final { + Entry VisitExpr_(const MaxNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); } - Entry VisitExpr_(const Select* op) final { + Entry VisitExpr_(const SelectNode* op) final { Entry a = VisitExpr(op->true_value); Entry b = VisitExpr(op->false_value); return Union(a, b); } - Entry VisitExpr_(const Call* op) final { + Entry VisitExpr_(const CallNode* op) final { // only special handle >> which can be // used for index calculation. - if (op->is_intrinsic(Call::shift_right)) { + if (op->is_intrinsic(CallNode::shift_right)) { return VisitRightShift(op); } else { return Everything(); } } - Entry VisitExpr_(const Variable* op) final { + Entry VisitExpr_(const VarNode* op) final { Var v = GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { @@ -242,7 +242,7 @@ class ModularSetAnalyzer::Impl : } } - Entry VisitRightShift(const Call* op) { + Entry VisitRightShift(const CallNode* op) { Entry b = VisitExpr(op->args[1]); // a c x / c -> a x if (b.is_const()) { diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index bff9564..e964abb 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -283,7 +283,7 @@ class PConstWithTypeLike : void InitMatch_() const {} bool Match_(const ObjectRef& node) const { - if (const ir::IntImm* ptr = node.as()) { + if (const ir::IntImmNode* ptr = node.as()) { return ptr->value == value_; } else { return false; @@ -325,30 +325,30 @@ class PConstWithTypeLike : // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, ir::Add); -TVM_PATTERN_BINARY_OP(operator-, ir::Sub); -TVM_PATTERN_BINARY_OP(operator*, ir::Mul); -TVM_PATTERN_BINARY_OP(min, ir::Min); -TVM_PATTERN_BINARY_OP(max, ir::Max); -TVM_PATTERN_BINARY_OP(div, ir::Div); -TVM_PATTERN_BINARY_OP(truncdiv, ir::Div); -TVM_PATTERN_BINARY_OP(truncmod, ir::Mod); -TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); -TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod); +TVM_PATTERN_BINARY_OP(operator+, ir::AddNode); +TVM_PATTERN_BINARY_OP(operator-, ir::SubNode); +TVM_PATTERN_BINARY_OP(operator*, ir::MulNode); +TVM_PATTERN_BINARY_OP(min, ir::MinNode); +TVM_PATTERN_BINARY_OP(max, ir::MaxNode); +TVM_PATTERN_BINARY_OP(div, ir::DivNode); +TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode); +TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode); +TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode); +TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, ir::GT); -TVM_PATTERN_BINARY_OP(operator>=, ir::GE); -TVM_PATTERN_BINARY_OP(operator<, ir::LT); -TVM_PATTERN_BINARY_OP(operator<=, ir::LE); -TVM_PATTERN_BINARY_OP(operator==, ir::EQ); -TVM_PATTERN_BINARY_OP(operator!=, ir::NE); -TVM_PATTERN_BINARY_OP(operator&&, ir::And); -TVM_PATTERN_BINARY_OP(operator||, ir::Or); +TVM_PATTERN_BINARY_OP(operator>, ir::GTNode); +TVM_PATTERN_BINARY_OP(operator>=, ir::GENode); +TVM_PATTERN_BINARY_OP(operator<, ir::LTNode); +TVM_PATTERN_BINARY_OP(operator<=, ir::LENode); +TVM_PATTERN_BINARY_OP(operator==, ir::EQNode); +TVM_PATTERN_BINARY_OP(operator!=, ir::NENode); +TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode); +TVM_PATTERN_BINARY_OP(operator||, ir::OrNode); /*! * \brief Pattern not expression. @@ -365,7 +365,7 @@ class PNotExpr : public Pattern > { } bool Match_(const ObjectRef& node) const { - if (const ir::Not* ptr = node.as()) { + if (const ir::NotNode* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; } else { @@ -374,7 +374,7 @@ class PNotExpr : public Pattern > { } Expr Eval() const { - return ir::Not::make(value_.Eval()); + return ir::NotNode::make(value_.Eval()); } private: @@ -411,7 +411,7 @@ class PSelectExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Select* ptr = node.as()) { + if (const ir::SelectNode* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false; @@ -422,7 +422,7 @@ class PSelectExpr : } Expr Eval() const { - return ir::Select::make( + return ir::SelectNode::make( condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } @@ -473,7 +473,7 @@ class PCastExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Cast* ptr = node.as()) { + if (const ir::CastNode* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; @@ -483,7 +483,7 @@ class PCastExpr : } Expr Eval() const { - return ir::Cast::make(dtype_.Eval(), value_.Eval()); + return ir::CastNode::make(dtype_.Eval(), value_.Eval()); } private: @@ -531,7 +531,7 @@ class PRampExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Ramp* ptr = node.as()) { + if (const ir::RampNode* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; if (!lanes_.Match_(ptr->lanes)) return false; @@ -542,7 +542,7 @@ class PRampExpr : } Expr Eval() const { - return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); + return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: @@ -593,7 +593,7 @@ class PBroadcastExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Broadcast* ptr = node.as()) { + if (const ir::BroadcastNode* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; return true; @@ -603,7 +603,7 @@ class PBroadcastExpr : } Expr Eval() const { - return ir::Broadcast::make(value_.Eval(), lanes_.Eval()); + return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } private: @@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor { }; struct PCallExprMatchFunctor { - const ir::Call* call_; + const ir::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const ir::Call* call) + explicit PCallExprMatchFunctor(const ir::CallNode* call) : call_(call) {} template @@ -705,7 +705,7 @@ class PCallExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Call* ptr = node.as()) { + if (const ir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (ptr->name != Op::kName) return false; detail::PCallExprMatchFunctor fmatch(ptr); @@ -727,18 +727,18 @@ class PCallExpr : }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static Expr Eval(Array args) { \ - return ir::Call::make(args[0].dtype(), kName, args, \ - ir::Call::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - return PCallExpr(a.derived(), b.derived()); \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static Expr Eval(Array args) { \ + return ir::CallNode::make(args[0].dtype(), kName, args, \ + ir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr \ + FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); @@ -748,18 +748,18 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static Expr Eval(Array args) { \ - return ir::Call::make(args[0].dtype(), kName, args, \ - ir::Call::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static Expr Eval(Array args) { \ + return ir::CallNode::make(args[0].dtype(), kName, args, \ + ir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr \ + FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); @@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static Expr Eval(Array args) { - return ir::Call::make( + return ir::CallNode::make( args[1].dtype(), kName, args, - ir::Call::PureIntrinsic); + ir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index f883bf1..2421e10 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -69,7 +69,7 @@ using namespace ir; RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: TryCompare(const Expr& x, int64_t val) { Expr diff = this->VisitExpr(x); - if (const auto* ptr = diff.as()) { + if (const auto* ptr = diff.as()) { if (ptr->value == val) { return kEQ; } else if (ptr->value > val) { @@ -116,10 +116,10 @@ Update(const Var& var, const Expr& info, bool override) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Add* op) { +VisitExpr_(const AddNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -231,10 +231,10 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& const } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Sub* op) { +VisitExpr_(const SubNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -430,10 +430,10 @@ VisitExpr_(const Sub* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Mul* op) { +VisitExpr_(const MulNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -469,10 +469,10 @@ VisitExpr_(const Mul* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Div* op) { +VisitExpr_(const DivNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as
(); - Expr const_res = TryConstFold
(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -482,7 +482,7 @@ VisitExpr_(const Div* op) { PVar lanes; // x / 2.0 = x * 0.5 - if (const FloatImm* ptr = op->b.as()) { + if (const FloatImmNode* ptr = op->b.as()) { CHECK(op->dtype.is_float()); return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); } @@ -691,10 +691,10 @@ VisitExpr_(const Div* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Mod* op) { +VisitExpr_(const ModNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -781,10 +781,10 @@ VisitExpr_(const Mod* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const FloorDiv* op) { +VisitExpr_(const FloorDivNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -925,10 +925,10 @@ VisitExpr_(const FloorDiv* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const FloorMod* op) { +VisitExpr_(const FloorModNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -995,10 +995,10 @@ VisitExpr_(const FloorMod* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Min* op) { +VisitExpr_(const MinNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1180,10 +1180,10 @@ VisitExpr_(const Min* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Max* op) { +VisitExpr_(const MaxNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1353,10 +1353,10 @@ VisitExpr_(const Max* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const EQ* op) { +VisitExpr_(const EQNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1387,30 +1387,30 @@ VisitExpr_(const EQ* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const NE* op) { - return this->VisitExpr(Not::make(op->a == op->b)); +VisitExpr_(const NENode* op) { + return this->VisitExpr(NotNode::make(op->a == op->b)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const LE* op) { - return this->VisitExpr(Not::make(op->b < op->a)); +VisitExpr_(const LENode* op) { + return this->VisitExpr(NotNode::make(op->b < op->a)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const GT* op) { +VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const GE* op) { - return this->VisitExpr(Not::make(op->a < op->b)); +VisitExpr_(const GENode* op) { + return this->VisitExpr(NotNode::make(op->a < op->b)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const LT* op) { +VisitExpr_(const LTNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1563,10 +1563,10 @@ VisitExpr_(const LT* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Not* op) { +VisitExpr_(const NotNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a); + op = ret.as(); + Expr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y; @@ -1588,10 +1588,10 @@ VisitExpr_(const Not* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const And* op) { +VisitExpr_(const AndNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1637,10 +1637,10 @@ VisitExpr_(const And* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Or* op) { +VisitExpr_(const OrNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1687,9 +1687,9 @@ VisitExpr_(const Or* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Select* op) { +VisitExpr_(const SelectNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as()) { +bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) { + if (const auto* rhs = other.as()) { return Equal(lhs->condition, rhs->condition) && Equal(lhs->true_value, rhs->true_value) && @@ -220,19 +220,19 @@ size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) { } } -size_t AttrsHashHandler::VisitAttr_(const IntImm* op) { +size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) { +size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) { +size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const StringImm* op) { +size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) { return std::hash()(op->value); } @@ -265,31 +265,31 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { return Combine(key, Combine(Hash(op->a), Hash(op->b))); \ } \ -TVM_DEFINE_ATTRS_BINOP_HASH(Add); -TVM_DEFINE_ATTRS_BINOP_HASH(Sub); -TVM_DEFINE_ATTRS_BINOP_HASH(Mul); -TVM_DEFINE_ATTRS_BINOP_HASH(Div); -TVM_DEFINE_ATTRS_BINOP_HASH(Mod); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); -TVM_DEFINE_ATTRS_BINOP_HASH(Max); -TVM_DEFINE_ATTRS_BINOP_HASH(Min); -TVM_DEFINE_ATTRS_BINOP_HASH(GE); -TVM_DEFINE_ATTRS_BINOP_HASH(GT); -TVM_DEFINE_ATTRS_BINOP_HASH(LE); -TVM_DEFINE_ATTRS_BINOP_HASH(LT); -TVM_DEFINE_ATTRS_BINOP_HASH(EQ); -TVM_DEFINE_ATTRS_BINOP_HASH(NE); -TVM_DEFINE_ATTRS_BINOP_HASH(And); -TVM_DEFINE_ATTRS_BINOP_HASH(Or); - -size_t AttrsHashHandler::VisitAttr_(const Not* op) { - static size_t key = std::hash()(Not::_type_key); +TVM_DEFINE_ATTRS_BINOP_HASH(AddNode); +TVM_DEFINE_ATTRS_BINOP_HASH(SubNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MulNode); +TVM_DEFINE_ATTRS_BINOP_HASH(DivNode); +TVM_DEFINE_ATTRS_BINOP_HASH(ModNode); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MinNode); +TVM_DEFINE_ATTRS_BINOP_HASH(GENode); +TVM_DEFINE_ATTRS_BINOP_HASH(GTNode); +TVM_DEFINE_ATTRS_BINOP_HASH(LENode); +TVM_DEFINE_ATTRS_BINOP_HASH(LTNode); +TVM_DEFINE_ATTRS_BINOP_HASH(EQNode); +TVM_DEFINE_ATTRS_BINOP_HASH(NENode); +TVM_DEFINE_ATTRS_BINOP_HASH(AndNode); +TVM_DEFINE_ATTRS_BINOP_HASH(OrNode); + +size_t AttrsHashHandler::VisitAttr_(const NotNode* op) { + static size_t key = std::hash()(NotNode::_type_key); return Combine(key, Hash(op->a)); } -size_t AttrsHashHandler::VisitAttr_(const Cast* op) { - static size_t key = std::hash()(Cast::_type_key); +size_t AttrsHashHandler::VisitAttr_(const CastNode* op) { + static size_t key = std::hash()(CastNode::_type_key); AttrsHash hasher; size_t res = key; res = Combine(res, hasher(op->dtype)); @@ -297,8 +297,8 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) { return res; } -size_t AttrsHashHandler::VisitAttr_(const Call* op) { - static size_t key = std::hash()(Call::_type_key); +size_t AttrsHashHandler::VisitAttr_(const CallNode* op) { + static size_t key = std::hash()(CallNode::_type_key); AttrsHash hasher; size_t res = key; res = Combine(res, hasher(op->name)); @@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) { return res; } -size_t AttrsHashHandler::VisitAttr_(const Select* op) { - static size_t key = std::hash()(Select::_type_key); +size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) { + static size_t key = std::hash()(SelectNode::_type_key); size_t res = key; res = Combine(res, Hash(op->condition)); res = Combine(res, Hash(op->true_value)); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 22efa1d..d96033d 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -31,8 +31,8 @@ namespace tvm { // TODO(tqchen): change to floormod/div -using IndexMod = ir::FloorMod; -using IndexDiv = ir::FloorDiv; +using IndexMod = ir::FloorModNode; +using IndexDiv = ir::FloorDivNode; Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { @@ -65,7 +65,7 @@ inline std::vector ExprSplitAddition(const Expr &expr) { while (!split_buffer.empty()) { const Expr* top_ele = split_buffer.top(); split_buffer.pop(); - auto expr_add_match = top_ele->as(); + auto expr_add_match = top_ele->as(); if (expr_add_match) { split_buffer.push(&expr_add_match->b); split_buffer.push(&expr_add_match->a); @@ -88,13 +88,13 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, const Expr &mod_l_expr, const Expr &mod_r_expr) { using namespace ir; - const Mul* mult_ptr = mult_expr.as(); + const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, Expr()); Expr mult_outer = mult_ptr->b; const Expr* inner = &(mult_ptr->a); // 1. Calculate the outer multiplier while (true) { - mult_ptr = inner->as(); + mult_ptr = inner->as(); if (mult_ptr) { inner = &(mult_ptr->a); mult_outer = mult_ptr->b * mult_outer; @@ -113,8 +113,8 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, Expr no_opt_sum; // Sum of the exprs that cannot be optimized while (true) { auto inner_div_ptr = search_ptr->as(); - auto inner_mult_ptr = search_ptr->as(); - auto inner_add_ptr = search_ptr->as(); + auto inner_mult_ptr = search_ptr->as(); + auto inner_add_ptr = search_ptr->as(); if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { return std::make_pair(false, Expr()); } else if (inner_div_ptr) { @@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, *has_mod = false; for (const Expr* ele : eles) { auto mod_ptr = ele->as(); - auto mult_ptr = ele->as(); + auto mult_ptr = ele->as(); if (mod_ptr) { *has_mod = true; mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); @@ -252,7 +252,7 @@ inline Expr ElemOffset(const BufferNode* n, Array index) { if (n->strides.size() == 0) { // Scalar case if (n->shape.size() == 0 && index.size() == 1) { - auto is_int = index[0].as(); + auto is_int = index[0].as(); CHECK(is_int && is_int->value == 0); base = base + index[0]; } else { @@ -285,7 +285,7 @@ inline Expr BufferOffset(const BufferNode* n, Array index, DataType dtype) offset = offset * make_const(offset.dtype(), dtype.lanes()); } if (dtype.lanes() != 1) { - return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } else { return offset; } @@ -299,13 +299,13 @@ Expr Buffer::vload(Array begin, DataType dtype) const { << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { - return ir::Cast::make( + return ir::CastNode::make( DataType::Bool(), - ir::Load::make( + ir::LoadNode::make( DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), const_true())); } else { - return ir::Load::make( + return ir::LoadNode::make( dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } @@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array begin, Expr value) const { << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return ir::Store::make(n->data, - ir::Cast::make(DataType::Int(8), value), + return ir::StoreNode::make(n->data, + ir::CastNode::make(DataType::Int(8), value), BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { - return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), + return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } @@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, Expr()) - offset; + extent = arith::ComputeReduce(self->shape, Expr()) - offset; } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { @@ -405,8 +405,8 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E Array acc_args{ e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return ir::Call::make( - ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic); + return ir::CallNode::make( + ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic); } Buffer BufferNode::make(Var data, diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index c4a6b35..c30f344 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -72,7 +72,7 @@ Layout::Layout(const Array& axes) { node->axes = axes; std::ostringstream repr; for (const IterVar& axis : axes) { - if (const auto* factor = axis->dom->extent.as()) { + if (const auto* factor = axis->dom->extent.as()) { CHECK_GT(factor->value, 0); repr << factor->value; } @@ -186,7 +186,7 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { if (!this->defined()) return -1; for (const IterVar& itvar : operator->()->axes) { if (sub == LayoutAxis::Get(itvar)) { - const auto* factor = itvar->dom->extent.as(); + const auto* factor = itvar->dom->extent.as(); CHECK(factor); return factor->value; } @@ -251,7 +251,7 @@ inline Array TransformIndex(const Array& src_index, const Array& src_axis, const Array& transform_rule) { Array result; - std::unordered_map bind_map; + std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } @@ -287,18 +287,18 @@ inline Array TransformShape(const Array& src_shape, // for major-axis, bind the corresponding size // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 - std::unordered_map bind_map; + std::unordered_map bind_map; std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { Expr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; - if (orig_shape.as()) { + if (orig_shape.as()) { symbolic_var_set.insert(i); } if (!LayoutAxis::Get(orig_axis).IsPrimal()) { if (orig_shape.defined()) { - const auto* orig_shape_const = orig_shape.as(); - const auto* orig_axis_extent = orig_axis->dom->extent.as(); + const auto* orig_shape_const = orig_shape.as(); + const auto* orig_axis_extent = orig_axis->dom->extent.as(); if (orig_shape_const) { CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) << "Input shape mismatch at index " << i << ". Expected " @@ -322,7 +322,7 @@ inline Array TransformShape(const Array& src_shape, result.push_back(axis->dom->extent); } else { if (symbolic_var_set.count(i)) { - result.push_back(ir::Any::make()); + result.push_back(ir::AnyNode::make()); } else { result.push_back(ir::Simplify(ir::Substitute(rule, bind_map))); } diff --git a/src/lang/expr.cc b/src/lang/expr.cc index eed6938..58a97ed 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -30,19 +30,19 @@ namespace tvm { Expr::Expr(int32_t value) - : Expr(IntImm::make(DataType::Int(32), value)) {} + : Expr(IntImmNode::make(DataType::Int(32), value)) {} Expr::Expr(float value) - : Expr(ir::FloatImm::make(DataType::Float(32), value)) {} + : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {} Expr::Expr(std::string str) - : Expr(ir::StringImm::make(str)) {} + : Expr(ir::StringImmNode::make(str)) {} Var::Var(std::string name_hint, DataType t) - : Var(Variable::make(t, name_hint)) {} + : Var(VarNode::make(t, name_hint)) {} -Var Variable::make(DataType t, std::string name_hint) { - ObjectPtr node = make_object(); +Var VarNode::make(DataType t, std::string name_hint) { + ObjectPtr node = make_object(); node->dtype = t; node->name_hint = std::move(name_hint); return Var(node); @@ -54,10 +54,10 @@ Range::Range(Expr begin, Expr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImm::make(DataType t, int64_t value) { +Integer IntImmNode::make(DataType t, int64_t value) { CHECK(t.is_int() && t.is_scalar()) << "ValueError: IntImm can only take scalar."; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Integer(node); @@ -98,8 +98,8 @@ Var var(std::string name_hint, DataType t) { } TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); if (op->dtype == DataType::Int(32)) { p->stream << op->value; } else { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 1166e7e..34fac72 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -32,7 +32,7 @@ namespace tvm { // simple cast that only checks if type matches and cast inline Expr SimpleCast(const DataType& t, Expr value) { if (value.dtype() == t) return value; - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } // The public function with a quick checking path. @@ -41,9 +41,9 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = ir::Broadcast::make(lhs, rtype.lanes()); + lhs = ir::BroadcastNode::make(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = ir::Broadcast::make(rhs, ltype.lanes()); + rhs = ir::BroadcastNode::make(rhs, ltype.lanes()); } else { CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; @@ -85,27 +85,27 @@ Expr max_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImm::make(dtype, std::numeric_limits::max()); + return IntImmNode::make(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; - return IntImm::make(dtype, val); + return IntImmNode::make(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImm::make(dtype, std::numeric_limits::max()); + return UIntImmNode::make(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImm::make(dtype, val); + return UIntImmNode::make(dtype, val); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { - return FloatImm::make(dtype, std::numeric_limits::max()); + return FloatImmNode::make(dtype, std::numeric_limits::max()); } else if (dtype.bits() == 32) { - return FloatImm::make(dtype, std::numeric_limits::max()); + return FloatImmNode::make(dtype, std::numeric_limits::max()); } else if (dtype.bits() == 16) { - return FloatImm::make(dtype, 65504.0); + return FloatImmNode::make(dtype, 65504.0); } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; @@ -117,21 +117,21 @@ Expr min_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImm::make(dtype, std::numeric_limits::lowest()); + return IntImmNode::make(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); - return IntImm::make(dtype, val); + return IntImmNode::make(dtype, val); } } else if (dtype.is_uint()) { - return UIntImm::make(dtype, 0); + return UIntImmNode::make(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { - return FloatImm::make(dtype, std::numeric_limits::lowest()); + return FloatImmNode::make(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() == 32) { - return FloatImm::make(dtype, std::numeric_limits::lowest()); + return FloatImmNode::make(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() == 16) { - return FloatImm::make(dtype, -65504.0); + return FloatImmNode::make(dtype, -65504.0); } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; @@ -153,9 +153,9 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { } bool is_const_power_of_two_integer(const Expr& x, int* shift) { - if (const auto* op = x.as()) { + if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { + } else if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); } else { return false; @@ -163,85 +163,86 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { } Expr cast(const DataType& t, Expr value) { - using ir::IntImm; - using ir::UIntImm; - using ir::FloatImm; + using ir::IntImmNode; + using ir::UIntImmNode; + using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { - if (const IntImm* op = value.as()) { + if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImm* op = value.as()) { + } else if (const UIntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const FloatImm* op = value.as()) { + } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast DataType vtype = t.element_of(); if (value.dtype() != vtype) { - if (const IntImm* op = value.as()) { + if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImm* op = value.as()) { + } else if (const UIntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const FloatImm* op = value.as()) { + } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { - value = ir::Cast::make(vtype, value); + value = ir::CastNode::make(vtype, value); } } - return ir::Broadcast::make(value, t.lanes()); + return ir::BroadcastNode::make(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } } } Expr reinterpret(const DataType& t, Expr value) { if (value.dtype() == t) return value; - return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic); } Expr operator+(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Add::make(a, b); + return ir::AddNode::make(a, b); } // negation Expr operator-(Expr a) { - using ir::IntImm; - using ir::FloatImm; - const IntImm* pa = a.as(); - const FloatImm* fa = a.as(); - if (pa) return ir::IntImm::make(a.dtype(), -pa->value); - if (fa) return ir::FloatImm::make(a.dtype(), -fa->value); + using ir::IntImmNode; + using ir::FloatImmNode; + const IntImmNode* pa = a.as(); + const FloatImmNode* fa = a.as(); + if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value); + if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } Expr operator-(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Sub::make(a, b); + return ir::SubNode::make(a, b); } Expr operator*(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Mul::make(a, b); + return ir::MulNode::make(a, b); } Expr div(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Div::make(a, b); + return ir::DivNode::make(a, b); } Expr truncdiv(Expr a, Expr b) { @@ -252,9 +253,9 @@ Expr truncdiv(Expr a, Expr b) { Expr truncmod(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Mod::make(a, b); + return ir::ModNode::make(a, b); } Expr operator/(Expr a, Expr b) { @@ -278,18 +279,18 @@ Expr floordiv(Expr a, Expr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorDiv::make(a, b); + return ir::FloorDivNode::make(a, b); } Expr floormod(Expr a, Expr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorMod::make(a, b); + return ir::FloorModNode::make(a, b); } Expr min(Expr a, Expr b) { @@ -301,9 +302,9 @@ Expr min(Expr a, Expr b) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Min::make(a, b); + return ir::MinNode::make(a, b); } Expr max(Expr a, Expr b) { @@ -315,184 +316,194 @@ Expr max(Expr a, Expr b) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Max::make(a, b); + return ir::MaxNode::make(a, b); } Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { - using ir::IntImm; - using ir::UIntImm; + using ir::IntImmNode; + using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImm* op = cond.as()) { + if (const UIntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { return false_value; } - } else if (const IntImm* op = cond.as()) { + } else if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { return false_value; } } - return ir::Call::make( + return ir::CallNode::make( true_value.dtype(), ir::intrinsic::tvm_if_then_else, {cond, true_value, false_value}, - ir::Call::PureIntrinsic); + ir::CallNode::PureIntrinsic); } Expr likely(Expr cond) { if (is_const(cond)) return cond; - return ir::Call::make(cond.dtype(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic); + return ir::CallNode::make(cond.dtype(), + ir::CallNode::likely, + { cond }, + ir::CallNode::PureIntrinsic); } Expr operator>(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GT::make(a, b); + return ir::GTNode::make(a, b); } Expr operator>=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GE::make(a, b); + return ir::GENode::make(a, b); } Expr operator<(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LT::make(a, b); + return ir::LTNode::make(a, b); } Expr operator<=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LE::make(a, b); + return ir::LENode::make(a, b); } Expr operator==(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::EQ::make(a, b); + return ir::EQNode::make(a, b); } Expr operator!=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::NE::make(a, b); + return ir::NENode::make(a, b); } Expr operator&&(Expr a, Expr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::And::make(a, b); + return ir::AndNode::make(a, b); } Expr operator||(Expr a, Expr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Or::make(a, b); + return ir::OrNode::make(a, b); } Expr operator!(Expr a) { CHECK(a.dtype().is_bool()); - Expr ret = arith::TryConstFold(a); + Expr ret = arith::TryConstFold(a); if (ret.defined()) return ret; - return ir::Not::make(a); + return ir::NotNode::make(a); } Expr operator>>(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); + if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } }); - return ir::Call::make(a.dtype(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic); } Expr operator<<(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); + if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } }); - return ir::Call::make(a.dtype(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic); } Expr operator&(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); + if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value)); }); - return ir::Call::make(a.dtype(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); } Expr operator|(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); + if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value)); }); - return ir::Call::make(a.dtype(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); } Expr operator^(Expr a, Expr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); + if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value)); }); - return ir::Call::make(a.dtype(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); } Expr operator~(Expr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return ir::Call::make(a.dtype(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic); } Expr pow(Expr x, Expr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return ir::Call::make(x.dtype(), "pow", { x, y }, ir::Call::PureIntrinsic); + return ir::CallNode::make( + x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic); } Expr abs(Expr x) { if (x.dtype().is_int()) { - using ir::IntImm; - const IntImm* px = x.as(); + using ir::IntImmNode; + const IntImmNode* px = x.as(); if (px) { - return ir::IntImm::make(x.dtype(), std::abs(px->value)); + return ir::IntImmNode::make(x.dtype(), std::abs(px->value)); } - return ir::Select::make(x >= make_zero(x.dtype()), x, -x); + return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { - using ir::FloatImm; - const FloatImm* fx = x.as(); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); if (fx) { - return ir::FloatImm::make(x.dtype(), std::fabs(fx->value)); + return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value)); } - return ir::Call::make(x.dtype(), "fabs", {x}, ir::Call::PureIntrinsic); + return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -507,17 +518,17 @@ Expr isnan(Expr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); } else if (x.dtype().is_float()) { - using ir::FloatImm; - const FloatImm* fx = x.as(); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); if (fx) { return make_const(t, std::isnan(fx->value)); } if (x.dtype().bits() == 16) { - return ir::Call::make(t, ir::Call::isnan, + return ir::CallNode::make(t, ir::CallNode::isnan, {cast(DataType::Float(32, t.lanes()), std::move(x))}, - ir::Call::PureIntrinsic); + ir::CallNode::PureIntrinsic); } else { - return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic); + return ir::CallNode::make(t, ir::CallNode::isnan, {x}, ir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() @@ -528,102 +539,102 @@ Expr isnan(Expr x) { Expr sum(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Add::make(x, y); + Expr result = ir::AddNode::make(x, y); Expr identity_element = make_zero(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr all(Expr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::And::make(x, y); + Expr result = ir::AndNode::make(x, y); Expr identity_element = make_const(source.dtype(), true); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr any(Expr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Or::make(x, y); + Expr result = ir::OrNode::make(x, y); Expr identity_element = make_const(source.dtype(), false); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr max(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Max::make(x, y); + Expr result = ir::MaxNode::make(x, y); Expr identity_element = min_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr min(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Min::make(x, y); + Expr result = ir::MinNode::make(x, y); Expr identity_element = max_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr prod(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Mul::make(x, y); + Expr result = ir::MulNode::make(x, y); Expr identity_element = make_const(source.dtype(), 1); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } Expr fmod(Expr x, Expr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return ir::Call::make(x.dtype(), "fmod", { x, y }, ir::Call::PureIntrinsic); + return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic); } Expr floor(Expr x) { - using ir::FloatImm; - const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.dtype(), std::floor(fx->value)); - return ir::Call::make(x.dtype(), "floor", {x}, ir::Call::PureIntrinsic); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); + if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value)); + return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic); } Expr ceil(Expr x) { - using ir::FloatImm; - const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.dtype(), std::ceil(fx->value)); - return ir::Call::make(x.dtype(), "ceil", {x}, ir::Call::PureIntrinsic); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); + if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value)); + return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic); } Expr round(Expr x) { - using ir::FloatImm; - const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value)); - return ir::Call::make(x.dtype(), "round", {x}, ir::Call::PureIntrinsic); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); + if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); + return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic); } Expr nearbyint(Expr x) { - using ir::FloatImm; - const FloatImm* fx = x.as(); - if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value)); - return ir::Call::make(x.dtype(), "nearbyint", {x}, ir::Call::PureIntrinsic); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); + if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value)); + return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic); } Expr trunc(Expr x) { - using ir::FloatImm; - const FloatImm* fx = x.as(); + using ir::FloatImmNode; + const FloatImmNode* fx = x.as(); if (fx) { - return FloatImm::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : + return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic); + return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic); } } // namespace tvm diff --git a/src/lang/ir.cc b/src/lang/ir.cc index de047f3..6b777cc 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,79 +31,79 @@ namespace tvm { namespace ir { // constructors -Expr UIntImm::make(DataType t, uint64_t value) { +Expr UIntImmNode::make(DataType t, uint64_t value) { CHECK(t.is_uint() && t.lanes() == 1) << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); } -Expr FloatImm::make(DataType t, double value) { +Expr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) << "ValueError: FloatImm can only take scalar"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = t; node->value = value; return Expr(node); } -Expr StringImm::make(std::string value) { - ObjectPtr node = make_object(); +Expr StringImmNode::make(std::string value) { + ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); return Expr(node); } -Expr Cast::make(DataType t, Expr value) { +Expr CastNode::make(DataType t, Expr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); return Expr(node); } -Expr And::make(Expr a, Expr b) { +Expr AndNode::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } -Expr Or::make(Expr a, Expr b) { +Expr OrNode::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } -Expr Not::make(Expr a) { +Expr NotNode::make(Expr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); return Expr(node); } -Expr Select::make(Expr condition, Expr true_value, Expr false_value) { +Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; @@ -111,7 +111,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes()); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; - ObjectPtr(); + ObjectPtr node = make_object(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -119,14 +119,14 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { return Expr(node); } -Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { +Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { CHECK(buffer_var.defined()); CHECK(predicate.defined()); CHECK(index.defined()); CHECK_EQ(dtype.lanes(), index.dtype().lanes()); CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = dtype; node->buffer_var = std::move(buffer_var); node->index = std::move(index); @@ -135,7 +135,7 @@ Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) { return Expr(node); } -Expr Ramp::make(Expr base, Expr stride, int lanes) { +Expr RampNode::make(Expr base, Expr stride, int lanes) { CHECK(base.defined()); CHECK(stride.defined()); CHECK(base.dtype().is_scalar()); @@ -143,7 +143,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { CHECK_GT(lanes, 1); CHECK_EQ(stride.dtype(), base.dtype()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = base.dtype().with_lanes(lanes); node->base = base; node->stride = stride; @@ -151,24 +151,24 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { return Expr(node); } -Expr Broadcast::make(Expr value, int lanes) { +Expr BroadcastNode::make(Expr value, int lanes) { CHECK(value.defined()); CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; return Expr(node); } -Expr Let::make(Var var, Expr value, Expr body) { +Expr LetNode::make(Var var, Expr value, Expr body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = body.dtype(); node->var = std::move(var); node->value = std::move(value); @@ -176,23 +176,23 @@ Expr Let::make(Var var, Expr value, Expr body) { return Expr(node); } -const char* Call::vectorizable_intrinsics[] = { +const char* CallNode::vectorizable_intrinsics[] = { "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right, - ir::Call::likely, ir::Call::popcount + "log", "sin", "cos", "pow", ir::CallNode::shift_left, ir::CallNode::shift_right, + ir::CallNode::likely, ir::CallNode::popcount }; -bool Call::is_vectorizable() const { - size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*); +bool CallNode::is_vectorizable() const { + size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); for (size_t i = 0; i < cnt; ++i) { - if (name == Call::vectorizable_intrinsics[i]) { + if (name == CallNode::vectorizable_intrinsics[i]) { return true; } } return false; } -Expr Call::make(DataType dtype, +Expr CallNode::make(DataType dtype, std::string name, Array args, CallType call_type, @@ -208,7 +208,7 @@ Expr Call::make(DataType dtype, } } - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = dtype; node->name = std::move(name); node->args = std::move(args); @@ -218,7 +218,7 @@ Expr Call::make(DataType dtype, return Expr(node); } -Expr Shuffle::make(Array vectors, +Expr ShuffleNode::make(Array vectors, Array indices) { CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); @@ -232,14 +232,14 @@ Expr Shuffle::make(Array vectors, } CHECK_LE(indices.size(), static_cast(total_lanes)); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); return Expr(node); } -Expr Shuffle::make_concat(Array vectors) { +Expr ShuffleNode::make_concat(Array vectors) { CHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; @@ -248,13 +248,13 @@ Expr Shuffle::make_concat(Array vectors) { int index = 0; for (const Expr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImm::make(DataType::Int(32), index++)); + indices.push_back(IntImmNode::make(DataType::Int(32), index++)); } } return make(vectors, indices); } -Expr Shuffle::make_extract_element(Expr vector, int index) { +Expr ShuffleNode::make_extract_element(Expr vector, int index) { return make({vector}, {Integer(index)}); } @@ -284,7 +284,7 @@ Array CommReducerNode::operator()(Array a, Array b) const { }); } -Expr Reduce::make(CommReducer combiner, Array source, +Expr ReduceNode::make(CommReducer combiner, Array source, Array axis, Expr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { CHECK_EQ(axis[i]->iter_type, kCommReduce) @@ -293,7 +293,7 @@ Expr Reduce::make(CommReducer combiner, Array source, if (!condition.defined()) { condition = const_true(); } - auto n = make_object(); + auto n = make_object(); CHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); @@ -307,28 +307,28 @@ Expr Reduce::make(CommReducer combiner, Array source, return Expr(n); } -Expr Any::make() { - auto n = make_object(); +Expr AnyNode::make() { + auto n = make_object(); return Expr(n); } -Stmt LetStmt::make(Var var, Expr value, Stmt body) { +Stmt LetStmtNode::make(Var var, Expr value, Stmt body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); return Stmt(node); } -Stmt AttrStmt::make(ObjectRef node, +Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, Expr value, Stmt body) { - auto n = make_object(); + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); @@ -336,31 +336,31 @@ Stmt AttrStmt::make(ObjectRef node, return Stmt(n); } -Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) { +Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) { CHECK(condition.defined()); CHECK(message.dtype() == DataType::Int(32) || - message.as()) + message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); return Stmt(node); } -Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) { +Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { CHECK(body.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->func = std::move(func); node->is_producer = is_producer; node->body = std::move(body); return Stmt(node); } -Stmt For::make(Var loop_var, +Stmt ForNode::make(Var loop_var, Expr min, Expr extent, ForType for_type, @@ -373,7 +373,7 @@ Stmt For::make(Var loop_var, CHECK(loop_var.dtype().is_scalar()); CHECK(body.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); @@ -383,14 +383,14 @@ Stmt For::make(Var loop_var, return Stmt(node); } -Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { +Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) { CHECK(value.defined()); CHECK(index.defined()); CHECK(predicate.defined()); CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->value = std::move(value); node->index = std::move(index); @@ -398,7 +398,7 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) { return Stmt(node); } -Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array args) { +Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array args) { CHECK(value_index >=0 && value_index < func->num_outputs()) << "value index output function return value bound"; CHECK(value.defined()) << "Provide of undefined value\n"; @@ -407,7 +407,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar CHECK(args[i].defined()) << "Provide to undefined location\n"; } - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->value = std::move(value); @@ -415,7 +415,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array ar return Stmt(node); } -Stmt Allocate::make(Var buffer_var, +Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, Expr condition, @@ -430,7 +430,7 @@ Stmt Allocate::make(Var buffer_var, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; node->extents = std::move(extents); @@ -441,10 +441,10 @@ Stmt Allocate::make(Var buffer_var, return Stmt(node); } -int32_t Allocate::constant_allocation_size(const Array& extents) { +int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImm *int_size = extents[i].as()) { + if (const IntImmNode *int_size = extents[i].as()) { result *= int_size->value; if (result > std::numeric_limits::max()) { return 0; @@ -456,13 +456,13 @@ int32_t Allocate::constant_allocation_size(const Array& extents) { return static_cast(result); } -Stmt Free::make(Var buffer_var) { - ObjectPtr node = make_object(); +Stmt FreeNode::make(Var buffer_var) { + ObjectPtr node = make_object(); node->buffer_var = buffer_var; return Stmt(node); } -Stmt Realize::make(FunctionRef func, +Stmt RealizeNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds, @@ -478,7 +478,7 @@ Stmt Realize::make(FunctionRef func, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -488,7 +488,7 @@ Stmt Realize::make(FunctionRef func, return Stmt(node); } -Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { +Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); @@ -496,7 +496,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo CHECK(bounds[i]->extent.dtype().is_scalar()); } - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->func = std::move(func); node->value_index = value_index; node->dtype = dtype; @@ -510,36 +510,36 @@ SeqStmt::SeqStmt(Array seq) { data_ = std::move(node); } -Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { +Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) { CHECK(condition.defined()); CHECK(then_case.defined()); // else_case may be null. - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); return Stmt(node); } -Stmt Evaluate::make(Expr value) { +Stmt EvaluateNode::make(Expr value) { CHECK(value.defined()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->value = std::move(value); return Stmt(node); } // Printers TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "(" << op->dtype << ")" << op->value; }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); auto& stream = p->stream; switch (op->dtype.bits()) { case 64: @@ -557,8 +557,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); auto& stream = p->stream; stream << '"'; for (size_t i = 0; i < op->value.size(); ++i) { @@ -593,116 +593,116 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << op->dtype << '('; p->Print(op->value); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); // omit the type // stream << op->name << "." << op->type; p->stream << op->name_hint; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " + "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " - "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << "*"; p->Print(op->b); p->stream << ')'; }) -.set_dispatch
([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << "/"; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " % "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "min("; p->Print(op->a); p->stream << ", "; p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "max("; p->Print(op->a); p->stream << ", "; p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " == "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " != "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " < "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " <= "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " > "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " >= "; @@ -711,20 +711,20 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " && "; @@ -733,8 +733,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " || "; @@ -743,15 +743,15 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '!'; p->Print(op->a); }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch(); + void VisitExpr_(const SelectNode *op, const Expr& other) final { + const SelectNode* rhs = other.as(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->true_value, rhs->true_value) != 0) return; if (CompareExpr(op->false_value, rhs->false_value) != 0) return; } - void VisitExpr_(const Ramp *op, const Expr& other) final { - const Ramp* rhs = other.as(); + void VisitExpr_(const RampNode *op, const Expr& other) final { + const RampNode* rhs = other.as(); if (CompareExpr(op->base, rhs->base) != 0) return; if (CompareExpr(op->stride, rhs->stride) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } - void VisitExpr_(const Broadcast *op, const Expr& other) final { - const Broadcast* rhs = other.as(); + void VisitExpr_(const BroadcastNode *op, const Expr& other) final { + const BroadcastNode* rhs = other.as(); if (CompareExpr(op->value, rhs->value) != 0) return; if (CompareValue(op->lanes, rhs->lanes) != 0) return; } - void VisitExpr_(const Shuffle *op, const Expr& other) final { - const Shuffle* rhs = other.as(); + void VisitExpr_(const ShuffleNode *op, const Expr& other) final { + const ShuffleNode* rhs = other.as(); if (CompareArray(op->vectors, rhs->vectors) != 0) return; if (CompareArray(op->indices, rhs->indices) != 0) return; } - DEFINE_BIOP_EXPR_CMP_(Add) - DEFINE_BIOP_EXPR_CMP_(Sub) - DEFINE_BIOP_EXPR_CMP_(Mul) - DEFINE_BIOP_EXPR_CMP_(Div) - DEFINE_BIOP_EXPR_CMP_(Mod) - DEFINE_BIOP_EXPR_CMP_(FloorDiv) - DEFINE_BIOP_EXPR_CMP_(FloorMod) - DEFINE_BIOP_EXPR_CMP_(Min) - DEFINE_BIOP_EXPR_CMP_(Max) - DEFINE_BIOP_EXPR_CMP_(EQ) - DEFINE_BIOP_EXPR_CMP_(NE) - DEFINE_BIOP_EXPR_CMP_(LT) - DEFINE_BIOP_EXPR_CMP_(LE) - DEFINE_BIOP_EXPR_CMP_(GT) - DEFINE_BIOP_EXPR_CMP_(GE) - DEFINE_BIOP_EXPR_CMP_(And) - DEFINE_BIOP_EXPR_CMP_(Or) + DEFINE_BIOP_EXPR_CMP_(AddNode) + DEFINE_BIOP_EXPR_CMP_(SubNode) + DEFINE_BIOP_EXPR_CMP_(MulNode) + DEFINE_BIOP_EXPR_CMP_(DivNode) + DEFINE_BIOP_EXPR_CMP_(ModNode) + DEFINE_BIOP_EXPR_CMP_(FloorDivNode) + DEFINE_BIOP_EXPR_CMP_(FloorModNode) + DEFINE_BIOP_EXPR_CMP_(MinNode) + DEFINE_BIOP_EXPR_CMP_(MaxNode) + DEFINE_BIOP_EXPR_CMP_(EQNode) + DEFINE_BIOP_EXPR_CMP_(NENode) + DEFINE_BIOP_EXPR_CMP_(LTNode) + DEFINE_BIOP_EXPR_CMP_(LENode) + DEFINE_BIOP_EXPR_CMP_(GTNode) + DEFINE_BIOP_EXPR_CMP_(GENode) + DEFINE_BIOP_EXPR_CMP_(AndNode) + DEFINE_BIOP_EXPR_CMP_(OrNode) private: int CompareExpr(const Expr& lhs, const Expr& rhs) { @@ -430,7 +430,7 @@ class IRDeepCompare : // Only equality/non-equality information is valid. bool tie_def_{false}; // varaible remap if any - std::unordered_map vmap_; + std::unordered_map vmap_; }; diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index dddf90e..b7a7362 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -123,7 +123,7 @@ Stmt IRTransform(Stmt ir_node, const Array& only_enable) { std::unordered_set only_type_index; for (Expr s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); + only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); } IRTransformer transform(f_preorder, f_postorder, only_type_index); return transform(std::move(ir_node)); @@ -137,23 +137,23 @@ inline void VisitArray(const Array& arr, F fvisit) { } } -void StmtVisitor::VisitStmt_(const LetStmt* op) { +void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const AttrStmt* op) { +void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const For* op) { +void StmtVisitor::VisitStmt_(const ForNode* op) { this->VisitExpr(op->min); this->VisitExpr(op->extent); this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const Allocate* op) { +void StmtVisitor::VisitStmt_(const AllocateNode* op) { VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); this->VisitExpr(op->condition); @@ -162,13 +162,13 @@ void StmtVisitor::VisitStmt_(const Allocate* op) { } } -void StmtVisitor::VisitStmt_(const Store* op) { +void StmtVisitor::VisitStmt_(const StoreNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->index); this->VisitExpr(op->predicate); } -void StmtVisitor::VisitStmt_(const IfThenElse* op) { +void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); if (op->else_case.defined()) { @@ -176,24 +176,24 @@ void StmtVisitor::VisitStmt_(const IfThenElse* op) { } } -void StmtVisitor::VisitStmt_(const Free* op) {} +void StmtVisitor::VisitStmt_(const FreeNode* op) {} -void StmtVisitor::VisitStmt_(const AssertStmt* op) { +void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitExpr(op->condition); this->VisitExpr(op->message); this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProducerConsumer* op) { +void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const Provide* op) { +void StmtVisitor::VisitStmt_(const ProvideNode* op) { VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); } -void StmtVisitor::VisitStmt_(const Realize* op) { +void StmtVisitor::VisitStmt_(const RealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); @@ -202,7 +202,7 @@ void StmtVisitor::VisitStmt_(const Realize* op) { this->VisitExpr(op->condition); } -void StmtVisitor::VisitStmt_(const Prefetch* op) { +void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); this->VisitExpr(r->extent); @@ -215,23 +215,23 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { }); } -void StmtVisitor::VisitStmt_(const Evaluate* op) { +void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const Variable* op) {} +void ExprVisitor::VisitExpr_(const VarNode* op) {} -void ExprVisitor::VisitExpr_(const Load* op) { +void ExprVisitor::VisitExpr_(const LoadNode* op) { this->VisitExpr(op->index); this->VisitExpr(op->predicate); } -void ExprVisitor::VisitExpr_(const Let* op) { +void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->body); } -void ExprVisitor::VisitExpr_(const Call* op) { +void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); }); } @@ -241,30 +241,30 @@ void ExprVisitor::VisitExpr_(const Call* op) { this->VisitExpr(op->b); \ } -DEFINE_BINOP_VISIT_(Add); -DEFINE_BINOP_VISIT_(Sub); -DEFINE_BINOP_VISIT_(Mul); -DEFINE_BINOP_VISIT_(Div); -DEFINE_BINOP_VISIT_(Mod); -DEFINE_BINOP_VISIT_(FloorDiv); -DEFINE_BINOP_VISIT_(FloorMod); -DEFINE_BINOP_VISIT_(Min); -DEFINE_BINOP_VISIT_(Max); -DEFINE_BINOP_VISIT_(EQ); -DEFINE_BINOP_VISIT_(NE); -DEFINE_BINOP_VISIT_(LT); -DEFINE_BINOP_VISIT_(LE); -DEFINE_BINOP_VISIT_(GT); -DEFINE_BINOP_VISIT_(GE); -DEFINE_BINOP_VISIT_(And); -DEFINE_BINOP_VISIT_(Or); - -void ExprVisitor::VisitExpr_(const IntImm* op) {} -void ExprVisitor::VisitExpr_(const UIntImm* op) {} -void ExprVisitor::VisitExpr_(const FloatImm* op) {} -void ExprVisitor::VisitExpr_(const StringImm* op) {} - -void ExprVisitor::VisitExpr_(const Reduce* op) { +DEFINE_BINOP_VISIT_(AddNode); +DEFINE_BINOP_VISIT_(SubNode); +DEFINE_BINOP_VISIT_(MulNode); +DEFINE_BINOP_VISIT_(DivNode); +DEFINE_BINOP_VISIT_(ModNode); +DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(FloorModNode); +DEFINE_BINOP_VISIT_(MinNode); +DEFINE_BINOP_VISIT_(MaxNode); +DEFINE_BINOP_VISIT_(EQNode); +DEFINE_BINOP_VISIT_(NENode); +DEFINE_BINOP_VISIT_(LTNode); +DEFINE_BINOP_VISIT_(LENode); +DEFINE_BINOP_VISIT_(GTNode); +DEFINE_BINOP_VISIT_(GENode); +DEFINE_BINOP_VISIT_(AndNode); +DEFINE_BINOP_VISIT_(OrNode); + +void ExprVisitor::VisitExpr_(const IntImmNode* op) {} +void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} +void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} +void ExprVisitor::VisitExpr_(const StringImmNode* op) {} + +void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { this->VisitExpr(r->dom->min); this->VisitExpr(r->dom->extent); @@ -273,31 +273,31 @@ void ExprVisitor::VisitExpr_(const Reduce* op) { this->VisitExpr(op->condition); } -void ExprVisitor::VisitExpr_(const Cast* op) { +void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const Not* op) { +void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } -void ExprVisitor::VisitExpr_(const Select* op) { +void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); this->VisitExpr(op->true_value); this->VisitExpr(op->false_value); } -void ExprVisitor::VisitExpr_(const Ramp* op) { +void ExprVisitor::VisitExpr_(const RampNode* op) { this->VisitExpr(op->base); this->VisitExpr(op->stride); } -void ExprVisitor::VisitExpr_(const Shuffle* op) { +void ExprVisitor::VisitExpr_(const ShuffleNode* op) { VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); }); VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); }); } -void ExprVisitor::VisitExpr_(const Broadcast* op) { +void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } @@ -344,7 +344,7 @@ class StmtMutator::Internal { } }; -Stmt StmtMutator::VisitStmt_(const AttrStmt* op) { +Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { Expr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && @@ -358,7 +358,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmt* op) { } } -Stmt StmtMutator::VisitStmt_(const LetStmt* op) { +Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { Expr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && @@ -372,7 +372,7 @@ Stmt StmtMutator::VisitStmt_(const LetStmt* op) { } } -Stmt StmtMutator::VisitStmt_(const For* op) { +Stmt StmtMutator::VisitStmt_(const ForNode* op) { Expr min = this->VisitExpr(op->min); Expr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); @@ -389,7 +389,7 @@ Stmt StmtMutator::VisitStmt_(const For* op) { } } -Stmt StmtMutator::VisitStmt_(const Allocate* op) { +Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); Expr condition = this->VisitExpr(op->condition); @@ -412,7 +412,7 @@ Stmt StmtMutator::VisitStmt_(const Allocate* op) { } } -Stmt StmtMutator::VisitStmt_(const IfThenElse* op) { +Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { Expr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); Stmt else_case; @@ -432,7 +432,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElse* op) { } } -Stmt StmtMutator::VisitStmt_(const Store* op) { +Stmt StmtMutator::VisitStmt_(const StoreNode* op) { Expr value = this->VisitExpr(op->value); Expr index = this->VisitExpr(op->index); Expr predicate = this->VisitExpr(op->predicate); @@ -449,7 +449,7 @@ Stmt StmtMutator::VisitStmt_(const Store* op) { } } -Stmt StmtMutator::VisitStmt_(const Provide* op) { +Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Array args = Internal::Mutate(this, op->args); Expr value = this->VisitExpr(op->value); if (args.same_as(op->args) && @@ -463,7 +463,7 @@ Stmt StmtMutator::VisitStmt_(const Provide* op) { } } -Stmt StmtMutator::VisitStmt_(const Realize* op) { +Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); Expr condition = this->VisitExpr(op->condition); @@ -480,7 +480,7 @@ Stmt StmtMutator::VisitStmt_(const Realize* op) { } } -Stmt StmtMutator::VisitStmt_(const Prefetch* op) { +Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { Region bounds = Internal::Mutate(this, op->bounds); if (bounds.same_as(op->bounds)) { return GetRef(op); @@ -548,7 +548,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, } } -Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { +Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { Expr condition = this->VisitExpr(op->condition); Expr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); @@ -566,7 +566,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { } } -Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) { +Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) { Stmt body = this->VisitStmt(op->body); if (body.same_as(op->body)) { return GetRef(op); @@ -577,7 +577,7 @@ Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) { } } -Stmt StmtMutator::VisitStmt_(const Evaluate* op) { +Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); @@ -588,44 +588,44 @@ Stmt StmtMutator::VisitStmt_(const Evaluate* op) { } } -Stmt StmtMutator::VisitStmt_(const Free* op) { +Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const Variable* op) { +Expr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const Load* op) { +Expr ExprMutator::VisitExpr_(const LoadNode* op) { Expr index = this->VisitExpr(op->index); Expr predicate = this->VisitExpr(op->predicate); if (index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { - return Load::make(op->dtype, op->buffer_var, index, predicate); + return LoadNode::make(op->dtype, op->buffer_var, index, predicate); } } -Expr ExprMutator::VisitExpr_(const Let* op) { +Expr ExprMutator::VisitExpr_(const LetNode* op) { Expr value = this->VisitExpr(op->value); Expr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return Let::make(op->var, value, body); + return LetNode::make(op->var, value, body); } } -Expr ExprMutator::VisitExpr_(const Call* op) { +Expr ExprMutator::VisitExpr_(const CallNode* op) { auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); }; Array args = MutateArray(op->args, fmutate); if (args.same_as(op->args)) { return GetRef(op); } else { - return Call::make(op->dtype, + return CallNode::make(op->dtype, op->name, args, op->call_type, @@ -639,10 +639,10 @@ Expr ExprMutator::VisitExpr_(const Call* op) { return GetRef(op); \ } -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ Expr ExprMutator::VisitExpr_(const OP* op) { \ @@ -656,25 +656,25 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) } \ } -DEFINE_BIOP_EXPR_MUTATE_(Add); -DEFINE_BIOP_EXPR_MUTATE_(Sub); -DEFINE_BIOP_EXPR_MUTATE_(Mul); -DEFINE_BIOP_EXPR_MUTATE_(Div); -DEFINE_BIOP_EXPR_MUTATE_(Mod); -DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); -DEFINE_BIOP_EXPR_MUTATE_(FloorMod); -DEFINE_BIOP_EXPR_MUTATE_(Min); -DEFINE_BIOP_EXPR_MUTATE_(Max); -DEFINE_BIOP_EXPR_MUTATE_(EQ); -DEFINE_BIOP_EXPR_MUTATE_(NE); -DEFINE_BIOP_EXPR_MUTATE_(LT); -DEFINE_BIOP_EXPR_MUTATE_(LE); -DEFINE_BIOP_EXPR_MUTATE_(GT); -DEFINE_BIOP_EXPR_MUTATE_(GE); -DEFINE_BIOP_EXPR_MUTATE_(And); -DEFINE_BIOP_EXPR_MUTATE_(Or); - -Expr ExprMutator::VisitExpr_(const Reduce* op) { +DEFINE_BIOP_EXPR_MUTATE_(AddNode); +DEFINE_BIOP_EXPR_MUTATE_(SubNode); +DEFINE_BIOP_EXPR_MUTATE_(MulNode); +DEFINE_BIOP_EXPR_MUTATE_(DivNode); +DEFINE_BIOP_EXPR_MUTATE_(ModNode); +DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); +DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); +DEFINE_BIOP_EXPR_MUTATE_(MinNode); +DEFINE_BIOP_EXPR_MUTATE_(MaxNode); +DEFINE_BIOP_EXPR_MUTATE_(EQNode); +DEFINE_BIOP_EXPR_MUTATE_(NENode); +DEFINE_BIOP_EXPR_MUTATE_(LTNode); +DEFINE_BIOP_EXPR_MUTATE_(LENode); +DEFINE_BIOP_EXPR_MUTATE_(GTNode); +DEFINE_BIOP_EXPR_MUTATE_(GENode); +DEFINE_BIOP_EXPR_MUTATE_(AndNode); +DEFINE_BIOP_EXPR_MUTATE_(OrNode); + +Expr ExprMutator::VisitExpr_(const ReduceNode* op) { auto fitervar = [this](const IterVar& v) { Range r = v->dom; Expr min = this->VisitExpr(r->min); @@ -700,30 +700,30 @@ Expr ExprMutator::VisitExpr_(const Reduce* op) { condition.same_as(op->condition)) { return GetRef(op); } else { - return Reduce::make( + return ReduceNode::make( op->combiner, source, axis, condition, op->value_index); } } -Expr ExprMutator::VisitExpr_(const Cast* op) { +Expr ExprMutator::VisitExpr_(const CastNode* op) { Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast::make(op->dtype, value); + return CastNode::make(op->dtype, value); } } -Expr ExprMutator::VisitExpr_(const Not* op) { +Expr ExprMutator::VisitExpr_(const NotNode* op) { Expr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { return GetRef(op); } else { - return Not::make(a); + return NotNode::make(a); } } -Expr ExprMutator::VisitExpr_(const Select* op) { +Expr ExprMutator::VisitExpr_(const SelectNode* op) { Expr condition = this->VisitExpr(op->condition); Expr true_value = this->VisitExpr(op->true_value); Expr false_value = this->VisitExpr(op->false_value); @@ -732,37 +732,37 @@ Expr ExprMutator::VisitExpr_(const Select* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return Select::make(condition, true_value, false_value); + return SelectNode::make(condition, true_value, false_value); } } -Expr ExprMutator::VisitExpr_(const Ramp* op) { +Expr ExprMutator::VisitExpr_(const RampNode* op) { Expr base = this->VisitExpr(op->base); Expr stride = this->VisitExpr(op->stride); if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { - return Ramp::make(base, stride, op->lanes); + return RampNode::make(base, stride, op->lanes); } } -Expr ExprMutator::VisitExpr_(const Broadcast* op) { +Expr ExprMutator::VisitExpr_(const BroadcastNode* op) { Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Broadcast::make(value, op->lanes); + return BroadcastNode::make(value, op->lanes); } } -Expr ExprMutator::VisitExpr_(const Shuffle* op) { +Expr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); }; auto vectors = MutateArray(op->vectors, fexpr); if (vectors.same_as(op->vectors)) { return GetRef(op); } else { - return Shuffle::make(vectors, op->indices); + return ShuffleNode::make(vectors, op->indices); } } diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index 8956a4d..8ecfbff 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -30,23 +30,23 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { // use reverse iteration for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { Stmt s = *ri; - if (const auto* for_ = s.as()) { - auto n = make_object(*for_); + if (const auto* for_ = s.as()) { + auto n = make_object(*for_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* let = s.as()) { - auto n = make_object(*let); + } else if (const auto* let = s.as()) { + auto n = make_object(*let); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* attr = s.as()) { - auto n = make_object(*attr); + } else if (const auto* attr = s.as()) { + auto n = make_object(*attr); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* ite = s.as()) { - auto n = make_object(*ite); + } else if (const auto* ite = s.as()) { + auto n = make_object(*ite); CHECK(is_no_op(n->then_case)); CHECK(!n->else_case.defined()); n->then_case = body; @@ -56,13 +56,13 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); n->seq.Set(n->size() - 1, body); body = Stmt(n); - } else if (const auto* assert_ = s.as()) { - auto n = make_object(*assert_); + } else if (const auto* assert_ = s.as()) { + auto n = make_object(*assert_); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); + } else if (const auto* alloc = s.as()) { + auto n = make_object(*alloc); CHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 900d6d5..74d5781 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -88,7 +88,7 @@ inline Expr TVMStructGet( handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic); + return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } /*! @@ -98,11 +98,11 @@ inline Expr TVMStructGet( * \param offset the offset index. */ inline Expr AddressOffset(Var handle, DataType dtype, int offset) { - return Call::make( + return CallNode::make( DataType::Handle(), intrinsic::tvm_address_of, - {Load::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), + {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), const_true(dtype.lanes()))}, - Call::PureIntrinsic); + CallNode::PureIntrinsic); } /*! @@ -114,13 +114,13 @@ inline Expr AddressOffset(Var handle, DataType dtype, int offset) { inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) { if (dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); - offset = Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call::make( + return CallNode::make( DataType::Handle(), intrinsic::tvm_address_of, - {Load::make(dtype, handle, offset, + {LoadNode::make(dtype, handle, offset, const_true(dtype.lanes()))}, - Call::PureIntrinsic); + CallNode::PureIntrinsic); } /*! @@ -139,8 +139,8 @@ inline Stmt TVMStructSet( make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return Evaluate::make( - Call::make(DataType::Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } /*! @@ -183,7 +183,7 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { * \return true if pattern match success and store the base to base. */ inline bool GetRamp1Base(Expr index, int lanes, Expr *base) { - const Ramp* r = index.as(); + const RampNode* r = index.as(); if (!r) return false; if (!is_one(r->stride)) return false; CHECK_EQ(r->lanes, lanes); diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 4f2df7b..9a97031 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -40,23 +40,23 @@ class AttrScopeLifter : public StmtMutator { Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { - stmt = AttrStmt::make( + stmt = AttrStmtNode::make( attr_node_, attr_key_, attr_value_, stmt); } return stmt; } // do not go beyond - Stmt VisitStmt_(const Allocate* op) final { + Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (attr_node_.defined()) { - Stmt body = AttrStmt::make( + Stmt body = AttrStmtNode::make( attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = Expr(); - return Allocate::make( + return AllocateNode::make( op->buffer_var, op->dtype, op->extents, op->condition, body, op->new_expr, op->free_function); @@ -65,7 +65,7 @@ class AttrScopeLifter : public StmtMutator { } } - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr_key_) { attr_node_ = op->node; attr_value_ = op->value; @@ -116,7 +116,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { - stmt = AttrStmt::make( + stmt = AttrStmtNode::make( attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); @@ -127,7 +127,7 @@ class AttrScopeLifter : public StmtMutator { return SeqStmt::Flatten(reorg); } - Stmt VisitStmt_(const IfThenElse* op) final { + Stmt VisitStmt_(const IfThenElseNode* op) final { if (!op->else_case.defined()) { return StmtMutator::VisitStmt_(op); } @@ -147,15 +147,15 @@ class AttrScopeLifter : public StmtMutator { else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElse::make(op->condition, then_case, else_case); + return IfThenElseNode::make(op->condition, then_case, else_case); } } else { if (first_node.defined()) { - then_case = AttrStmt::make( + then_case = AttrStmtNode::make( first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { - else_case = AttrStmt::make( + else_case = AttrStmtNode::make( attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); @@ -165,7 +165,7 @@ class AttrScopeLifter : public StmtMutator { else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElse::make(op->condition, then_case, else_case); + return IfThenElseNode::make(op->condition, then_case, else_case); } } } @@ -177,11 +177,11 @@ class AttrScopeLifter : public StmtMutator { if (!a.defined() || !b.defined()) return false; if (a->type_index() != b->type_index()) return false; if (a.dtype() != b.dtype()) return false; - if (const IntImm* op = a.as()) { - return op->value == b.as()->value; + if (const IntImmNode* op = a.as()) { + return op->value == b.as()->value; } - if (const UIntImm* op = a.as()) { - return op->value == b.as()->value; + if (const UIntImmNode* op = a.as()) { + return op->value == b.as()->value; } return false; } diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index aa8ebe1..7d9ce62 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -49,10 +49,10 @@ struct PartitionKeyHash { // condition cond is proven to have value cond_value (true or false) in interval. using Partition = std::unordered_map; -bool ExprUseVars(Expr expr, const std::unordered_set& vars) { +bool ExprUseVars(Expr expr, const std::unordered_set& vars) { bool success = false; PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { - if (const Variable* v = node.as()) { + if (const VarNode* v = node.as()) { if (vars.count(v)) { success = true; return; @@ -72,10 +72,10 @@ class CandidateSelector final : public StmtExprVisitor { explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {} - void VisitStmt_(const For* op) final { + void VisitStmt_(const ForNode* op) final { // partition const loop when sets split_const_loop_ if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) { - const Variable* var = op->loop_var.get(); + const VarNode* var = op->loop_var.get(); record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { @@ -87,7 +87,7 @@ class CandidateSelector final : public StmtExprVisitor { } } - void VisitStmt_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { const IterVarNode *iv = op->node.as(); CHECK(iv); @@ -118,8 +118,8 @@ class CandidateSelector final : public StmtExprVisitor { } } - void VisitExpr_(const Call* op) final { - if (op->is_intrinsic(Call::likely)) { + void VisitExpr_(const CallNode* op) final { + if (op->is_intrinsic(CallNode::likely)) { in_likely_ = true; StmtExprVisitor::VisitExpr_(op); in_likely_ = false; @@ -132,7 +132,7 @@ class CandidateSelector final : public StmtExprVisitor { } } - void VisitExpr_(const Variable* op) final { + void VisitExpr_(const VarNode* op) final { if (in_likely_ && record_.count(op)) { record_.at(op) = true; } @@ -144,7 +144,7 @@ class CandidateSelector final : public StmtExprVisitor { bool in_likely_{false}; bool no_split_{false}; bool split_const_loop_{false}; - std::unordered_map record_; + std::unordered_map record_; }; // Populate partitions data structure, i.e., for a specific variable, @@ -153,8 +153,8 @@ class CandidateSelector final : public StmtExprVisitor { class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(VarExpr current_var, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { for (const auto& kv : hint_map) { out_vars_.insert(kv.first); @@ -164,10 +164,10 @@ class PartitionFinder : public StmtExprVisitor { } } - void VisitStmt_(const For* op) final { + void VisitStmt_(const ForNode* op) final { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; - const Variable* var = op->loop_var.get(); + const VarNode* var = op->loop_var.get(); hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); StmtExprVisitor::VisitStmt_(op); @@ -175,12 +175,12 @@ class PartitionFinder : public StmtExprVisitor { hint_map_.erase(var); } - void VisitStmt_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmtNode* op) final { // handle thread_axis if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); - const Variable* var = thread_axis->var.get(); + const VarNode* var = thread_axis->var.get(); IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); hint_map_.insert({var, dom}); relax_map_.insert({var, dom}); @@ -192,11 +192,11 @@ class PartitionFinder : public StmtExprVisitor { } } - void VisitExpr_(const Call* op) final { - if (op->is_intrinsic(Call::likely)) { + void VisitExpr_(const CallNode* op) final { + if (op->is_intrinsic(CallNode::likely)) { Expr cond = op->args[0]; if (ExprUseVars(cond, - std::unordered_set({current_var_.get()}))) { + std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. @@ -226,32 +226,32 @@ class PartitionFinder : public StmtExprVisitor { private: Expr InverseCond(const Expr& cond) { Expr inverse_cond; - if (const LT* op = cond.as()) { + if (const LTNode* op = cond.as()) { // a < b -> a >= b - inverse_cond = GE::make(op->a, op->b); - } else if (const GT* op = cond.as()) { + inverse_cond = GENode::make(op->a, op->b); + } else if (const GTNode* op = cond.as()) { // a > b -> a <= b - inverse_cond = LE::make(op->a, op->b); - } else if (const LE* op = cond.as()) { + inverse_cond = LENode::make(op->a, op->b); + } else if (const LENode* op = cond.as()) { // a <= b -> a > b - inverse_cond = GT::make(op->a, op->b); - } else if (const GE* op = cond.as()) { + inverse_cond = GTNode::make(op->a, op->b); + } else if (const GENode* op = cond.as()) { // a >= b -> a < b - inverse_cond = LT::make(op->a, op->b); - } else if (const EQ* op = cond.as()) { + inverse_cond = LTNode::make(op->a, op->b); + } else if (const EQNode* op = cond.as()) { // a == b -> a != b - inverse_cond = NE::make(op->a, op->b); + inverse_cond = NENode::make(op->a, op->b); // a != b -> a == b - } else if (const NE* op = cond.as()) { - inverse_cond = EQ::make(op->a, op->b); + } else if (const NENode* op = cond.as()) { + inverse_cond = EQNode::make(op->a, op->b); } return inverse_cond; } VarExpr current_var_; - std::unordered_set out_vars_; - std::unordered_map hint_map_; - std::unordered_map relax_map_; + std::unordered_set out_vars_; + std::unordered_map hint_map_; + std::unordered_map relax_map_; }; // Replace the set of conditions given by ps with cond_value (true or false) @@ -279,16 +279,16 @@ class ThreadPartitionInserter : public StmtMutator { explicit ThreadPartitionInserter(const std::unordered_set& ps, Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { innermost_thread_scope_ = true; Stmt stmt = StmtMutator::VisitStmt_(op); // add branch code inside the innermost thread scope if (innermost_thread_scope_) { Stmt simplified_body = ConditionEliminator(ps_)(op->body); - Stmt body = IfThenElse::make(cond_, simplified_body, op->body); + Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body); Expr value = this->VisitExpr(op->value); - stmt = AttrStmt::make(op->node, op->attr_key, value, body); + stmt = AttrStmtNode::make(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; return stmt; @@ -315,7 +315,7 @@ class LoopPartitioner : public StmtMutator { return operator()(std::move(stmt)); } - Stmt VisitStmt_(const For* op) final { + Stmt VisitStmt_(const ForNode* op) final { if (selector.candidates.count(op)) { Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, op->body, false); @@ -331,7 +331,7 @@ class LoopPartitioner : public StmtMutator { return res; } - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key != attr::thread_extent) { return StmtMutator::VisitStmt_(op); } @@ -374,8 +374,8 @@ class LoopPartitioner : public StmtMutator { inline Stmt MakeFor(const Object* op, Expr extent, Stmt body); /* Candidate IRs that may be partitioned potentially */ - std::unordered_map hint_map_; - std::unordered_map relax_map_; + std::unordered_map hint_map_; + std::unordered_map relax_map_; arith::Analyzer analyzer_; CandidateSelector selector; }; @@ -506,7 +506,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); + body_begin = MaxNode::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } @@ -532,7 +532,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max+1); + post_doubt_begin = MinNode::make(post_doubt_begin, max+1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } @@ -581,21 +581,21 @@ Stmt LoopPartitioner::TryPartition(const Object* node, } inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) { - const For *for_node = static_cast(node); + const ForNode *for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return For::make(for_node->loop_var, 0, extent, + return ForNode::make(for_node->loop_var, 0, extent, for_node->for_type, for_node->device_api, body); } } class RemoveLikelyTags : public StmtExprMutator { public: - Expr VisitExpr_(const Call *op) final { - if (op->is_intrinsic(Call::likely)) { + Expr VisitExpr_(const CallNode *op) final { + if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else { diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index 2440b1f..ded17d4 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -41,14 +41,14 @@ class CustomDatatypesLowerer : public StmtExprMutator { public: explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} - inline Expr VisitExpr_(const Cast* op) final { + inline Expr VisitExpr_(const CastNode* op) final { auto type_code = op->dtype.code(); auto src_type_code = op->value.dtype().code(); // If either datatype is a registered custom datatype, we must lower. bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || datatype::Registry::Global()->GetTypeRegistered(src_type_code); Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); if (toBeLowered) { auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); CHECK(lower) << "Cast lowering function for target " << target_ << " destination type " @@ -59,7 +59,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } - inline Expr VisitExpr_(const FloatImm* imm) final { + inline Expr VisitExpr_(const FloatImmNode* imm) final { auto type_code = imm->dtype.code(); auto e = GetRef(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { @@ -71,37 +71,37 @@ class CustomDatatypesLowerer : public StmtExprMutator { return e; } - inline Stmt VisitStmt_(const Allocate* allocate) final { + inline Stmt VisitStmt_(const AllocateNode* allocate) final { bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code()); Stmt stmt = StmtExprMutator::VisitStmt_(allocate); - allocate = stmt.as(); + allocate = stmt.as(); if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents, + return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, allocate->condition, allocate->body, allocate->new_expr, allocate->free_function); } return stmt; } - inline Expr VisitExpr_(const Load* load) final { + inline Expr VisitExpr_(const LoadNode* load) final { bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code()); Expr expr = StmtExprMutator::VisitExpr_(load); - load = expr.as(); + load = expr.as(); if (toBeLowered) { auto new_load_type = DataType::UInt(load->dtype.bits()); - return Load::make(new_load_type, load->buffer_var, load->index, load->predicate); + return LoadNode::make(new_load_type, load->buffer_var, load->index, load->predicate); } return expr; } -#define DEFINE_MUTATE__(OP) \ - inline Expr VisitExpr_(const OP* op) final { \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline Expr VisitExpr_(const NodeName* op) final { \ auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ Expr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ + op = expr.as(); \ if (toBeLowered) { \ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ @@ -111,19 +111,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; \ } - DEFINE_MUTATE__(Add) - DEFINE_MUTATE__(Sub) - DEFINE_MUTATE__(Mul) - DEFINE_MUTATE__(Div) - DEFINE_MUTATE__(Mod) - DEFINE_MUTATE__(Min) - DEFINE_MUTATE__(Max) - DEFINE_MUTATE__(EQ) - DEFINE_MUTATE__(NE) - DEFINE_MUTATE__(LT) - DEFINE_MUTATE__(LE) - DEFINE_MUTATE__(GT) - DEFINE_MUTATE__(GE) + DEFINE_MUTATE__(Add, AddNode); + DEFINE_MUTATE__(Sub, SubNode); + DEFINE_MUTATE__(Mul, MulNode); + DEFINE_MUTATE__(Div, DivNode); + DEFINE_MUTATE__(Mod, ModNode); + DEFINE_MUTATE__(Min, MinNode); + DEFINE_MUTATE__(Max, MaxNode); + DEFINE_MUTATE__(EQ, EQNode); + DEFINE_MUTATE__(NE, NENode); + DEFINE_MUTATE__(LT, LTNode); + DEFINE_MUTATE__(LE, LENode); + DEFINE_MUTATE__(GT, GTNode); + DEFINE_MUTATE__(GE, GENode); // Later changes may need to add more mutate functions as we support workloads with more ops. private: diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 0f49710..b46bf18 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -53,19 +53,19 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } - Expr VisitExpr_(const Call* op) final { - if (op->call_type == Call::Intrinsic || - op->call_type == Call::PureIntrinsic) { + Expr VisitExpr_(const CallNode* op) final { + if (op->call_type == CallNode::Intrinsic || + op->call_type == CallNode::PureIntrinsic) { Expr r = ApplyPattern(op->name, GetRef(op)); if (r.defined()) return r; } return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const Add* op) final { - if (const Mul* mb = op->b.as()) { + Expr VisitExpr_(const AddNode* op) final { + if (const MulNode* mb = op->b.as()) { return MakeFMA(mb->a, mb->b, op->a, op); - } else if (const Mul* ma = op->a.as()) { + } else if (const MulNode* ma = op->a.as()) { return MakeFMA(ma->a, ma->b, op->b, op); } return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -73,10 +73,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions - Expr VisitExpr_(const FloorDiv* op) final { + Expr VisitExpr_(const FloorDivNode* op) final { auto e = GetRef(op); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + op = ret.as(); if (op == nullptr) return ret; int shift; const DataType& dtype = op->dtype; @@ -104,7 +104,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return ir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -114,15 +114,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) Expr rdiv = truncdiv(op->a, op->b); Expr rmod = truncmod(op->a, op->b); - return ir::Select::make( + return ir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, rdiv - make_const(dtype, 1)); } } - Expr VisitExpr_(const FloorMod* op) final { + Expr VisitExpr_(const FloorModNode* op) final { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + op = ret.as(); if (op == nullptr) return ret; // Lower floordiv to native truncdiv. int shift; @@ -153,7 +153,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return ir::Select::make(rmod >= 0, rmod, rmod + op->b); + return ir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); } } } else { @@ -164,13 +164,13 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return ir::Select::make( + return ir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); } } - Expr VisitExpr_(const Max* op) final { + Expr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; PVar c; @@ -183,7 +183,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const EQ* op) final { + Expr VisitExpr_(const EQNode* op) final { using namespace arith; PVar x, y; auto e = GetRef(op); @@ -193,7 +193,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const NE* op) final { + Expr VisitExpr_(const NENode* op) final { using namespace arith; PVar x, y; auto e = GetRef(op); @@ -209,8 +209,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // For some targets, LLVM will generate more efficient FMA // instruction with the latter. For example, vmla vs. vmlal // on ARM. - if (const Broadcast* bcast = e.as()) { - if (const Cast* cast = bcast->value.as()) { + if (const BroadcastNode* bcast = e.as()) { + if (const CastNode* cast = bcast->value.as()) { auto should_swap = [&]() { // Maintain behaviour (int8 -> int16, fp16 -> fp32). if (cast->dtype.bits() == cast->value.dtype().bits() * 2) { @@ -228,8 +228,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { }; if (should_swap()) { - Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); - return Cast::make(bcast->dtype, new_bcast); + Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes); + return CastNode::make(bcast->dtype, new_bcast); } } } @@ -237,19 +237,19 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, - const Add* op) { + const AddNode* op) { // emit fma instruction: a * b + c Expr lhs = SwapBroadcastCast(a); Expr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - Expr r = (*fma_)(Call::make( - op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic)); + Expr r = (*fma_)(CallNode::make( + op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { - Expr mul = this->VisitExpr(Mul::make(lhs, rhs)); - return Add::make(mul, this->VisitExpr(c)); + Expr mul = this->VisitExpr(MulNode::make(lhs, rhs)); + return AddNode::make(mul, this->VisitExpr(c)); } } return IRMutatorWithAnalyzer::VisitExpr_(op); diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 4712bcc..d38d1da 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -37,7 +37,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { explicit ThreadAllreduceBuilder(int warp_size) : warp_size_(warp_size) {} - Stmt VisitStmt_(const AttrStmt *op) final { + Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == attr::thread_extent) { thread_extents_.push_back(op); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -45,8 +45,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - const Variable* v = op->node.as(); + op = ret.as(); + const VarNode* v = op->node.as(); if (alloc_remap_.count(v)) { return op->body; } else { @@ -63,37 +63,37 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } - Stmt VisitStmt_(const Evaluate* op) final { + Stmt VisitStmt_(const EvaluateNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - const Call* call = op->value.as(); + op = stmt.as(); + const CallNode* call = op->value.as(); if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { return MakeAllreduce(call); } else { return stmt; } } - Stmt VisitStmt_(const Allocate* op) final { + Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - const Allocate* repl = it->second.as(); + const AllocateNode* repl = it->second.as(); // use volatile access to shared buffer. - stmt = AttrStmt::make( + stmt = AttrStmtNode::make( repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate::make( + stmt = AllocateNode::make( repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt::make( + stmt = AttrStmtNode::make( repl->buffer_var, attr::storage_scope, - StringImm::make("shared"), stmt); + StringImmNode::make("shared"), stmt); return stmt; } else { return stmt; } } - Expr VisitExpr_(const Load* op) final { + Expr VisitExpr_(const LoadNode* op) final { auto it = load_remap_.find(op->buffer_var.get()); if (it != load_remap_.end()) { CHECK(is_zero(op->index)); @@ -115,12 +115,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } }; // make allreduce. - Stmt MakeAllreduce(const Call* call) { + Stmt MakeAllreduce(const CallNode* call) { CHECK(!reduce_combiner_.empty()); const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImm *size_of_args = call->args[0].as(); + const UIntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; @@ -130,26 +130,26 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1+idx]; if (!is_one(cond)) { - values[idx] = Select::make(cond, values[idx], inits[idx]); + values[idx] = SelectNode::make(cond, values[idx], inits[idx]); } types[idx] = values[idx].dtype(); } - std::vector buffers(size); + std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const Variable* buffer = call->args[2+size+idx].as(); + const VarNode* buffer = call->args[2+size+idx].as(); CHECK(buffer); buffers[idx] = buffer; } - std::unordered_set reduce_set; + std::unordered_set reduce_set; for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { - const Variable* v = call->args[i].as(); + const VarNode* v = call->args[i].as(); CHECK(v); reduce_set.insert(v); } size_t nmatch = 0; std::vector vred, vpar; - for (const AttrStmt* attr : thread_extents_) { + for (const AttrStmtNode* attr : thread_extents_) { ThreadEntry e; IterVar iv = Downcast(attr->node); e.scope = runtime::ThreadScope::make(iv->thread_tag); @@ -183,7 +183,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Expr pred = const_true(types[i].lanes()); Var buffer_var = Downcast(call->args[2+size+i]); - stores[i] = Store::make(buffer_var, values[i], 0, pred); + stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); } return SeqStmt::Flatten(stores); } @@ -199,7 +199,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); Expr pred = const_true(types[idx].lanes()); - seq.emplace_back(Store::make( + seq.emplace_back(StoreNode::make( shared_bufs[idx], values[idx], BufIndex(reduce_index, group_index, reduce_extent), pred)); } @@ -210,13 +210,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { CHECK(!load_remap_.count(buffers[idx])); Expr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = Load::make( + load_remap_[buffers[idx]] = LoadNode::make( types[idx], shared_bufs[idx], BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = Allocate::make( + alloc_remap_[buffers[idx]] = AllocateNode::make( shared_bufs[idx], types[idx], {Expr(group_extent), Expr(reduce_extent)}, - pred, Evaluate::make(0)); + pred, EvaluateNode::make(0)); } return SeqStmt::Flatten(seq); } @@ -242,15 +242,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto freduce = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { - b.push_back(Load::make(types[i], shared_bufs[i], + b.push_back(LoadNode::make(types[i], shared_bufs[i], BufIndex(reduce_index + offset, group_index, reduce_extent), const_true())); - a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true())); + a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); std::vector stores(size); for (size_t i = 0; i < size; ++i) { - stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true()); + stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true()); } return SeqStmt::Flatten(stores); }; @@ -259,7 +259,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // reduction with the boundary condition reduce_align = reduce_align >> 1; Expr cond = reduce_index < (reduce_extent - reduce_align); - seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } CHECK(threadx_extent >= 1 && warp_size_ >= 1); @@ -268,7 +268,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { reduce_align > warp_size_) { reduce_align = reduce_align >> 1; Expr cond = reduce_index < reduce_align; - seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } // in warp synchronization. @@ -281,7 +281,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } if (in_warp_seq.size() != 0) { Stmt warp_body = SeqStmt::Flatten(in_warp_seq); - seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); + seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } return SeqStmt::Flatten(seq); @@ -310,10 +310,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return Evaluate::make( - Call::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImm::make(sync)}, - Call::Intrinsic)); + return EvaluateNode::make( + CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImmNode::make(sync)}, + CallNode::Intrinsic)); } // The local buffer index. static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) { @@ -327,12 +327,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { int warp_size_{1}; // surrounding scope of thread extent. - std::vector thread_extents_; + std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; }; LoweredFunc diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index c0b9879..a9b401f 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -37,8 +37,11 @@ inline Expr ConstInt32(size_t index) { } inline Expr StackAlloca(std::string type, size_t num) { - Array args = {StringImm::make(type), ConstInt32(num)}; - return Call::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic); + Array args = {StringImmNode::make(type), ConstInt32(num)}; + return CallNode::make( + DataType::Handle(), + intrinsic::tvm_stack_alloca, + args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -52,17 +55,17 @@ class BuiltinLower : public StmtExprMutator { stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); if (max_shape_stack_ != 0) { - stmt = LetStmt::make( + stmt = LetStmtNode::make( stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { - stmt = LetStmt::make( + stmt = LetStmtNode::make( stack_array_, StackAlloca("array", max_array_stack_), stmt); } if (max_arg_stack_ != 0) { - stmt = LetStmt::make( + stmt = LetStmtNode::make( stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); - stmt = LetStmt::make( + stmt = LetStmtNode::make( stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); } return stmt; @@ -82,10 +85,10 @@ class BuiltinLower : public StmtExprMutator { } } - Stmt VisitStmt_(const Allocate* op) { + Stmt VisitStmt_(const AllocateNode* op) { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (op->new_expr.defined()) return stmt; // Get constant allocation bound. int64_t dev_type; @@ -106,45 +109,48 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = Evaluate::make(Call::make(DataType::Int(32), - intrinsic::tvm_throw_last_error, {}, - Call::Intrinsic)); + Stmt throw_last_error = EvaluateNode::make( + CallNode::make(DataType::Int(32), + intrinsic::tvm_throw_last_error, {}, + CallNode::Intrinsic)); Stmt body = SeqStmt({ - IfThenElse::make(Call::make(DataType::Bool(1), - intrinsic::tvm_handle_is_null, - {op->buffer_var}, Call::PureIntrinsic), - throw_last_error), + IfThenElseNode::make( + CallNode::make(DataType::Bool(1), + intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), op->body}); - Stmt alloca = LetStmt::make( + Stmt alloca = LetStmtNode::make( op->buffer_var, - Call::make(op->buffer_var.dtype(), - "TVMBackendAllocWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), - IntImm::make(DataType::Int(32), op->dtype.code()), - IntImm::make(DataType::Int(32), op->dtype.bits())}, - Call::Extern), + CallNode::make(op->buffer_var.dtype(), + "TVMBackendAllocWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), + IntImmNode::make(DataType::Int(32), op->dtype.code()), + IntImmNode::make(DataType::Int(32), op->dtype.bits())}, + CallNode::Extern), body); - Expr free_op = Call::make(DataType::Int(32), - "TVMBackendFreeWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - op->buffer_var}, - Call::Extern); - Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error); + Expr free_op = CallNode::make(DataType::Int(32), + "TVMBackendFreeWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), + op->buffer_var}, + CallNode::Extern); + Stmt free_stmt = IfThenElseNode::make( + free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); - body = AttrStmt::make( + body = AttrStmtNode::make( op->buffer_var, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::device_context_id) { CHECK(!device_id_.defined()); device_id_ = op->value; @@ -157,7 +163,7 @@ class BuiltinLower : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } } - Expr VisitExpr_(const Call* op) final { + Expr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_call_packed)) { return MakeCallPacked(op); } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { @@ -173,24 +179,24 @@ class BuiltinLower : public StmtExprMutator { } } // call shape - Expr MakeShape(const Call* op) { + Expr MakeShape(const CallNode* op) { size_t stack_begin = run_shape_stack_; run_shape_stack_ += op->args.size(); Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { prep_seq_.emplace_back( - Store::make(stack_shape_, cast(DataType::Int(64), op->args[i]), + StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), ConstInt32(stack_begin +i), const_true(1))); } return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } // make array - Expr MakeArray(const Call* op) { + Expr MakeArray(const CallNode* op) { size_t idx = run_array_stack_; run_array_stack_ += 1; Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); prep_seq_.emplace_back( @@ -233,32 +239,32 @@ class BuiltinLower : public StmtExprMutator { return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. - Expr MakeCallPacked(const Call* op) { + Expr MakeCallPacked(const CallNode* op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; run_arg_stack_ += op->args.size(); // Specially handle the buffer packed intrinsic Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); for (size_t i = 1; i < op->args.size(); ++i) { Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr arg = op->args[i]; DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = Cast::make(api_type, arg); + arg = CastNode::make(api_type, arg); } prep_seq_.emplace_back(TVMStructSet( stack_value_, static_cast(arg_stack_begin + i - 1), intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); - if (api_type.is_handle() && arg.as()) { + if (api_type.is_handle() && arg.as()) { arg_tcode = kStr; } if (IsArrayHandle(arg)) arg_tcode = kArrayHandle; prep_seq_.emplace_back( - Store::make(stack_tcode_, + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } @@ -276,12 +282,12 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1) }; - return Call::make( + return CallNode::make( DataType::Int(32), intrinsic::tvm_call_packed_lowered, - packed_args, Call::Intrinsic); + packed_args, CallNode::Intrinsic); } - Expr MakeCallTracePacked(const Call *op) { + Expr MakeCallTracePacked(const CallNode *op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; @@ -289,14 +295,14 @@ class BuiltinLower : public StmtExprMutator { size_t args_size = op->args.size(); CHECK_GT(args_size, 0); Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); for (size_t i = 1; i < op->args.size(); ++i) { Expr stack_index = ConstInt32(arg_stack_begin + i - 1); Expr arg = op->args[i]; DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = Cast::make(api_type, arg); + arg = CastNode::make(api_type, arg); } prep_seq_.emplace_back(TVMStructSet( stack_value_, static_cast(arg_stack_begin + i - 1), @@ -304,7 +310,7 @@ class BuiltinLower : public StmtExprMutator { int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( - Store::make(stack_tcode_, + StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } @@ -326,17 +332,17 @@ class BuiltinLower : public StmtExprMutator { // Pass traced value. op->args[args_size - 1] }; - return Call::make( + return CallNode::make( op->dtype, intrinsic::tvm_call_trace_packed_lowered, - packed_args, Call::Intrinsic); + packed_args, CallNode::Intrinsic); } private: bool IsArrayHandle(const Expr& arg) { // specially set array handle. - if (const Call* buf = arg.as()) { + if (const CallNode* buf = arg.as()) { if (buf->is_intrinsic(intrinsic::tvm_struct_get) && - buf->args[2].as()->value == intrinsic::kArrAddr) { + buf->args[2].as()->value == intrinsic::kArrAddr) { return true; } } diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 2d24ec4..75f128e 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -76,7 +76,7 @@ namespace ir { // store warp_mem[m * warp_index + (warp_size * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: - WarpStoreCoeffFinder(const Variable* buffer, + WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) : buffer_(buffer), @@ -91,7 +91,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const Store *op) final { + void VisitStmt_(const StoreNode *op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); @@ -129,7 +129,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { } // The buffer variable - const Variable* buffer_; + const VarNode* buffer_; // the warp index Var warp_index_; // the coefficient @@ -155,7 +155,7 @@ class WarpIndexFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const AttrStmt *op) final { + void VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { @@ -190,7 +190,7 @@ class WarpAccessRewriter : protected StmtExprMutator { : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the allocate statement which transforms // warp memory to local memory. - Stmt Rewrite(const Allocate* op) { + Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); CHECK_GT(alloc_size, 0) @@ -202,7 +202,7 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) << "Warp memory must be multiple of warp size"; warp_group_ = alloc_size / (warp_size_ * warp_coeff_); - return Allocate::make( + return AllocateNode::make( op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / warp_size_)}, @@ -211,23 +211,23 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: - Expr Mutate_(const Variable* op) { + Expr Mutate_(const VarNode* op) { CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } - Stmt VisitStmt_(const Store* op) { + Stmt VisitStmt_(const StoreNode* op) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); - return Store::make(op->buffer_var, op->value, local_index, op->predicate); + return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Load* op) { + Expr Mutate_(const LoadNode* op) { if (op->buffer_var.get() == buffer_) { Expr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); @@ -235,12 +235,12 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK(!ExprUseVar(local_index, {warp_index_.get()})) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; - Expr load_value = Load::make( + Expr load_value = LoadNode::make( op->dtype, op->buffer_var, local_index, op->predicate); - return Call::make(load_value.dtype(), + return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, {load_value, group}, - Call::Intrinsic); + CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } @@ -256,7 +256,7 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK(GetRamp1Base(index, index.dtype().lanes(), &base)); std::tie(local_index, group) = SplitIndexByGroup(base); local_index = - Ramp::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); + RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } Expr m = make_const(index.dtype(), warp_coeff_); @@ -281,7 +281,7 @@ class WarpAccessRewriter : protected StmtExprMutator { // the warp size int warp_size_{0}; // The buffer variable - const Variable* buffer_; + const VarNode* buffer_; // Warp index Var warp_index_; // the coefficient m @@ -301,13 +301,13 @@ class BindVarBoundInfo : public StmtVisitor { explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} - void VisitStmt_(const For* op) final { + void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); StmtVisitor::VisitStmt_(op); } - void VisitStmt_(const AttrStmt* op) { + void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); @@ -325,7 +325,7 @@ class BindVarBoundInfo : public StmtVisitor { // internal analyzer. arith::Analyzer* analyzer_; // variable domain - std::unordered_map var_dom_; + std::unordered_map var_dom_; }; // Mutator to change the read pattern @@ -345,7 +345,7 @@ class WarpMemoryRewriter : private StmtMutator { } private: - Stmt VisitStmt_(const Allocate* op) { + Stmt VisitStmt_(const AllocateNode* op) { if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); return rewriter.Rewrite(op); @@ -354,27 +354,27 @@ class WarpMemoryRewriter : private StmtMutator { } } - Stmt VisitStmt_(const AttrStmt* op) { + Stmt VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { - const Variable* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + const VarNode* buf = op->node.as(); + StorageScope scope = StorageScope::make(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt::make( - op->node, op->attr_key, StringImm::make("local"), op->body); + op = ret.as(); + return AttrStmtNode::make( + op->node, op->attr_key, StringImmNode::make("local"), op->body); } } return StmtMutator::VisitStmt_(op); } int warp_size_{0}; - std::unordered_set warp_buffer_; + std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain - std::unordered_map var_dom_; + std::unordered_map var_dom_; }; LoweredFunc diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 03a6035..56609bb 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -36,7 +36,7 @@ namespace tvm { namespace ir { inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) { - return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0)); + return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); } LoweredFunc MakeAPI(Stmt body, @@ -44,7 +44,7 @@ LoweredFunc MakeAPI(Stmt body, Array api_args, int num_unpacked_args, bool is_restricted) { - const Stmt nop = Evaluate::make(0); + const Stmt nop = EvaluateNode::make(0); int num_args = static_cast(api_args.size()); CHECK_LE(num_unpacked_args, num_args); int num_packed_args = num_args - num_unpacked_args; @@ -62,23 +62,23 @@ LoweredFunc MakeAPI(Stmt body, // seq_init gives sequence of initialization // seq_check gives sequence of later checks after init std::vector seq_init, seq_check; - std::unordered_map vmap; + std::unordered_map vmap; ArgBinder binder(&vmap); // --------------------------- // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImm::make(DataType::Int(32), i), - IntImm::make(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImmNode::make(DataType::Int(32), i), + IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - Expr res = Call::make( + Expr res = CallNode::make( api_type, intrinsic::tvm_struct_get, call_args, - Call::PureIntrinsic); + CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { - res = Cast::make(t, res); + res = CastNode::make(t, res); } return res; }; @@ -86,7 +86,7 @@ LoweredFunc MakeAPI(Stmt body, auto f_arg_decl = [&](int i) { std::ostringstream os; os << "arg" << i; - const Variable* v = api_args[i].as(); + const VarNode* v = api_args[i].as(); return Var(os.str(), v ? v->dtype: DataType::Handle()); }; // --------------------------- @@ -110,40 +110,40 @@ LoweredFunc MakeAPI(Stmt body, Var v_arg = f_arg_decl(i); if (i < num_packed_args) { // Value loads - seq_init.emplace_back(LetStmt::make( + seq_init.emplace_back(LetStmtNode::make( v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmt::make( - tcode, Load::make( + seq_init.emplace_back(LetStmtNode::make( + tcode, LoadNode::make( DataType::Int(32), v_packed_arg_type_ids, - IntImm::make(DataType::Int(32), i), const_true(1)), + IntImmNode::make(DataType::Int(32), i), const_true(1)), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name << ": Expect arg[" << i << "] to be pointer"; seq_check.emplace_back( - AssertStmt::make(tcode == kHandle || + AssertStmtNode::make(tcode == kHandle || tcode == kNDArrayContainer || tcode == kArrayHandle || tcode == kNull, msg.str(), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop)); + seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name << ": Expect arg[" << i << "] to be float"; seq_check.emplace_back( - AssertStmt::make(tcode == kDLFloat, msg.str(), nop)); + AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop)); } } else { args.push_back(v_arg); } // add checks for functions. - if (api_args[i].as()) { + if (api_args[i].as()) { var_defs.emplace_back(std::make_pair(Downcast(api_args[i]), v_arg)); } else { // Buffer checks @@ -184,22 +184,22 @@ LoweredFunc MakeAPI(Stmt body, n->handle_data_type = binder.def_handle_dtype(); n->is_packed_func = num_unpacked_args == 0; n->is_restricted = is_restricted; - body = AttrStmt::make( + body = AttrStmtNode::make( make_zero(DataType::Int(32)), attr::compute_scope, - StringImm::make(name + "_compute_"), body); + StringImmNode::make(name + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - Expr node = StringImm::make("default"); + Expr node = StringImmNode::make("default"); CHECK(vmap.count(device_type.get())); - seq_check.push_back(AttrStmt::make( + seq_check.push_back(AttrStmtNode::make( node, attr::device_context_id, device_id, nop)); - seq_check.push_back(AttrStmt::make( + seq_check.push_back(AttrStmtNode::make( node, attr::device_context_type, device_type, nop)); - Stmt set_device = IfThenElse::make( - device_type != kDLCPU, Evaluate::make(Call::make( + Stmt set_device = IfThenElseNode::make( + device_type != kDLCPU, EvaluateNode::make(CallNode::make( DataType::Int(32), intrinsic::tvm_call_packed, - {StringImm::make(runtime::symbol::tvm_set_device), - device_type, device_id}, Call::Intrinsic))); + {StringImmNode::make(runtime::symbol::tvm_set_device), + device_type, device_id}, CallNode::Intrinsic))); body = SeqStmt({set_device, body}); } n->body = MergeNest( @@ -222,28 +222,28 @@ class DeviceTypeBinder: public StmtExprMutator { explicit DeviceTypeBinder(int device_type) : device_type_(device_type) {} - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::device_context_type) { - if (const Variable* var = op->value.as()) { + if (const VarNode* var = op->value.as()) { var_ = var; Expr value = make_const(op->value.dtype(), device_type_); Stmt body = StmtExprMutator::VisitStmt_(op); var_ = nullptr; std::ostringstream os; os << "device_type need to be " << device_type_; - return AssertStmt::make(op->value == value, os.str(), body); + return AssertStmtNode::make(op->value == value, os.str(), body); } } return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const IfThenElse* op) final { + Stmt VisitStmt_(const IfThenElseNode* op) final { // eager simplify if guard. Stmt res = StmtExprMutator::VisitStmt_(op); - op = res.as(); + op = res.as(); if (is_zero(op->condition)) { if (op->else_case.defined()) return op->else_case; - return Evaluate::make(0); + return EvaluateNode::make(0); } if (is_one(op->condition)) { return op->then_case; @@ -251,17 +251,17 @@ class DeviceTypeBinder: public StmtExprMutator { return res; } - Expr VisitExpr_(const NE* op) final { + Expr VisitExpr_(const NENode* op) final { // eager check NE for device check Expr res = StmtExprMutator::VisitExpr_(op); - op = res.as(); + op = res.as(); if (ir::Equal(op->a, op->b)) { return make_const(op->dtype, false); } return res; } - Expr VisitExpr_(const Variable* op) final { + Expr VisitExpr_(const VarNode* op) final { if (op == var_) { return make_const(op->dtype, device_type_); } else { @@ -270,7 +270,7 @@ class DeviceTypeBinder: public StmtExprMutator { } public: - const Variable* var_{nullptr}; + const VarNode* var_{nullptr}; int device_type_; }; diff --git a/src/pass/remap_thread_axis.cc b/src/pass/remap_thread_axis.cc index 92b941a..2a486b5 100644 --- a/src/pass/remap_thread_axis.cc +++ b/src/pass/remap_thread_axis.cc @@ -42,28 +42,28 @@ class ThreadAxisRewriter : private StmtExprMutator { } private: - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); auto it = tmap_.find(iv->thread_tag); if (it != tmap_.end()) { const IterVar& new_iv = it->second; - const Variable* v = iv->var.get(); + const VarNode* v = iv->var.get(); if (!vmap_.count(v)) { vmap_[v] = new_iv->var; } else { CHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); - return AttrStmt::make( + return AttrStmtNode::make( new_iv, op->attr_key, op->value, body); } } return StmtExprMutator::VisitStmt_(op); } - Expr VisitExpr_(const Variable* op) final { + Expr VisitExpr_(const VarNode* op) final { auto it = vmap_.find(op); if (it != vmap_.end()) return it->second; return StmtExprMutator::VisitExpr_(op); @@ -71,14 +71,14 @@ class ThreadAxisRewriter : private StmtExprMutator { // The thread map const std::unordered_map& tmap_; // variable map - std::unordered_map vmap_; + std::unordered_map vmap_; }; LoweredFunc RemapThreadAxis(LoweredFunc f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { - const StringImm* str = kv.first.as(); + const StringImmNode* str = kv.first.as(); CHECK(str != nullptr); tmap[str->value] = kv.second; } diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index 6891870..3c9114d 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -32,28 +32,28 @@ namespace ir { // Mark the statment of each stage. class NoOpRemover : public StmtMutator { public: - Stmt VisitStmt_(const LetStmt* op) final { + Stmt VisitStmt_(const LetStmtNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; } - Stmt VisitStmt_(const AttrStmt* op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_debug_skip_region") { return MakeEvaluate(0); } Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; } - Stmt VisitStmt_(const IfThenElse* op) final { + Stmt VisitStmt_(const IfThenElseNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (op->else_case.defined()) { if (is_no_op(op->else_case)) { if (is_no_op(op->then_case)) { return MakeEvaluate(op->condition); } else { - return IfThenElse::make(op->condition, op->then_case); + return IfThenElseNode::make(op->condition, op->then_case); } } else { return stmt; @@ -66,32 +66,32 @@ class NoOpRemover : public StmtMutator { } } } - Stmt VisitStmt_(const For* op) final { + Stmt VisitStmt_(const ForNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (is_zero(op->extent)) { - return Evaluate::make(0); + return EvaluateNode::make(0); } return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; } - Stmt VisitStmt_(const Allocate* op) final { + Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const ProducerConsumer* op) final { + Stmt VisitStmt_(const ProducerConsumerNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? op->body : stmt; } - Stmt VisitStmt_(const Realize* op) final { + Stmt VisitStmt_(const RealizeNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? op->body : stmt; } - Stmt VisitStmt_(const Evaluate* op) final { + Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) return GetRef(op); - return Evaluate::make(0); + return EvaluateNode::make(0); } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -128,9 +128,9 @@ class NoOpRemover : public StmtMutator { private: Stmt MakeEvaluate(Expr value) { if (HasSideEffect(value)) { - return Evaluate::make(value); + return EvaluateNode::make(value); } else { - return Evaluate::make(0); + return EvaluateNode::make(0); } } Stmt MakeEvaluate(const Array& values) { @@ -138,13 +138,13 @@ class NoOpRemover : public StmtMutator { for (Expr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { - stmt = SeqStmt({stmt, Evaluate::make(e)}); + stmt = SeqStmt({stmt, EvaluateNode::make(e)}); } else { - stmt = Evaluate::make(e); + stmt = EvaluateNode::make(e); } } } - return stmt.defined() ? stmt : Evaluate::make(0); + return stmt.defined() ? stmt : EvaluateNode::make(0); } }; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 0a27671..c38fac1 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -35,14 +35,14 @@ class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. - bool VisitExpr_(const Select* op) { + bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } - bool VisitExpr_(const Call* op) { + bool VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const Load* l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); } else if (op->is_pure()) { for (Expr e : op->args) { @@ -53,53 +53,53 @@ class UnsafeExprDetector : public ExprFunctor { return true; } } - bool VisitExpr_(const Load* op) { + bool VisitExpr_(const LoadNode* op) { // Load is considered unsafe. return true; } - bool VisitExpr_(const Add* op) final { return BinaryOp(op); } - bool VisitExpr_(const Sub* op) final { return BinaryOp(op); } - bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } - bool VisitExpr_(const Div* op) final { return BinaryOp(op); } - bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } - bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); } - bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); } - bool VisitExpr_(const Min* op) final { return BinaryOp(op); } - bool VisitExpr_(const Max* op) final { return BinaryOp(op); } - bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } - bool VisitExpr_(const NE* op) final { return BinaryOp(op); } - bool VisitExpr_(const LT* op) final { return BinaryOp(op); } - bool VisitExpr_(const LE* op) final { return BinaryOp(op); } - bool VisitExpr_(const GT* op) final { return BinaryOp(op); } - bool VisitExpr_(const GE* op) final { return BinaryOp(op); } - bool VisitExpr_(const And* op) final { return BinaryOp(op); } - bool VisitExpr_(const Or* op) final { return BinaryOp(op); } - bool VisitExpr_(const Not* op) final { + bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const NENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const LENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } - bool VisitExpr_(const Let* op) final { + bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } - bool VisitExpr_(const Cast* op) final { + bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } - bool VisitExpr_(const Broadcast* op) final { + bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } - bool VisitExpr_(const Ramp* op) final { + bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } - bool VisitExpr_(const Shuffle* op) final { + bool VisitExpr_(const ShuffleNode* op) final { for (Expr e : op->vectors) { if (VisitExpr(e)) return true; } return false; } - bool VisitExpr_(const Variable* op) final { return false; } - bool VisitExpr_(const UIntImm* op) final { return false; } - bool VisitExpr_(const IntImm* op) final { return false; } - bool VisitExpr_(const FloatImm* op) final { return false; } - bool VisitExpr_(const StringImm* op) final { return false; } + bool VisitExpr_(const VarNode* op) final { return false; } + bool VisitExpr_(const UIntImmNode* op) final { return false; } + bool VisitExpr_(const IntImmNode* op) final { return false; } + bool VisitExpr_(const FloatImmNode* op) final { return false; } + bool VisitExpr_(const StringImmNode* op) final { return false; } private: template @@ -110,19 +110,19 @@ class UnsafeExprDetector : public ExprFunctor { class UnsafeSelectRewriter : public StmtExprMutator { public: - Expr VisitExpr_(const Select* op) { + Expr VisitExpr_(const SelectNode* op) { Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as