* \param expr The expression of interest.
* \return the result of the analysis.
*/
- ConstIntBound operator()(const Expr& expr);
+ ConstIntBound operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
- std::function<void()> EnterConstraint(const Expr& constraint);
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
* \param expr The expression of interest.
* \return the result of the analysis.
*/
- ModularSet operator()(const Expr& expr);
+ ModularSet operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
- std::function<void()> EnterConstraint(const Expr& constraint);
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
* \param expr The expression of interest.
* \return the result of the analysis.
*/
- Expr operator()(const Expr& expr);
+ PrimExpr operator()(const PrimExpr& expr);
/*!
* \brief Update binding of var to a new expression.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
- const Expr& new_expr,
+ const PrimExpr& new_expr,
bool override = false);
- std::function<void()> EnterConstraint(const Expr& constraint);
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
private:
friend class Analyzer;
* \param expr The expression of interest.
* \return the result of the analysis.
*/
- Expr operator()(const Expr& expr);
+ PrimExpr operator()(const PrimExpr& expr);
/*!
* \brief Update binding of var to a new expression.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
- const Expr& new_expr,
+ const PrimExpr& new_expr,
bool override = false);
private:
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
- ConstraintContext(Analyzer* analyzer, Expr constraint)
+ ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
- Expr constraint_;
+ PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
*/
Range cover_range(Range max_range) const;
/*! \return Lower bound of the set */
- Expr min() const;
+ PrimExpr min() const;
/*! \return upper bound of the set */
- Expr max() const;
+ PrimExpr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
- Expr point_value() const;
+ PrimExpr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \param point The point in the set.
* \return construct a single point set
*/
- static IntSet single_point(Expr point);
+ static IntSet single_point(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
- static IntSet vector(Expr vec);
+ static IntSet vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \param max The maximum value of the interval.
* \return constructed set.
*/
- static IntSet interval(Expr min, Expr max);
+ static IntSet interval(PrimExpr min, PrimExpr max);
};
/*!
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
- IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
+ IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
private:
friend class Analyzer;
* \param var The variable.
* \param expr The expression we bind to.
*/
- void Bind(const VarExpr& var, const Expr& expr);
+ void Bind(const Var& var, const PrimExpr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
* \param var The variable.
* \param range The range we bind to.
*/
- void Bind(const VarExpr& var, const Range& range);
+ void Bind(const Var& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
- bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
+ bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
- bool CanProve(const Expr& cond);
+ bool CanProve(const PrimExpr& cond);
/*!
* \brief Simplify expr.
*
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
- Expr Simplify(const Expr& expr);
+ PrimExpr Simplify(const PrimExpr& expr);
};
//-----------------------------------------------
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
-using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
+using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(
- Expr e,
+ PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
-IntSet DeduceBound(Expr v, Expr cond,
+IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
-IntSet DeduceBound(Expr v, Expr cond,
+IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
-Array<Expr> DetectLinearEquation(const Expr& e,
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
const Array<Var>& vars);
/*!
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
-Array<Expr> DetectClipBound(const Expr& e,
+Array<PrimExpr> DetectClipBound(const PrimExpr& e,
const Array<Var>& vars);
// implementation
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
- Expr expr = val;
+ PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
- Expr expr = val;
+ PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
- Expr expr = val;
+ PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
- TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
+ TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
- TVM_DLL Expr access_ptr(int access_mask,
+ TVM_DLL PrimExpr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
- Expr offset = make_const(DataType::Int(32), 0)) const;
+ PrimExpr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
- TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
+ TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
- TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
+ TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The shape of the buffer */
- Array<Expr> shape;
+ Array<PrimExpr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
- Array<Expr> strides;
+ Array<PrimExpr> strides;
/*! \brief The offset in terms of number of dtype elements (including lanes) */
- Expr elem_offset;
+ PrimExpr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
DataType dtype,
- Array<Expr> shape,
- Array<Expr> strides,
- Expr elem_offset,
+ Array<PrimExpr> shape,
+ Array<PrimExpr> strides,
+ PrimExpr elem_offset,
std::string name,
std::string scope,
int data_alignment,
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
-TVM_DLL Buffer decl_buffer(Array<Expr> shape,
+TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
- Array<Expr> keys_array;
+ Array<PrimExpr> keys_array;
/*! \brief Options for this target */
- Array<Expr> options_array;
+ Array<PrimExpr> options_array;
/*! \brief Collection of imported libs */
- Array<Expr> libs_array;
+ Array<PrimExpr> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
/*! \brief Describes how source axes can be mapped to the destination axes,
* e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
*/
- Array<Expr> forward_rule;
+ Array<PrimExpr> forward_rule;
/*! \brief Describes how destination axes can be mapped to the source axes */
- Array<Expr> backward_rule;
+ Array<PrimExpr> backward_rule;
/*! \brief The source layout */
Layout src_layout;
explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
// Given the source shape, infer the destination shape.
- TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
+ TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
// Given the destination shape, recover the source shape.
- TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
+ TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const;
// Given the destination indices, infer the destination indices.
- TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
+ TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const;
// Given the destination indices, recover the source indices.
- TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
+ TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
/*!
* \brief access the internal node container
namespace tvm {
-/*! \brief Base node of all expressions. */
-class ExprNode : public Object {
+/*!
+ * \brief Base node of all primitive expressions.
+ *
+ * A primitive expression deals with low-level
+ * POD data types and handles without
+ * doing life-cycle management for objects.
+ *
+ * PrimExpr is used in the low-level code
+ * optimizations and integer analysis.
+ *
+ * \sa PrimExpr
+ */
+class PrimExprNode : public Object {
public:
/*! \brief The data type of the expression. */
DataType dtype;
- static constexpr const char* _type_key = "Expr";
- TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object);
+ static constexpr const char* _type_key = "PrimExpr";
+ TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
};
-/*! \brief Container of all expressions. */
-class Expr : public ObjectRef {
+/*!
+ * \brief Container of all primitive expressions.
+ * \sa PrimExprNode
+ */
+class PrimExpr : public ObjectRef {
public:
- Expr() {}
- explicit Expr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+ PrimExpr() {}
+ explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
- TVM_DLL Expr(int32_t value); // NOLINT(*)
+ TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
- TVM_DLL Expr(float value); // NOLINT(*)
+ TVM_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
- TVM_DLL Expr(std::string str); // NOLINT(*)
+ TVM_DLL PrimExpr(std::string str); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
- return static_cast<const ExprNode*>(get())->dtype;
+ return static_cast<const PrimExprNode*>(get())->dtype;
}
- /*! \brief type indicate the container type */
- using ContainerType = ExprNode;
+ using ContainerType = PrimExprNode;
};
/*! \brief Base node of all statements. */
* - Let
* - LetStmt
*/
-class VarNode : public ExprNode {
+class VarNode : public PrimExprNode {
public:
/*!
* \brief The hint to the variable name.
}
static constexpr const char* _type_key = "Variable";
- TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode);
};
/*! \brief a named variable in TVM */
-class Var : public Expr {
+class Var : public PrimExpr {
public:
- explicit Var(ObjectPtr<Object> n) : Expr(n) {}
+ explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32));
/*!
using ContainerType = VarNode;
};
-// Backward compatibility, will be removed later.
-using VarExpr = Var;
-using BaseExprNode = ExprNode;
-using ExprHash = ObjectHash;
-using ExprEqual = ObjectEqual;
-
class Integer;
/*! \brief ExprNode: constant integer. */
-class IntImmNode : public ExprNode {
+class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;
TVM_DLL static Integer make(DataType t, int64_t value);
static constexpr const char* _type_key = "IntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
/*!
* This is used to store and automate type check
* attributes that must be constant integer.
*/
-class Integer : public Expr {
+class Integer : public PrimExpr {
public:
- Integer() : Expr() {}
+ Integer() : PrimExpr() {}
/*!
* \brief constructor from node.
*/
- explicit Integer(ObjectPtr<Object> node) : Expr(node) {}
+ explicit Integer(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Construct integer from int value.
*/
- Integer(int value) : Expr(value) {} // NOLINT(*)
+ Integer(int value) : PrimExpr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
class RangeNode : public Object {
public:
/*! \brief beginning of the node */
- Expr min;
+ PrimExpr min;
/*! \brief the extend of range */
- Expr extent;
+ PrimExpr extent;
/*! \brief constructor */
RangeNode() {}
- RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
+ RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
* \param begin The begin of the range.
* \param end The end of the range.
*/
- TVM_DLL Range(Expr begin, Expr end);
+ TVM_DLL Range(PrimExpr begin, PrimExpr end);
/*!
* \brief construct a new range with min and extent
* The corresponding constructor is removed,
* \param min The minimum range.
* \param extent The extent of the range.
*/
- static Range make_by_min_extent(Expr min, Expr extent);
+ static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
// declare range.
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
/*!
* \return the corresponding var in the IterVar.
*/
- inline operator Expr() const;
+ inline operator PrimExpr() const;
/*! \brief specify container node */
using ContainerType = IterVarNode;
};
return static_cast<const IterVarNode*>(data_.get());
}
-inline IterVar::operator Expr() const {
+inline IterVar::operator PrimExpr() const {
return (*this)->var;
}
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
-inline Expr make_const(DataType t, ValueType value);
+inline PrimExpr make_const(DataType t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
-inline Expr make_zero(DataType t);
+inline PrimExpr make_zero(DataType t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
-inline Expr const_true(int lanes = 1) {
+inline PrimExpr const_true(int lanes = 1) {
return make_const(DataType::UInt(1, lanes), 1);
}
/*!
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
-inline Expr const_false(int lanes = 1) {
+inline PrimExpr const_false(int lanes = 1) {
return make_const(DataType::UInt(1, lanes), 0);
}
/*!
* \return the address to the int expression,
* return nullptr, if x is not IntImm.
*/
-inline const int64_t* as_const_int(const Expr& x) {
+inline const int64_t* as_const_int(const PrimExpr& x) {
if (!x.defined()) return nullptr;
if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
return &(op->value);
* \return the address to the int expression,
* return nullptr, if x is not UIntImm.
*/
-inline const uint64_t* as_const_uint(const Expr& x) {
+inline const uint64_t* as_const_uint(const PrimExpr& x) {
if (!x.defined()) return nullptr;
if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
return &(op->value);
* \param value the value to be compared against.
* \return whether x is constant expression.
*/
-inline bool is_const_int(const Expr& x, int64_t value);
+inline bool is_const_int(const PrimExpr& x, int64_t value);
/*!
* \brief Check whether stmt is nop.
* \note This only return true for integer types.
* \return whether x is constant 1
*/
-inline bool is_one(const Expr& x) {
+inline bool is_one(const PrimExpr& x) {
return is_const_int(x, 1);
}
* \return whether x is constant 0
* \note This only return true for integer types.
*/
-inline bool is_zero(const Expr& x) {
+inline bool is_zero(const PrimExpr& x) {
return is_const_int(x, 0);
}
* \note This only return true for integer types.
* \return whether x is constant
*/
-inline bool is_const(const Expr& x);
+inline bool is_const(const PrimExpr& x);
/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
*/
-TVM_DLL Expr max_value(const DataType& dtype);
+TVM_DLL PrimExpr max_value(const DataType& dtype);
/*!
* Query the minimum possible value of dtype.
* \param dtype The data type.
* \return the minimum possible value in this format.
*/
-TVM_DLL Expr min_value(const DataType& dtype);
+TVM_DLL PrimExpr min_value(const DataType& dtype);
/*!
* \brief Check whether x is a constant power of two
* \param shift The output shift if x is power of two.
* \return whether x is constant power of two
*/
-TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
+TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
/*!
* \brief cast value to type.
* \return The result expression.
* \note This function may return value if the type is the same.
*/
-TVM_DLL Expr cast(const DataType& t, Expr value);
+TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value);
/*!
* \brief perform reinterpret cast value to type.
*
* \return The result expression.
* \note This function may return value if the type is the same.
*/
-TVM_DLL Expr reinterpret(const DataType& t, Expr value);
+TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value);
/*!
* \brief add operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator+(Expr a, Expr b);
+TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
/*!
* \brief subtraction operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator-(Expr a, Expr b);
+TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
/*!
* \brief negation.
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator-(Expr a);
+TVM_DLL PrimExpr operator-(PrimExpr a);
/*!
* \brief multiplication operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator*(Expr a, Expr b);
+TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
/*!
* \brief division operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator/(Expr a, Expr b);
+TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
/*!
* \brief left shift operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator<<(Expr a, Expr b);
+TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
/*!
* \brief right shift operator
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator>>(Expr a, Expr b);
+TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
/*!
* \brief greater
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator>(Expr a, Expr b);
+TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
/*!
* \brief greater_equal
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator>=(Expr a, Expr b);
+TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
/*!
* \brief less
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator<(Expr a, Expr b);
+TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
/*!
* \brief less_equal
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator<=(Expr a, Expr b);
+TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
/*!
* \brief equal
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator==(Expr a, Expr b);
+TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
/*!
* \brief not_equal
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator!=(Expr a, Expr b);
+TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
/*!
* \brief and
*
* \return The result expression.
* \note This operator does eager constant folding.
*/
-TVM_DLL Expr operator&&(Expr a, Expr b);
+TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
/*!
* \brief or
*
* \return The result expression.
* \note This operator does eager constant folding.
*/
-TVM_DLL Expr operator||(Expr a, Expr b);
+TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
/*!
* \brief not
*
* \return The result expression.
* \note This operator does eager constant folding.
*/
-TVM_DLL Expr operator!(Expr a);
+TVM_DLL PrimExpr operator!(PrimExpr a);
/*!
* \brief compute division in C semantics.
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr div(Expr a, Expr b);
+TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b);
/*!
* \brief compute trunc(a / b)
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr truncdiv(Expr a, Expr b);
+TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b);
/*!
* \brief compute the remainder of truncdiv
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr truncmod(Expr a, Expr b);
+TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b);
/*!
* \brief compute floor(a / b) where a and b are non-negative.
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr indexdiv(Expr a, Expr b);
+TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b);
/*!
* \brief compute the remainder floor(a / b) where a and b are non-negative.
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr indexmod(Expr a, Expr b);
+TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b);
/*!
* \brief compute floor(a / b)
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr floordiv(Expr a, Expr b);
+TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b);
/*!
* \brief compute the remainder of floordiv
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr floormod(Expr a, Expr b);
+TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b);
/*!
* \brief take maximum of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr max(Expr a, Expr b);
+TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b);
/*!
* \brief take minimum of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr min(Expr a, Expr b);
+TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise and of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator&(Expr a, Expr b);
+TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise or of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator|(Expr a, Expr b);
+TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise xor of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator^(Expr a, Expr b);
+TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise negation of two values
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr operator~(Expr a);
+TVM_DLL PrimExpr operator~(PrimExpr a);
/*!
* \brief Conditional expression.
*
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
-TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value);
+TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value);
/*!
* \brief Mark condition as likely.
* \param cond The condition
* \return The marked expression.
*/
-TVM_DLL Expr likely(Expr cond);
+TVM_DLL PrimExpr likely(PrimExpr cond);
/*!
* \brief Calculate power(x, y)
* \param x The left operand.
* \param y The right operand.
*/
-TVM_DLL Expr pow(Expr x, Expr y);
+TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y);
/*!
* \brief Calculate absolute value of x.
* \param x The input data
*
* \return The aboslute value of input data x
*/
-TVM_DLL Expr abs(Expr x);
+TVM_DLL PrimExpr abs(PrimExpr x);
/*!
* \brief Check if x is NaN.
* \param x The input data
* \return The result expression.
*/
-TVM_DLL Expr isnan(Expr x);
+TVM_DLL PrimExpr isnan(PrimExpr x);
/*!
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr sum(PrimExpr source, Array<IterVar> axis);
/*!
* \brief logical And of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr all(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr all(PrimExpr source, Array<IterVar> axis);
/*!
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr any(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr any(PrimExpr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr max(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr max(PrimExpr source, Array<IterVar> axis);
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr min(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr min(PrimExpr source, Array<IterVar> axis);
/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
*/
-TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr prod(PrimExpr source, Array<IterVar> axis);
/*!
* \brief Calculate floor(x)
* \param x The input expression.
* \return The result expression.
*/
-TVM_DLL Expr floor(Expr x);
+TVM_DLL PrimExpr floor(PrimExpr x);
/*!
* \brief Calculate ceil(x)
* \param x The input expression.
* \return The result expression.
*/
-TVM_DLL Expr ceil(Expr x);
+TVM_DLL PrimExpr ceil(PrimExpr x);
/*!
* \brief Calculate round(x)
* \param x The input expression.
* \return The result expression.
*/
-TVM_DLL Expr round(Expr x);
+TVM_DLL PrimExpr round(PrimExpr x);
/*!
* \brief Calculates std::nearbyint(x)
* \return The result expression.
* This is a faster alternate to round.
*/
-TVM_DLL Expr nearbyint(Expr x);
+TVM_DLL PrimExpr nearbyint(PrimExpr x);
/*!
* \brief Calculate trunc(x)
* \param x The input expression.
* \return The result expression.
*/
-TVM_DLL Expr trunc(Expr x);
+TVM_DLL PrimExpr trunc(PrimExpr x);
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName) \
- inline Expr OpName(Expr x) { \
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+ inline PrimExpr OpName(PrimExpr x) { \
return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
- } \
+ } \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(atan);
// Implementation details after this
-inline bool is_const(const Expr& x) {
+inline bool is_const(const PrimExpr& x) {
if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
- const Expr& val = op->value;
+ const PrimExpr& val = op->value;
if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
return true;
}
return false;
}
-inline bool is_positive_const(const Expr& a) {
+inline bool is_positive_const(const PrimExpr& a) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
} else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
}
}
-inline bool is_negative_const(const Expr& a) {
+inline bool is_negative_const(const PrimExpr& a) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value < 0;
} else {
}
}
-inline bool is_const_int(const Expr& x, int64_t value) {
+inline bool is_const_int(const PrimExpr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
- const Expr& val = op->value;
+ const PrimExpr& val = op->value;
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImmNode>()) {
}
template<typename ValueType>
-inline Expr MakeConstScalar(DataType t, ValueType value) {
+inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
return ir::FloatImmNode::make(t, static_cast<double>(value));
LOG(FATAL) << "cannot make const for type " << t;
- return Expr();
+ return PrimExpr();
}
template<typename ValueType, typename>
-inline Expr make_const(DataType t, ValueType value) {
+inline PrimExpr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
}
}
-inline Expr make_zero(DataType t) {
+inline PrimExpr make_zero(DataType t) {
if (t.is_handle()) {
return reinterpret(t, make_const(DataType::UInt(64), 0));
}
}
// additional const expression overloading
-#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
- inline Expr Name(Expr& a, Expr b) { \
- a = OpFunc(a, b); \
- return a; \
+#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
+ inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\
+ a = OpFunc(a, b); \
+ return a; \
}
#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
- inline Expr Name(const Expr& a, float b) { \
- return Name(a, Expr(b)); \
+ inline PrimExpr Name(const PrimExpr& a, float b) { \
+ return Name(a, PrimExpr(b)); \
} \
- inline Expr Name(float a, const Expr& b) { \
- return Name(Expr(a), b); \
+ inline PrimExpr Name(float a, const PrimExpr& b) { \
+ return Name(PrimExpr(a), b); \
} \
- inline Expr Name(int a, const Expr& b) { \
+ inline PrimExpr Name(int a, const PrimExpr& b) { \
return Name(make_const(b.dtype(), a), b); \
} \
- inline Expr Name(const Expr& a, int b) { \
+ inline PrimExpr Name(const PrimExpr& a, int b) { \
return Name(a, make_const(a.dtype(), b)); \
} \
- inline Expr Name(const Expr& a, double b) { \
- return Name(a, make_const(DataType::Float(64), b)); \
+ inline PrimExpr Name(const PrimExpr& a, double b) {\
+ return Name(a, make_const(DataType::Float(64), b)); \
}
-#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
- inline Expr Name(const Expr& a, bool b) { \
- return Name(a, Expr(b)); \
- } \
- inline Expr Name(bool a, const Expr& b) { \
- return Name(Expr(a), b); \
+#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
+ inline PrimExpr Name(const PrimExpr& a, bool b) { \
+ return Name(a, PrimExpr(b)); \
+ } \
+ inline PrimExpr Name(bool a, const PrimExpr& b) { \
+ return Name(PrimExpr(a), b); \
}
-#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
- inline Expr Name(const Expr& a, int b) { \
- return Name(a, make_const(a.dtype(), b)); \
- } \
- inline Expr Name(int a, const Expr& b) { \
- return Name(make_const(b.dtype(), a), b); \
+#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
+ inline PrimExpr Name(const PrimExpr& a, int b) { \
+ return Name(a, make_const(a.dtype(), b)); \
+ } \
+ inline PrimExpr Name(int a, const PrimExpr& b) { \
+ return Name(make_const(b.dtype(), a), b); \
}
// The second template argument is necessary to make sure the
// code compiles lazily by the compiler during invocation.
template<typename TB>
-inline Expr operator/(const Expr& a, const TB& b) {
+inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
-inline Expr operator/=(const Expr& a, const TB& b) {
+inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
template<typename TB>
-inline Expr operator%(const Expr& a, const TB& b) {
+inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
DivAmbiguityError(a);
return a;
}
using VarNode = tvm::VarNode;
/*! \brief constant unsigned integer. */
-class UIntImmNode : public ExprNode {
+class UIntImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
uint64_t value;
v->Visit("value", &value);
}
- TVM_DLL static Expr make(DataType t, uint64_t value);
+ TVM_DLL static PrimExpr make(DataType t, uint64_t value);
static constexpr const char* _type_key = "UIntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
};
/*! \brief Floating point constants. */
-class FloatImmNode : public ExprNode {
+class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
v->Visit("value", &value);
}
- TVM_DLL static Expr make(DataType t, double value);
+ TVM_DLL static PrimExpr make(DataType t, double value);
static constexpr const char* _type_key = "FloatImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*! \brief String constants, only used in asserts. */
-class StringImmNode : public ExprNode {
+class StringImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
std::string value;
v->Visit("value", &value);
}
- TVM_DLL Expr static make(std::string value);
+ TVM_DLL PrimExpr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
};
/*!
* \brief Cast value from one data type to another.
* \note The lanes of value should keep fixed.
*/
-class CastNode : public ExprNode {
+class CastNode : public PrimExprNode {
public:
/*! \brief Original data type. */
- Expr value;
+ PrimExpr value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
- TVM_DLL static Expr make(DataType t, Expr v);
+ TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
static constexpr const char* _type_key = "Cast";
- TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
};
/*!
* \tparam T The type of the child class.
*/
template<typename T>
-class BinaryOpNode : public ExprNode {
+class BinaryOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
- Expr a;
+ PrimExpr a;
/*! \brief The right operand. */
- Expr b;
+ PrimExpr b;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("b", &b);
}
- static Expr make(Expr a, Expr b) {
+ static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
node->dtype = a.dtype();
node->a = std::move(a);
node->b = std::move(b);
- return Expr(node);
+ return PrimExpr(node);
}
- TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
};
/*! \brief a + b */
* \tparam T The type of the child class.
*/
template<typename T>
-class CmpOpNode : public ExprNode {
+class CmpOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
- Expr a;
+ PrimExpr a;
/*! \brief The right operand. */
- Expr b;
+ PrimExpr b;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("b", &b);
}
- static Expr make(Expr a, Expr b) {
+ static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
- return Expr(node);
+ return PrimExpr(node);
}
- TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
};
/*! \brief a == b */
};
/*! \brief a && b */
-class AndNode : public ExprNode {
+class AndNode : public PrimExprNode {
public:
/*! \brief The left operand. */
- Expr a;
+ PrimExpr a;
/*! \brief The right operand. */
- Expr b;
+ PrimExpr b;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("b", &b);
}
- TVM_DLL static Expr make(Expr a, Expr b);
+ TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "And";
- TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
};
/*! \brief a || b */
-class OrNode : public ExprNode {
+class OrNode : public PrimExprNode {
public:
/*! \brief The left operand. */
- Expr a;
+ PrimExpr a;
/*! \brief The right operand. */
- Expr b;
+ PrimExpr b;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("b", &b);
}
- TVM_DLL static Expr make(Expr a, Expr b);
+ TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "Or";
- TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
};
/*! \brief !a */
-class NotNode : public ExprNode {
+class NotNode : public PrimExprNode {
public:
/*! \brief The input operand. */
- Expr a;
+ PrimExpr a;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("a", &a);
}
- TVM_DLL static Expr make(Expr a);
+ TVM_DLL static PrimExpr make(PrimExpr a);
static constexpr const char* _type_key = "Not";
- TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
};
/*!
* Do not use it to guard against out of bound access,
* please use if_then_else instead.
*/
-class SelectNode : public ExprNode {
+class SelectNode : public PrimExprNode {
public:
/*! \brief The condition */
- Expr condition;
+ PrimExpr condition;
/*! \brief value to be returned when condition is true. */
- Expr true_value;
+ PrimExpr true_value;
/*! \brief value to be returned when condition is false. */
- Expr false_value;
+ PrimExpr false_value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("false_value", &false_value);
}
- TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
+ TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
static constexpr const char* _type_key = "Select";
- TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
};
/*!
*
* \endcode
*/
-class LoadNode : public ExprNode {
+class LoadNode : public PrimExprNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The index locations to be loaded. */
- Expr index;
+ PrimExpr index;
/*! \brief The predicate to mask which lanes would be loaded. */
- Expr predicate;
+ PrimExpr predicate;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("predicate", &predicate);
}
- TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate);
+ TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
static constexpr const char* _type_key = "Load";
- TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
};
/*!
* - ramp(0, 1, 3) = [0, 1, 2]
* - ramp(1, 2, 4) = [1, 3, 5, 7]
*/
-class RampNode : public ExprNode {
+class RampNode : public PrimExprNode {
public:
/*! \brief The base value. */
- Expr base;
+ PrimExpr base;
/*! \brief The stride of each step. */
- Expr stride;
+ PrimExpr stride;
/*! \brief Total number of lanes. */
int lanes;
v->Visit("lanes", &lanes);
}
- TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
+ TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
- TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
};
/*! \brief Create a vector where all the elements are value. */
-class BroadcastNode : public ExprNode {
+class BroadcastNode : public PrimExprNode {
public:
/*! \brief The base value. */
- Expr value;
+ PrimExpr value;
/*! \brief The number of lanes. */
int lanes;
v->Visit("lanes", &lanes);
}
- TVM_DLL static Expr make(Expr value, int lanes);
+ TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
- TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
};
/*!
* \brief Let binding. Bind var to value then evaluate body.
*/
-class LetNode : public ExprNode {
+class LetNode : public PrimExprNode {
public:
/*! \brief The variable. */
Var var;
/*! \brief The value to be binded. */
- Expr value;
+ PrimExpr value;
/*! \brief The result expression. */
- Expr body;
+ PrimExpr body;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("body", &body);
}
- TVM_DLL static Expr make(Var var, Expr value, Expr body);
+ TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
static constexpr const char* _type_key = "Let";
- TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
};
// Call node, represent a function call or a multi-dimensional array load.
/*!
* \brief Call node.
*/
-class CallNode : public ExprNode {
+class CallNode : public PrimExprNode {
public:
/*! \brief Possible types of calls. */
enum CallType : int {
/*! \brief The name of the function/intrinsic. */
std::string name;
/*! \brief The arguments. */
- Array<Expr> args;
+ Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
/*! \brief The function to be called. */
v->Visit("value_index", &value_index);
}
- TVM_DLL static Expr make(DataType dtype,
+ TVM_DLL static PrimExpr make(DataType dtype,
std::string name,
- Array<Expr> args,
+ Array<PrimExpr> args,
CallType call_type,
FunctionRef func = FunctionRef(),
int value_index = 0);
bool is_vectorizable() const;
static constexpr const char* _type_key = "Call";
- TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
// Build-in intrinsics
static constexpr const char* reinterpret = "reinterpret";
* vec = concat(vectors)
* result = (vec[indices[0]], vec[indices[1]] ...)
*/
-class ShuffleNode : public ExprNode {
+class ShuffleNode : public PrimExprNode {
public:
/*! \brief the input vectors. */
- Array<Expr> vectors;
+ Array<PrimExpr> vectors;
/*! \brief The indices of each element. */
- Array<Expr> indices;
+ Array<PrimExpr> indices;
void VisitAttrs(AttrVisitor* v) {
v->Visit("vectors", &vectors);
v->Visit("indices", &indices);
}
- TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
- TVM_DLL static Expr make_concat(Array<Expr> vectors);
- TVM_DLL static Expr make_extract_element(Expr vector, int index);
+ TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
+ TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
+ TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
static constexpr const char* _type_key = "Shuffle";
- TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
};
// Reduce operator
/*! \brief The right argument of reducer */
Array<Var> rhs;
/*! \brief The result of reducer */
- Array<Expr> result;
+ Array<PrimExpr> result;
/*!
* \brief The identity element of reducer, which leaves other
* elements unchanged when combined with it, with respect to
* the binary operation of this reducer uses.
*/
- Array<Expr> identity_element;
+ Array<PrimExpr> identity_element;
/*! \brief Function call operator to combine a and b */
- Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
+ Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
/*! \brief construct CommReducer from args, result and identity_element */
TVM_DLL static CommReducer make(Array<Var> lhs,
Array<Var> rhs,
- Array<Expr> result,
- Array<Expr> identity_element);
+ Array<PrimExpr> result,
+ Array<PrimExpr> identity_element);
void VisitAttrs(AttrVisitor* v) {
v->Visit("lhs", &lhs);
}
/*! \brief Reduction operator operator */
-class ReduceNode : public ExprNode {
+class ReduceNode : public PrimExprNode {
public:
/*! \brief The commutative combiner */
CommReducer combiner;
/*! \brief The source operand */
- Array<Expr> source;
+ Array<PrimExpr> source;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
* \brief Predicate on the reduction
* Only add the body to reduction if condition is true.
*/
- Expr condition;
+ PrimExpr condition;
/*! \brief the index of this reduce node */
int value_index;
/*! \brief construct expr from op and rdom */
- TVM_DLL static Expr make(CommReducer combiner,
- Array<Expr> src,
+ TVM_DLL static PrimExpr make(CommReducer combiner,
+ Array<PrimExpr> src,
Array<IterVar> rdom,
- Expr condition,
+ PrimExpr condition,
int value_index);
void VisitAttrs(AttrVisitor* v) {
}
static constexpr const char* _type_key = "Reduce";
- TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
/*! \brief Any shape. */
-class AnyNode : public ExprNode {
+class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
return VarNode::make(DataType::Int(32), "any_dim");
}
- TVM_DLL static Expr make();
+ TVM_DLL static PrimExpr make();
static constexpr const char* _type_key = "Any";
- TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode);
+ TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
};
// Statements
/*! \brief The variable. */
Var var;
/*! \brief The value to be binded. */
- Expr value;
+ PrimExpr value;
/*! \brief The body block. */
Stmt body;
v->Visit("body", &body);
}
- TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
+ TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
/*! \brief the type key of the attribute */
std::string attr_key;
/*! \brief The attribute value, value is well defined at current scope. */
- Expr value;
+ PrimExpr value;
/*! \brief The body statement to be executed */
Stmt body;
TVM_DLL static Stmt make(ObjectRef node,
std::string type_key,
- Expr value,
+ PrimExpr value,
Stmt body);
static constexpr const char* _type_key = "AttrStmt";
class AssertStmtNode : public StmtNode {
public:
/*! \brief Condition to be checked. */
- Expr condition;
+ PrimExpr condition;
/*! \brief Error message when assertion failed. */
- Expr message;
+ PrimExpr message;
/*!
* \brief Body which this assertion holds true.
* Will be executed after the assertion.
v->Visit("body", &body);
}
- TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
+ TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The value to be stored. */
- Expr value;
+ PrimExpr value;
/*! \brief The index locations to be stored. */
- Expr index;
+ PrimExpr index;
/*! \brief The predicate to mask which lanes would be stored. */
- Expr predicate;
+ PrimExpr predicate;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
}
TVM_DLL static Stmt make(Var buffer_var,
- Expr value,
- Expr index,
- Expr predicate);
+ PrimExpr value,
+ PrimExpr index,
+ PrimExpr predicate);
static constexpr const char* _type_key = "Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
/*! \brief The output value index if func's value is a tuple. */
int value_index{0};
/*! \brief The value to be stored. */
- Expr value;
+ PrimExpr value;
/*! \brief The index arguments of the function. */
- Array<Expr> args;
+ Array<PrimExpr> args;
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
- Expr value,
- Array<Expr> args);
+ PrimExpr value,
+ Array<PrimExpr> args);
static constexpr const char* _type_key = "Provide";
TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
- Array<Expr> extents;
+ Array<PrimExpr> extents;
/*! \brief Only allocate buffer when condition is satisfied. */
- Expr condition;
+ PrimExpr condition;
/*! \brief The body to be executed. */
Stmt body;
// The following two fields are deprecated
// kept for backward compatibility and will be refactored later.
- Expr new_expr;
+ PrimExpr new_expr;
std::string free_function;
void VisitAttrs(AttrVisitor* v) {
TVM_DLL static Stmt make(Var buffer_var,
DataType dtype,
- Array<Expr> extents,
- Expr condition,
+ Array<PrimExpr> extents,
+ PrimExpr condition,
Stmt body,
- Expr new_expr = Expr(),
+ PrimExpr new_expr = PrimExpr(),
std::string free_function = std::string());
/*!
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(
- const Array<Expr>& extents);
+ const Array<PrimExpr>& extents);
static constexpr const char* _type_key = "Allocate";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
/*! \brief Bounds to be realized. */
Region bounds;
/*! \brief Only realize if condition holds. */
- Expr condition;
+ PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
int value_index,
DataType dtype,
Region bounds,
- Expr condition,
+ PrimExpr condition,
Stmt body);
static constexpr const char* _type_key = "Realize";
class IfThenElseNode : public StmtNode {
public:
/*! \brief The condition. */
- Expr condition;
+ PrimExpr condition;
/*! \brief The branch to be executed when condition is true. */
Stmt then_case;
/*! \brief The branch to be executed when condition is false, can be null. */
v->Visit("else_case", &else_case);
}
- TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
+ TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
class EvaluateNode : public StmtNode {
public:
/*! \brief The expression to be evaluated. */
- Expr value;
+ PrimExpr value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
}
- TVM_DLL static Stmt make(Expr v);
+ TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
/*! \brief The loop variable. */
Var loop_var;
/*! \brief The minimum value of iteration. */
- Expr min;
+ PrimExpr min;
/*! \brief The extent of the iteration. */
- Expr extent;
+ PrimExpr extent;
/*! \brief The type of the for loop. */
ForType for_type;
/*!
Stmt body;
TVM_DLL static Stmt make(Var loop_var,
- Expr min,
- Expr extent,
+ PrimExpr min,
+ PrimExpr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body);
* \param dtype The data type
* \return Expr a expression with dtype.
*/
-inline Expr TypeAnnotation(DataType dtype) {
+inline PrimExpr TypeAnnotation(DataType dtype) {
return ir::CallNode::make(dtype,
"type_annotation", {},
ir::CallNode::PureIntrinsic);
}); \
template<typename R, typename ...Args>
-class ExprFunctor<R(const Expr& n, Args...)> {
+class ExprFunctor<R(const PrimExpr& n, Args...)> {
private:
- using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
+ using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
public:
* \param args Additional arguments.
* \return The result of the call
*/
- R operator()(const Expr& n, Args... args) {
+ R operator()(const PrimExpr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \param args Additional arguments.
* \return The result of the call
*/
- virtual R VisitExpr(const Expr& n, Args... args) {
+ virtual R VisitExpr(const PrimExpr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
* \brief ExprVisitor
*/
class TVM_DLL ExprVisitor :
- public ExprFunctor<void(const Expr&)> {
+ public ExprFunctor<void(const PrimExpr&)> {
public:
using ExprFunctor::operator();
* \brief ExprMutator that mutates expressions.
*/
class TVM_DLL ExprMutator :
- protected ExprFunctor<Expr(const Expr&)> {
+ protected ExprFunctor<PrimExpr(const PrimExpr&)> {
public:
using ExprFunctor::operator();
protected:
using ExprFunctor::VisitExpr;
// list of functions to override.
- Expr VisitExpr_(const VarNode* op) override;
- Expr VisitExpr_(const LoadNode* op) override;
- Expr VisitExpr_(const LetNode* op) override;
- Expr VisitExpr_(const CallNode* op) override;
- Expr VisitExpr_(const AddNode* op) override;
- Expr VisitExpr_(const SubNode* op) override;
- Expr VisitExpr_(const MulNode* op) override;
- Expr VisitExpr_(const DivNode* op) override;
- Expr VisitExpr_(const ModNode* op) override;
- Expr VisitExpr_(const FloorDivNode* op) override;
- Expr VisitExpr_(const FloorModNode* op) override;
- Expr VisitExpr_(const MinNode* op) override;
- Expr VisitExpr_(const MaxNode* op) override;
- Expr VisitExpr_(const EQNode* op) override;
- Expr VisitExpr_(const NENode* op) override;
- Expr VisitExpr_(const LTNode* op) override;
- Expr VisitExpr_(const LENode* op) override;
- Expr VisitExpr_(const GTNode* op) override;
- Expr VisitExpr_(const GENode* op) override;
- Expr VisitExpr_(const AndNode* op) override;
- Expr VisitExpr_(const OrNode* op) override;
- Expr VisitExpr_(const ReduceNode* op) override;
- Expr VisitExpr_(const CastNode* op) override;
- Expr VisitExpr_(const NotNode* op) override;
- Expr VisitExpr_(const SelectNode* op) override;
- Expr VisitExpr_(const RampNode* op) override;
- Expr VisitExpr_(const BroadcastNode* op) override;
- Expr VisitExpr_(const ShuffleNode* op) override;
- Expr VisitExpr_(const IntImmNode* op) override;
- Expr VisitExpr_(const UIntImmNode* op) override;
- Expr VisitExpr_(const FloatImmNode* op) override;
- Expr VisitExpr_(const StringImmNode* op) override;
+ PrimExpr VisitExpr_(const VarNode* op) override;
+ PrimExpr VisitExpr_(const LoadNode* op) override;
+ PrimExpr VisitExpr_(const LetNode* op) override;
+ PrimExpr VisitExpr_(const CallNode* op) override;
+ PrimExpr VisitExpr_(const AddNode* op) override;
+ PrimExpr VisitExpr_(const SubNode* op) override;
+ PrimExpr VisitExpr_(const MulNode* op) override;
+ PrimExpr VisitExpr_(const DivNode* op) override;
+ PrimExpr VisitExpr_(const ModNode* op) override;
+ PrimExpr VisitExpr_(const FloorDivNode* op) override;
+ PrimExpr VisitExpr_(const FloorModNode* op) override;
+ PrimExpr VisitExpr_(const MinNode* op) override;
+ PrimExpr VisitExpr_(const MaxNode* op) override;
+ PrimExpr VisitExpr_(const EQNode* op) override;
+ PrimExpr VisitExpr_(const NENode* op) override;
+ PrimExpr VisitExpr_(const LTNode* op) override;
+ PrimExpr VisitExpr_(const LENode* op) override;
+ PrimExpr VisitExpr_(const GTNode* op) override;
+ PrimExpr VisitExpr_(const GENode* op) override;
+ PrimExpr VisitExpr_(const AndNode* op) override;
+ PrimExpr VisitExpr_(const OrNode* op) override;
+ PrimExpr VisitExpr_(const ReduceNode* op) override;
+ PrimExpr VisitExpr_(const CastNode* op) override;
+ PrimExpr VisitExpr_(const NotNode* op) override;
+ PrimExpr VisitExpr_(const SelectNode* op) override;
+ PrimExpr VisitExpr_(const RampNode* op) override;
+ PrimExpr VisitExpr_(const BroadcastNode* op) override;
+ PrimExpr VisitExpr_(const ShuffleNode* op) override;
+ PrimExpr VisitExpr_(const IntImmNode* op) override;
+ PrimExpr VisitExpr_(const UIntImmNode* op) override;
+ PrimExpr VisitExpr_(const FloatImmNode* op) override;
+ PrimExpr VisitExpr_(const StringImmNode* op) override;
};
/*!
* or have a class sub-class both StmtVisitor and ExprVisitor
* and redirect Visit to ExprMutator::VisitExpr(Expr)
*/
- virtual void VisitExpr(const Expr& e) {}
+ virtual void VisitExpr(const PrimExpr& e) {}
// statement visitor
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
* or have a class sub-class both StmtMutator and ExprMutator
* and redirect Mutate to ExprMutator::Mutate(Expr)
*/
- virtual Expr VisitExpr(const Expr& e) {
+ virtual PrimExpr VisitExpr(const PrimExpr& e) {
return e;
}
// statement visitor
using StmtVisitor::VisitStmt;
using ExprVisitor::VisitExpr;
- void VisitExpr(const Expr& e) override {
+ void VisitExpr(const PrimExpr& e) override {
return ExprVisitor::VisitExpr(e);
}
};
using StmtMutator::VisitExpr;
using ExprMutator::VisitExpr;
- Expr VisitExpr(const Expr& e) override {
+ PrimExpr VisitExpr(const PrimExpr& e) override {
return ExprMutator::VisitExpr(e);
}
};
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
- const Array<Expr>& only_enable = {});
+ const Array<PrimExpr>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
-TVM_DLL Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
+TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify the statement.
* \param vrange The range information about the variable.
* \return Canonicalized expression.
*/
-TVM_DLL Expr CanonicalSimplify(Expr expr,
+TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \param rhs The right operand
* \return The comparison result.
*/
-TVM_DLL bool Equal(const Expr& lhs, const Expr& rhs);
+TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs);
/*!
* \brief Deep compare lhs and rhs
* \param rhs The right operand
* \return The comparison result.
*/
-int Compare(const Expr& lhs, const Expr& rhs);
+int Compare(const PrimExpr& lhs, const PrimExpr& rhs);
/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* \brief Whether the expression have side effect.
* \return whether expression have side effect
*/
-TVM_DLL bool HasSideEffect(const Expr& e);
+TVM_DLL bool HasSideEffect(const PrimExpr& e);
/*!
* \brief Whether e expression used var.
* \param v The variable.
* \return Whether e uses v.
*/
-bool ExprUseVar(const Expr& e, const Var& v);
+bool ExprUseVar(const PrimExpr& e, const Var& v);
/*!
* \brief Whether e expression used any var in variable set..
* \param vset The variable set.
* \return Whether e uses vset.
*/
-bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
+bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vset);
/*!
* \brief Convert a IR node to be SSA form.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
- const std::unordered_map<const VarNode*, Expr>& value_map);
+ const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param value_map The map of new values.
* \return The converted expression.
*/
-Expr Substitute(Expr expr,
- const std::unordered_map<const VarNode*, Expr>& value_map);
+PrimExpr Substitute(PrimExpr expr,
+ const std::unordered_map<const VarNode*, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param value_map The map of new values.
* \return The converted form.
*/
-Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
+Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param value_map The map of new values.
* \return The converted expression.
*/
-Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
+PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
/*!
* \brief inline all calls of f in stmt.
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
- Expr body);
+ PrimExpr body);
/*!
* \brief Flatten the multi-dimensional read/write
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
-LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);
+LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Lower packed function call.
*
*/
bool VerifyGPUCode(Stmt stmt,
- Map<std::string, Expr> constraints);
+ Map<std::string, PrimExpr> constraints);
} // namespace ir
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
- Map<Var, Expr> handle_data_type;
+ Map<Var, PrimExpr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
* \param i The output index.
* \return shape of i-th output.
*/
- virtual Array<Expr> output_shape(size_t i) const = 0;
+ virtual Array<PrimExpr> output_shape(size_t i) const = 0;
/*!
* \brief List all the input Tensors.
* \return List of input tensors.
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
- Array<Expr> shape;
+ Array<PrimExpr> shape;
/*! \brief The data type of the input. */
DataType dtype;
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
- Array<Expr> output_shape(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
- Array<Expr> shape,
+ Array<PrimExpr> shape,
DataType dtype);
static constexpr const char* _type_key = "PlaceholderOp";
Array<IterVar> reduce_axis;
// override functions
Array<IterVar> root_iter_vars() const final;
- Array<Expr> output_shape(size_t idx) const final;
+ Array<PrimExpr> output_shape(size_t idx) const final;
void GatherBound(
const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief the compute expression */
- Array<Expr> body;
+ Array<PrimExpr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
std::string tag,
Map<std::string, ObjectRef> attrs,
Array<IterVar> axis,
- Array<Expr> body);
+ Array<PrimExpr> body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
- Array<Expr> scalar_inputs;
+ Array<PrimExpr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
- Array<Expr> scalar_inputs);
+ Array<PrimExpr> scalar_inputs);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
- Array<Expr> output_shape(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
- Array<Expr> output_shape(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
- Array<Expr> output_shape(size_t i) const final;
+ Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(
const Operation& self,
};
/*! \brief The compute function to specify the input source of a Tensor */
-using FCompute = std::function<Expr (const Array<Var>& i)>;
+using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
-using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
+using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
-TVM_DLL Tensor placeholder(Array<Expr> shape,
+TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
std::string name = "placeholder");
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
-TVM_DLL Tensor compute(Array<Expr> shape,
+TVM_DLL Tensor compute(Array<PrimExpr> shape,
FCompute fcompute,
std::string name = "tensor",
std::string tag = "",
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
-TVM_DLL Array<Tensor> compute(Array<Expr> shape,
+TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
FBatchCompute fcompute,
std::string name = "tensor",
std::string tag = "",
Map<std::string, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function
-inline Tensor compute(Array<Expr> shape,
- std::function<Expr(Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<Expr> shape,
- std::function<Expr(Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<Expr> shape,
- std::function<Expr(Var, Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag, attrs);
}
-inline Tensor compute(Array<Expr> shape,
- std::function<Expr(Var, Var, Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+ std::function<PrimExpr(Var, Var, Var, Var)> f,
std::string name = "tensor",
std::string tag = "",
Map<std::string, ObjectRef> attrs = {}) {
};
// extensions for tvm arg value
-inline TVMPODValue_::operator tvm::Expr() const {
- if (type_code_ == kNull) return Expr();
+inline TVMPODValue_::operator tvm::PrimExpr() const {
+ if (type_code_ == kNull) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
- return Expr(static_cast<int>(value_.v_int64));
+ return PrimExpr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kDLFloat) {
- return Expr(static_cast<float>(value_.v_float64));
+ return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
if (ptr->IsInstance<TensorNode>()) {
return Tensor(ObjectPtr<Object>(ptr))();
}
- CHECK(ObjectTypeChecker<Expr>::Check(ptr))
- << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
+ CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
+ << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
- return Expr(ObjectPtr<Object>(ptr));
+ return PrimExpr(ObjectPtr<Object>(ptr));
}
inline TVMPODValue_::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
- << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
+ << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
#include <tvm/attrs.h>
#include <tvm/relay/base.h>
+#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
/*!
* \brief Symbolic expression for tensor shape.
*/
-using IndexExpr = ::tvm::Expr;
+using IndexExpr = ::tvm::PrimExpr;
using SourceName = tvm::SourceName;
using Span = tvm::Span;
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
- tvm::Array<tvm::Expr> required_pass;
+ tvm::Array<tvm::PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
- tvm::Array<tvm::Expr> disabled_pass;
+ tvm::Array<tvm::PrimExpr> disabled_pass;
PassContextNode() = default;
std::string name;
/*! \brief The passes that are required to perform the current pass. */
- tvm::Array<tvm::Expr> required;
+ tvm::Array<tvm::PrimExpr> required;
PassInfoNode() = default;
TVM_DLL static PassInfo make(int opt_level,
std::string name,
- tvm::Array<tvm::Expr> required);
+ tvm::Array<tvm::PrimExpr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode);
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::Expr>& required);
+ const tvm::Array<tvm::PrimExpr>& required);
/*
* \brief Create a function pass.
Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::Expr>& required);
+ const tvm::Array<tvm::PrimExpr>& required);
/*! \brief Remove expressions which does not effect the program result.
*
namespace tvm {
// forward declarations
class Integer;
-class Expr;
+class PrimExpr;
namespace runtime {
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// ObjectRef Specializations
- inline operator tvm::Expr() const;
+ inline operator tvm::PrimExpr() const;
inline operator tvm::Integer() const;
protected:
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
- using TVMPODValue_::operator tvm::Expr;
+ using TVMPODValue_::operator tvm::PrimExpr;
using TVMPODValue_::operator tvm::Integer;
// conversion operator.
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef;
- using TVMPODValue_::operator tvm::Expr;
+ using TVMPODValue_::operator tvm::PrimExpr;
using TVMPODValue_::operator tvm::Integer;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
* \param predicate The condition to be checked.
* \return reference to self.
*/
- TVM_DLL Stage& set_store_predicate(Expr predicate);
+ TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
* \param p_inner The result inner domain.
* \return reference to self.
*/
- TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Split the iteration with given number of parts.
*
* \param p_inner The result inner domain.
* \return reference to self.
*/
- TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
+ TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
* \return reference to self.
*/
TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
- Expr x_factor, Expr y_factor,
+ PrimExpr x_factor, PrimExpr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
/*!
*/
TVM_DLL Stage& pragma(IterVar var,
const std::string& pragma_type,
- const Expr& pragma_value = Expr()); // NOLINT(*)
+ const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
/*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
- TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
+ TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
* Use this when there can be duplicated threads doing the same store.
* \note Experimental primitive: used by cross thread-reduction.
*/
- Expr store_predicate;
+ PrimExpr store_predicate;
/*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
/*! \brief List of tensor to be prefetched in this loop */
Array<Tensor> prefetch_data;
/*! \brief The offset used in each prefetch */
- Array<Expr> prefetch_offset;
+ Array<PrimExpr> prefetch_offset;
/*!
* \brief Tensor intrinsic used in tensorization,
* when the axis is marked as Tensorized
/*!
* \brief Additional pragma keys, array of StringImm
*/
- Array<Expr> pragma_keys;
+ Array<PrimExpr> pragma_keys;
/*!
* \brief Additional values of pragma, if any
*/
- Array<Expr> pragma_values;
+ Array<PrimExpr> pragma_values;
void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_type", &iter_type);
/*! \brief The inner domain */
IterVar inner;
/*! \brief The split factor */
- Expr factor;
+ PrimExpr factor;
/*! \brief Number of parts, only factor or nparts can be given */
- Expr nparts;
+ PrimExpr nparts;
void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
static IterVarRelation make(IterVar parent,
IterVar outer,
IterVar inner,
- Expr factor,
- Expr nparts);
+ PrimExpr factor,
+ PrimExpr nparts);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
* \brief head address of the buffer, if visible to CPU
* This address can be None.
*/
- Expr head_address;
+ PrimExpr head_address;
void VisitAttrs(AttrVisitor* v) {
v->Visit("unit_bits", &unit_bits);
* \return the result expression representing tensor read.
*/
template<typename... Args>
- inline Expr operator()(Args&& ...args) const {
- Array<Expr> indices{std::forward<Args>(args)...};
+ inline PrimExpr operator()(Args&& ...args) const {
+ Array<PrimExpr> indices{std::forward<Args>(args)...};
return operator()(indices);
}
/*!
* \param indices the indices.
* \return the result expression representing tensor read.
*/
- TVM_DLL Expr operator()(Array<Expr> indices) const;
+ TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
- TVM_DLL Expr operator()(Array<Var> indices) const;
+ TVM_DLL PrimExpr operator()(Array<Var> indices) const;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
class Slice {
public:
// construct via tensor and indices
- Slice(const Tensor& tensor, std::vector<Expr> indices)
+ Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
: tensor_(tensor), indices_(indices) {}
/*!
* \brief get i-th slice from the current slice.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
- inline Slice operator[](Expr i) {
- std::vector<Expr> other = indices_;
+ inline Slice operator[](PrimExpr i) {
+ std::vector<PrimExpr> other = indices_;
other.emplace_back(i);
return Slice(tensor_, other);
}
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
- inline operator Expr() const {
+ inline operator PrimExpr() const {
return tensor_(indices_);
}
private:
const Tensor& tensor_;
- std::vector<Expr> indices_;
+ std::vector<PrimExpr> indices_;
};
/*!
* \brief get i-th slice from the current Tensor.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
- inline Slice operator[](Expr i) const {
+ inline Slice operator[](PrimExpr i) const {
return Slice(*this, {i});
}
/*! \brief specify container node */
class TensorNode : public Object {
public:
/*! \brief The shape of the tensor */
- Array<Expr> shape;
+ Array<PrimExpr> shape;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief the source operation, can be None */
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
- TVM_DLL static Tensor make(Array<Expr> shape,
+ TVM_DLL static Tensor make(Array<PrimExpr> shape,
DataType dtype,
Operation op,
int value_index);
// macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
- inline Expr operator Op (const Tensor::Slice& a) { \
- return Op a.operator Expr() ; \
+ inline PrimExpr operator Op (const Tensor::Slice& a) { \
+ return Op a.operator PrimExpr() ; \
} \
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \
- inline Expr operator Op (const Tensor::Slice& a, const T& b) { \
- return a.operator Expr() Op b; \
+ inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
+ return a.operator PrimExpr() Op b; \
} \
template<typename T> \
- inline Expr operator Op (const T& a, const Tensor::Slice& b) { \
- return a Op b.operator Expr(); \
- } \
- inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
- return a.operator Expr() Op b.operator Expr(); \
+ inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \
+ return a Op b.operator PrimExpr(); \
+ } \
+ inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
+ return a.operator PrimExpr() Op b.operator PrimExpr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
Array<IterVar> reduce_axis;
/*! \brief scalar expression inputs */
- Array<Expr> scalar_inputs;
+ Array<PrimExpr> scalar_inputs;
void VisitAttrs(AttrVisitor* v) {
v->Visit("intrin", &intrin);
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis,
- Array<Expr> scalar_inputs);
+ Array<PrimExpr> scalar_inputs);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
tensor: Tensor
The created tensor
"""
- shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+ shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
dtype = float32 if dtype is None else dtype
return _api_internal._Placeholder(
shape, dtype, name)
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
- shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+ shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
if tag != "":
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
- shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
- if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
+ shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
+ if shape == () or isinstance(shape[0], (_expr.PrimExpr, _Integral)):
shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
for shp, dt in zip(shape, dtype):
output_placeholders.append(decl_buffer(shp, dt, name))
body = fcompute(input_placeholders, output_placeholders)
- if isinstance(body, _expr.Expr):
+ if isinstance(body, _expr.PrimExpr):
body = _make.Evaluate(body)
op = _api_internal._ExternOp(name, tag, attrs,
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""
- shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
+ shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
result = fcombine(lhs, rhs)
id_elem = fidentity(*dtypes)
else:
- assert isinstance(expr, _expr.Expr)
+ assert isinstance(expr, _expr.PrimExpr)
size = 1
dtype = expr.dtype
lvar = var(code.co_varnames[0], dtype)
tensor: SparsePlaceholderOp
The created sparse tensor placeholder
"""
- shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+ shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
nonzeros = 0 if nonzeros is None else nonzeros
dtype = float32 if dtype is None else dtype
stype = 'csr' if stype is None else stype
return _make._OpNE(self.a, self.b)
-class Expr(ExprOp, NodeBase):
+class PrimExpr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__
-class ConstExpr(Expr):
+class ConstExpr(PrimExpr):
pass
-class BinaryOpExpr(Expr):
+class BinaryOpExpr(PrimExpr):
pass
-class CmpExpr(Expr):
+class CmpExpr(PrimExpr):
pass
-class LogicalExpr(Expr):
+class LogicalExpr(PrimExpr):
pass
@register_node("Variable")
-class Var(Expr):
+class Var(PrimExpr):
"""Symbolic variable.
Parameters
@register_node
-class Reduce(Expr):
+class Reduce(PrimExpr):
"""Reduce node.
Parameters
@register_node
-class Cast(Expr):
+class Cast(PrimExpr):
"""Cast expression.
Parameters
@register_node
-class Select(Expr):
+class Select(PrimExpr):
"""Select node.
Note
@register_node
-class Load(Expr):
+class Load(PrimExpr):
"""Load node.
Parameters
@register_node
-class Ramp(Expr):
+class Ramp(PrimExpr):
"""Ramp node.
Parameters
@register_node
-class Broadcast(Expr):
+class Broadcast(PrimExpr):
"""Broadcast node.
Parameters
@register_node
-class Shuffle(Expr):
+class Shuffle(PrimExpr):
"""Shuffle node.
Parameters
@register_node
-class Call(Expr):
+class Call(PrimExpr):
"""Call node.
Parameters
@register_node
-class Let(Expr):
+class Let(PrimExpr):
"""Let node.
Parameters
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
- _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
+ _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression")
if n > 1:
_internal_assert(isinstance(args[1], str),
"The data type should be an str")
def _cast(func_id, args):
- _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
+ _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])
def ceil_div(func_id, args):
_internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 2, "2 arguments expected for division!")
- _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
- _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
+ _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div")
+ _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) // b
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
- if isinstance(rhs, _expr.Expr):
+ if isinstance(rhs, _expr.PrimExpr):
rhs = _ir_pass.Simplify(rhs)
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
load : Expr
The corresponding load expression.
"""
- begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
+ begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)
store : Stmt
The corresponding store stmt.
"""
- begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
+ begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
return _api_internal._BufferVStore(self, begin, value)
indices = convert_to_node(indices)
args = []
for x in indices:
- if isinstance(x, _expr.Expr):
+ if isinstance(x, _expr.PrimExpr):
args.append(x)
elif isinstance(x, iter_var_cls):
args.append(x.var)
else:
body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
scalar_params = []
- if isinstance(body, (_expr.Expr, _stmt.Stmt)):
+ if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
body = [body]
- body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
+ body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed([](
- Expr v, Expr cond,
+ PrimExpr v, PrimExpr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
if (args[1].IsObjectRef<Range>()) {
self->Bind(args[0], args[1].operator Range());
} else {
- self->Bind(args[0], args[1].operator Expr());
+ self->Bind(args[0], args[1].operator PrimExpr());
}
});
} else if (name == "enter_constraint_context") {
TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([](
- VarExpr loop_var, Expr min, Expr extent,
+ Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
min,
TVM_REGISTER_GLOBAL("make.Store")
.set_body([](TVMArgs args, TVMRetValue *ret) {
- Expr value = args[1];
+ PrimExpr value = args[1];
if (args.size() == 3) {
*ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
} else {
TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed([](
DataType type, std::string name,
- Array<Expr> args, int call_type,
+ Array<PrimExpr> args, int call_type,
FunctionRef func, int value_index
) {
return CallNode::make(type,
// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([](
- VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
+ Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
){
return AllocateNode::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
- .set_body_typed([](Expr a, Expr b) { \
+ .set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \
})
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
- *ret = (Func(args[0].operator int(), args[1].operator Expr())); \
+ *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
} else if (rhs_is_int) { \
- *ret = (Func(args[0].operator Expr(), args[1].operator int())); \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
} else { \
- *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
+ *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
} \
})
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse")
-.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) {
+.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
});
TVM_REGISTER_GLOBAL("_Placeholder")
-.set_body_typed([](Array<Expr> shape, DataType dtype, std::string name) {
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, Expr factor) {
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, Expr nparts) {
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
- Expr x_factor, Expr y_factor
+ PrimExpr x_factor, PrimExpr y_factor
) {
IterVar x_outer, y_outer, x_inner, y_inner;
stage.tile(x_parent, y_parent,
}
} else {
if (args.size() > 1) {
- *ret = Simplify(args[0].operator Expr(), args[1]);
+ *ret = Simplify(args[0].operator PrimExpr(), args[1]);
} else {
- *ret = Simplify(args[0].operator Expr());
+ *ret = Simplify(args[0].operator PrimExpr());
}
}
});
}
} else {
if (args.size() > 1) {
- *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
+ *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]);
} else {
- *ret = CanonicalSimplify(args[0].operator Expr());
+ *ret = CanonicalSimplify(args[0].operator PrimExpr());
}
}
});
TVM_REGISTER_GLOBAL("ir_pass.Substitute")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<Stmt>()) {
- *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
+ *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
} else {
- *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
+ *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
}
});
if (args[0].IsObjectRef<Stmt>()) {
*ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
} else {
- *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
+ *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr());
}
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
- *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
+ *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
});
TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
- Array<Expr> padding;
+ Array<PrimExpr> padding;
TypedEnvFunc<int(int)> func;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
.describe("name");
TVM_ATTR_FIELD(padding)
.describe("padding of input")
- .set_default(Array<Expr>({0, 0}));
+ .set_default(Array<PrimExpr>({0, 0}));
TVM_ATTR_FIELD(func)
.describe("some random env function")
.set_default(TypedEnvFunc<int(int)>(nullptr));
int_set(this) {
}
-void Analyzer::Bind(const VarExpr& var, const Expr& expr) {
- Expr new_expr = expr;
+void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
+ PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);
this->canonical_simplify.Update(var, new_expr);
}
-void Analyzer::Bind(const VarExpr& var, const Range& range) {
+void Analyzer::Bind(const Var& var, const Range& range) {
CHECK(range.defined());
if (is_one(range->extent)) {
this->Bind(var, range->min);
exit_();
}
-bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
+bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImmNode>()) {
return ptr->value >= lower_bound;
}
return false;
}
-bool Analyzer::CanProve(const Expr& expr) {
+bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
return ptr->value != 0;
}
return false;
}
-Expr Analyzer::Simplify(const Expr& expr) {
+PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
if (is_const(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (is_const(res)) return res;
// from a expression.
class VariablePathFinder: public ExprVisitor {
public:
- explicit VariablePathFinder(Expr target) : target_(target) {}
+ explicit VariablePathFinder(PrimExpr target) : target_(target) {}
- void VisitExpr(const Expr& node) final {
+ void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
private:
bool found_{false};
- Expr target_;
+ PrimExpr target_;
std::unordered_set<const Object*> visited_;
};
// get the path to the variable,
// return empty vector to represent failure
-std::vector<const Object*> GetPath(Expr target, Expr expr) {
+std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) {
VariablePathFinder v(target);
v(expr);
return v.path_;
public:
friend class BoundDeduceInputChecker;
friend class Converter;
- BoundDeducer(Expr target, Expr expr,
+ BoundDeducer(PrimExpr target, PrimExpr expr,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
void Deduce();
- void VisitExpr(const Expr& e) final {
+ void VisitExpr(const PrimExpr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
ExprVisitor::VisitExpr(e);
void VisitExpr_(const MulNode* op) final {
bool left = op->a.get() == path_[iter_];
- Expr operand = left ? op->b : op->a;
- Expr target_var = left ? op->a : op->b;
+ PrimExpr operand = left ? op->b : op->a;
+ PrimExpr target_var = left ? op->a : op->b;
SignType sign_operand;
if (operand.dtype().is_uint()) {
this->VisitExpr(left ? op->a : op->b);
}
- Expr result_;
+ PrimExpr result_;
CompareOp comp_op{kGreater};
bool success_{true};
void Transform();
void Relax();
CompareOp ReverseOp(CompareOp comp_op);
- Expr target_;
- Expr expr_;
+ PrimExpr target_;
+ PrimExpr expr_;
const std::unordered_map<const VarNode*, IntSet>& hint_map_;
const std::unordered_map<const VarNode*, IntSet>& relax_map_;
ExprIntSetMap expr_map_;
return target_count == 1;
}
- void VisitExpr(const Expr& e) final {
+ void VisitExpr(const PrimExpr& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
ExprVisitor::VisitExpr(e);
}
result_ = (comp_op == kGreater) ? b.max() : b.min();
}
-IntSet DeduceBound(Expr v, Expr e,
+IntSet DeduceBound(PrimExpr v, PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success_) return IntSet::nothing();
- Expr min = neg_inf(), max = pos_inf();
+ PrimExpr min = neg_inf(), max = pos_inf();
if (d.comp_op == kEqual) {
min = d.result_;
max = d.result_;
// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
-IntSet DeduceBound(Expr v, Expr e,
+IntSet DeduceBound(PrimExpr v, PrimExpr e,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map) {
std::unordered_map<const VarNode*, IntSet> hmap;
* \brief Base class of all temporary expression introduced
* for canonicalization.
*/
-class CanonicalExprNode : public BaseExprNode {
+class CanonicalExprNode : public PrimExprNode {
public:
virtual ~CanonicalExprNode() {}
/*!
* \note Can mutate the internal data structure.
* \return The normal expression.
*/
- virtual Expr Normalize() const = 0;
+ virtual PrimExpr Normalize() const = 0;
// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
}
static constexpr const char* _type_key = "arith.CanonicalExpr";
- TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode);
+ TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};
enum DivMode {
kFloorDiv
};
-inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
if (mode == kTruncDiv) {
return truncmod(a, b);
} else {
}
}
-inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
if (mode == kTruncDiv) {
return truncdiv(a, b);
} else {
class SplitExprNode : public CanonicalExprNode {
public:
/*! \brief The base index expression. */
- Expr index;
+ PrimExpr index;
/*! \brief The division factor ratio. */
int64_t lower_factor{1};
/*!
CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0);
}
- Expr NormalizeWithScale(int64_t sscale) const {
- Expr res = this->index;
+ PrimExpr NormalizeWithScale(int64_t sscale) const {
+ PrimExpr res = this->index;
DataType dtype = this->dtype;
if (this->scale == 0) {
return make_const(dtype, 0);
return res;
}
- Expr Normalize() const final {
+ PrimExpr Normalize() const final {
return NormalizeWithScale(1);
}
TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
};
-class SplitExpr : public Expr {
+class SplitExpr : public PrimExpr {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode);
+ TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
};
* \brief Return the normal Expr that is equivalent to self.
* \return The normal expression.
*/
- Expr Normalize() const final {
+ PrimExpr Normalize() const final {
// quick path 1.
if (this->args.size() == 0) {
return make_const(this->dtype, this->base);
std::stable_sort(args.begin(), args.end(), fcompare);
return args;
}
- static Expr Normalize_(DataType dtype,
+ static PrimExpr Normalize_(DataType dtype,
const std::vector<SplitExpr>& args,
int64_t base) {
// Positive scales first
- Expr res = make_const(dtype, 0);
+ PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
}
};
-class SumExpr : public Expr {
+class SumExpr : public PrimExpr {
public:
- TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode);
+ TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
};
: Rewriter(parent) {}
- Expr CanonicalSimplify(Expr expr) {
+ PrimExpr CanonicalSimplify(PrimExpr expr) {
expr = operator()(expr);
return expr;
}
// override the original mutate function.
- Expr VisitExpr(const Expr& input_expr) final {
+ PrimExpr VisitExpr(const PrimExpr& input_expr) final {
auto expr = Rewriter::VisitExpr(input_expr);
return Normalize(expr);
}
// Normal mutation without normalization.
- Expr CanonicalMutate(Expr expr) {
+ PrimExpr CanonicalMutate(PrimExpr expr) {
return Rewriter::VisitExpr(expr);
}
using Rewriter::VisitExpr_;
- Expr VisitExpr_(const AddNode* op) final;
- Expr VisitExpr_(const SubNode* op) final;
- Expr VisitExpr_(const MulNode* op) final;
- Expr VisitExpr_(const DivNode* op) final;
- Expr VisitExpr_(const ModNode* op) final;
- Expr VisitExpr_(const FloorDivNode* op) final;
- Expr VisitExpr_(const FloorModNode* op) final;
- Expr VisitExpr_(const ReduceNode* op) final;
+ PrimExpr VisitExpr_(const AddNode* op) final;
+ PrimExpr VisitExpr_(const SubNode* op) final;
+ PrimExpr VisitExpr_(const MulNode* op) final;
+ PrimExpr VisitExpr_(const DivNode* op) final;
+ PrimExpr VisitExpr_(const ModNode* op) final;
+ PrimExpr VisitExpr_(const FloorDivNode* op) final;
+ PrimExpr VisitExpr_(const FloorModNode* op) final;
+ PrimExpr VisitExpr_(const ReduceNode* op) final;
private:
/*!
* \param expr The input expression.
* \return Normalized expr.
*/
- Expr Normalize(Expr expr) {
+ PrimExpr Normalize(PrimExpr expr) {
if (const auto* op = expr.as<CanonicalExprNode>()) {
return op->Normalize();
} else {
* \param expr The input expr.
* \return The transformed SplitExpr.
*/
- SplitExpr ToSplitExpr(Expr expr) {
+ SplitExpr ToSplitExpr(PrimExpr expr) {
if (const auto* op = expr.as<SplitExprNode>()) {
return GetRef<SplitExpr>(op);
}
* \param expr The input expr.
* \return The transformed SumExpr.
*/
- SumExpr ToSumExpr(Expr expr) {
+ SumExpr ToSumExpr(PrimExpr expr) {
if (const auto* op = expr.as<SumExprNode>()) {
return GetRef<SumExpr>(op);
}
}
}
// Simplify the combiner used in reduce.
- Expr SimplifyReduceCombiner(const ReduceNode* op);
+ PrimExpr SimplifyReduceCombiner(const ReduceNode* op);
};
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<AddNode>(a, b);
+ PrimExpr const_res = TryConstFold<AddNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
return std::move(ret);
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const SubNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<SubNode>(a, b);
+ PrimExpr const_res = TryConstFold<SubNode>(a, b);
if (const_res.defined()) return const_res;
// canonical form simplification.
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const MulNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<MulNode>(a, b);
+ PrimExpr const_res = TryConstFold<MulNode>(a, b);
if (const_res.defined()) return const_res;
// x * c
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return MulNode::make(a, b);
}
return lhs;
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const DivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<DivNode>(a, b);
+ PrimExpr const_res = TryConstFold<DivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
lhs.CopyOnWrite()->DivideBy(cval);
- Expr temp = Normalize(extra);
+ PrimExpr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return DivNode::make(a, b);
}
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const FloorDivNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<FloorDivNode>(a, b);
+ PrimExpr const_res = TryConstFold<FloorDivNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// x / c1
}
// continue simplification.
lhs.CopyOnWrite()->DivideBy(cval);
- Expr temp = Normalize(extra);
+ PrimExpr temp = Normalize(extra);
if (const auto* pconst = temp.as<IntImmNode>()) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return FloorDivNode::make(a, b);
}
return lhs;
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const ModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<ModNode>(a, b);
+ PrimExpr const_res = TryConstFold<ModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
// both lhs and extra are non-negative
if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
- Expr temp = Normalize(extra);
+ PrimExpr temp = Normalize(extra);
if (temp.as<IntImmNode>()) {
return truncmod(temp, c1.Eval());
} else {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return ModNode::make(a, b);
}
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const FloorModNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
- Expr a = this->CanonicalMutate(op->a);
- Expr b = this->CanonicalMutate(op->b);
+ PrimExpr a = this->CanonicalMutate(op->a);
+ PrimExpr b = this->CanonicalMutate(op->b);
// const folding
- Expr const_res = TryConstFold<FloorModNode>(a, b);
+ PrimExpr const_res = TryConstFold<FloorModNode>(a, b);
if (const_res.defined()) return const_res;
PVar<Integer> c1;
if (const auto* psum = a.as<SumExprNode>()) {
SumExpr lhs, extra;
SeparateDivisibleParts(psum, cval, &lhs, &extra);
- Expr temp = Normalize(extra);
+ PrimExpr temp = Normalize(extra);
if (temp.as<IntImmNode>()) {
return floormod(temp, c1.Eval());
} else {
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return FloorModNode::make(a, b);
}
}
// Simplify reduce expression.
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
SimplifyReduceCombiner(const ReduceNode* op) {
// First simplify the results
- Array<Expr> simplified_result;
+ Array<PrimExpr> simplified_result;
for (const auto& res : op->combiner->result) {
- Expr new_res = this->VisitExpr(res);
+ PrimExpr new_res = this->VisitExpr(res);
simplified_result.push_back(new_res);
}
}
int new_value_index = op->value_index;
- Array<Expr> new_result;
- Array<Expr> new_identity;
+ Array<PrimExpr> new_result;
+ Array<PrimExpr> new_identity;
Array<Var> new_lhs;
Array<Var> new_rhs;
- Array<Expr> new_source;
+ Array<PrimExpr> new_source;
// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
new_combiner, new_source, op->axis, op->condition, new_value_index);
}
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
VisitExpr_(const ReduceNode* op) {
// Recursively call simplification when necessary.
- Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
+ PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op);
op = ret.as<ReduceNode>();
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
return ret;
}
-Expr CanonicalSimplifier::operator()(const Expr& expr) {
+PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
void CanonicalSimplifier::Update(const Var& var,
- const Expr& info,
+ const PrimExpr& info,
bool override) {
impl_->Update(var, info, override);
}
* \return The result.
*/
template<typename OP>
-inline Expr Compute(Expr lhs, Expr rhs) {
+inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
return OP::make(lhs, rhs);
}
* \return The result.
*/
template<typename Op>
-inline Expr ComputeReduce(
- const Array<Expr>& values, Expr empty_value);
+inline PrimExpr ComputeReduce(
+ const Array<PrimExpr>& values, PrimExpr empty_value);
-inline bool GetConst(Expr e, int64_t* out) {
+inline bool GetConst(PrimExpr e, int64_t* out) {
if (e.dtype().is_vector()) return false;
const int64_t* v = as_const_int(e);
if (v) {
}
// get a small constant int
-inline bool GetConstInt(Expr e, int* out) {
+inline bool GetConstInt(PrimExpr e, int* out) {
int64_t v1 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
}
template<>
-inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b;
}
template<>
-inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::SubNode>(PrimExpr a, PrimExpr b) {
return a - b;
}
template<>
-inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MulNode>(PrimExpr a, PrimExpr b) {
return a * b;
}
template<>
-inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::DivNode>(PrimExpr a, PrimExpr b) {
return truncdiv(a, b);
}
template<>
-inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::ModNode>(PrimExpr a, PrimExpr b) {
return truncmod(a, b);
}
template<>
-inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MaxNode>(PrimExpr a, PrimExpr b) {
return max(a, b);
}
template<>
-inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MinNode>(PrimExpr a, PrimExpr b) {
return min(a, b);
}
template<typename Op>
-inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
+inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
if (values.size() == 0U) {
CHECK(empty_value.defined());
return empty_value;
}
- Expr res = values[0];
+ PrimExpr res = values[0];
for (size_t i = 1; i < values.size(); ++i) {
res = Compute<Op>(res, values[i]);
}
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
-inline Expr TryConstFold(Expr a, Expr b) {
- return Expr();
+inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) {
+ return PrimExpr();
}
/*!
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template<typename Op>
-inline Expr TryConstFold(Expr a);
+inline PrimExpr TryConstFold(PrimExpr a);
/*!
* \brief Check whether type is used to represent index.
// specialization of constant folders.
template<>
-inline Expr TryConstFold<ir::AddNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::SubNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a;
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::MulNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
if (fb->value == 0) return b;
}
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::DivNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::ModNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::FloorDivNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(fb->value, 0) << "Divide by zero";
}
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::FloorModNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) {
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
}
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::MinNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::MaxNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::GTNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::GENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::LTNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::LENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::EQNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::NENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
});
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::AndNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
const UIntImmNode* pb = b.as<UIntImmNode>();
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
if (pb && !pb->value) return b;
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::OrNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
const UIntImmNode* pb = b.as<UIntImmNode>();
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
if (pb && !pb->value) return a;
- return Expr();
+ return PrimExpr();
}
template<>
-inline Expr TryConstFold<ir::NotNode>(Expr a) {
+inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) {
using ir::UIntImmNode;
const UIntImmNode* pa = a.as<UIntImmNode>();
if (pa) {
return UIntImmNode::make(DataType::UInt(1), !(pa->value));
}
- return Expr();
+ return PrimExpr();
}
/*! \brief Helper namespace for symbolic value limits */
struct SymbolicLimits {
/*! \brief positive infinity */
- static Expr pos_inf_;
+ static PrimExpr pos_inf_;
/*! \brief negative infinity */
- static Expr neg_inf_;
+ static PrimExpr neg_inf_;
};
/*!
*
* \return positive infinity.
*/
-inline Expr pos_inf() {
+inline PrimExpr pos_inf() {
return SymbolicLimits::pos_inf_;
}
*
* \return The check result.
*/
-inline bool is_pos_inf(const Expr& value) {
+inline bool is_pos_inf(const PrimExpr& value) {
return value.same_as(SymbolicLimits::pos_inf_);
}
*
* \return negative infinity.
*/
-inline Expr neg_inf() {
+inline PrimExpr neg_inf() {
return SymbolicLimits::neg_inf_;
}
*
* \return The check result.
*/
-inline bool is_neg_inf(const Expr& value) {
+inline bool is_neg_inf(const PrimExpr& value) {
return value.same_as(SymbolicLimits::neg_inf_);
}
};
class ConstIntBoundAnalyzer::Impl :
- public ExprFunctor<ConstIntBoundAnalyzer::Entry(const Expr&)> {
+ public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
public:
/*! \brief additional bound info about expr \in bound */
struct BoundInfo {
/*! \brief The expr */
- Expr expr;
+ PrimExpr expr;
/*! \brief The additional bound */
Entry bound;
BoundInfo() {}
- BoundInfo(Expr expr, Entry bound)
+ BoundInfo(PrimExpr expr, Entry bound)
: expr(expr), bound(bound) {
}
};
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final {
return Everything(
- static_cast<const ExprNode*>(op)->dtype);
+ static_cast<const PrimExprNode*>(op)->dtype);
}
- Entry VisitExpr(const Expr& expr) final {
+ Entry VisitExpr(const PrimExpr& expr) final {
Entry res = ExprFunctor::VisitExpr(expr);
// a linear search over additional info
// assume we won't have a lot of conditions
}
}
- std::function<void()> EnterConstraint(const Expr& constraint) {
+ std::function<void()> EnterConstraint(const PrimExpr& constraint) {
std::vector<BoundInfo> info = DetectBoundInfo(constraint);
if (info.size() == 0) return nullptr;
size_t old_size = additional_info_.size();
private:
// internal variable map
- std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
+ std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited
* \param cond The constraint condition.
* \return List of detected bounds.
*/
- static std::vector<BoundInfo> DetectBoundInfo(const Expr& cond) {
- PVar<Expr> x, y;
+ static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
+ PVar<PrimExpr> x, y;
PVar<Integer> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
}
};
-ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
+ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ConstIntBound(ret.min_value, ret.max_value);
}
impl_->Bind(var, range);
}
-std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) {
+std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
// Linear equation, the components can be undefined.
struct LinearEqEntry {
- Expr base;
- Expr coeff;
+ PrimExpr base;
+ PrimExpr coeff;
};
struct IntervalEntry {
- Expr min_value;
- Expr max_value;
+ PrimExpr min_value;
+ PrimExpr max_value;
};
class LinearEqDetector
- : public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
+ : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr &)> {
public:
explicit LinearEqDetector(Var var)
: var_(var) {}
- bool Detect(const Expr& e, LinearEqEntry* ret) {
+ bool Detect(const PrimExpr& e, LinearEqEntry* ret) {
*ret = VisitExpr(e, e);
if (fail_) return false;
if (!ret->base.defined()) {
return true;
}
- LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
return ret;
}
- LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
return ret;
}
- LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
LinearEqEntry a = VisitExpr(op->a, op->a);
LinearEqEntry b = VisitExpr(op->b, op->b);
ret.coeff = MulCombine(a.base, b.coeff);
return ret;
}
- LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
+ LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
LinearEqEntry ret;
if (op == var_.get()) {
ret.coeff = make_const(op->dtype, 1);
}
return ret;
}
- LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final {
+ LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
fail_ = true;
Var var_;
bool fail_{false};
// Combine by add
- Expr AddCombine(Expr a, Expr b) {
+ PrimExpr AddCombine(PrimExpr a, PrimExpr b) {
if (!a.defined()) return b;
if (!b.defined()) return a;
return a + b;
}
- Expr SubCombine(Expr a, Expr b) {
+ PrimExpr SubCombine(PrimExpr a, PrimExpr b) {
// Check b first in case they are both undefined
if (!b.defined()) return a;
if (!a.defined()) return -b;
return a - b;
}
- Expr MulCombine(Expr a, Expr b) {
+ PrimExpr MulCombine(PrimExpr a, PrimExpr b) {
if (!a.defined()) return a;
if (!b.defined()) return b;
return a * b;
}
};
-Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
- Expr base = e;
- Array<Expr> coeff;
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
+ const Array<Var>& vars) {
+ PrimExpr base = e;
+ Array<PrimExpr> coeff;
for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
- return Array<Expr>();
+ return Array<PrimExpr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
- return Array<Expr>();
+ return Array<PrimExpr>();
}
}
coeff.push_back(base);
// Detect clip condition as min max value
bool DetectClipBound(
- const Expr& cond,
+ const PrimExpr& cond,
std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
int flag = 0;
Var var;
PostOrderVisit(cond, fvisit);
if (flag != 1) return false;
// canonical form: exp >= 0
- Expr canonical;
+ PrimExpr canonical;
if (const LTNode* op = cond.as<LTNode>()) {
if (!op->a.dtype().is_int()) return false;
canonical = op->b - op->a - make_const(op->a.dtype(), 1);
template<typename OP>
-void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
+void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
if (const OP* op = e.as<OP>()) {
SplitCommExpr<OP>(op->a, ret);
SplitCommExpr<OP>(op->b, ret);
// Detect the lower and upper bound from the expression.
// e must be connected by and.
-Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
- std::vector<Expr> splits;
+Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
+ std::vector<PrimExpr> splits;
SplitCommExpr<ir::AndNode>(e, &splits);
std::unordered_map<const VarNode*, IntervalEntry> rmap;
for (Var v : vars) {
rmap[v.get()] = IntervalEntry();
}
- for (Expr cond : splits) {
- if (!DetectClipBound(cond, &rmap)) return Array<Expr>();
+ for (PrimExpr cond : splits) {
+ if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
}
- Array<Expr> ret;
+ Array<PrimExpr> ret;
for (Var v : vars) {
IntervalEntry e = rmap[v.get()];
if (e.min_value.defined()) {
}
private:
- void Touch(const Array<Expr>& args) {
+ void Touch(const Array<PrimExpr>& args) {
if (args.size() > bounds_.size()) {
bounds_.resize(args.size());
}
namespace tvm {
namespace arith {
-Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
-Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
+PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
+PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
-IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
+IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) {
auto node = make_object<IntervalSetNode>();
node->min_value = std::move(min_value);
node->max_value = std::move(max_value);
data_ = std::move(node);
}
-IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) {
+IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) {
return IntervalSet(min_value, max_value);
}
IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
- Expr max_value = min(a->max_value, b->max_value);
- Expr min_value = max(a->min_value, b->min_value);
+ PrimExpr max_value = min(a->max_value, b->max_value);
+ PrimExpr min_value = max(a->min_value, b->min_value);
if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
(min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
}
IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
- Expr max_value = max(a->max_value, b->max_value);
- Expr min_value = min(a->min_value, b->min_value);
+ PrimExpr max_value = max(a->max_value, b->max_value);
+ PrimExpr min_value = min(a->min_value, b->min_value);
return IntervalSet(min_value, max_value);
}
IntervalSet a,
IntervalSet b) {
if (a->IsSinglePoint() && b->IsSinglePoint()) {
- Expr res = TryConstFold<Op>(a->min_value, b->min_value);
+ PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
if (!res.defined()) res = Op::make(a->min_value, b->min_value);
return IntervalSet::SinglePoint(res);
}
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
- Expr min_value =
+ PrimExpr min_value =
a->HasLowerBound() && b->HasLowerBound() ?
a->min_value + b->min_value : neg_inf();
- Expr max_value =
+ PrimExpr max_value =
a->HasUpperBound() && b->HasUpperBound() ?
a->max_value + b->max_value : pos_inf();
return IntervalSet(min_value, max_value);
}
if (a->IsEmpty()) return a;
if (b->IsEmpty()) return b;
- Expr min_value =
+ PrimExpr min_value =
a->HasLowerBound() && b->HasUpperBound() ?
a->min_value - b->max_value : neg_inf();
- Expr max_value =
+ PrimExpr max_value =
a->HasUpperBound() && b->HasLowerBound() ?
a->max_value - b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
if (is_zero(b->min_value)) return b;
if (is_one(b->min_value)) return a;
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
- Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
- Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
+ PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
+ PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
- Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
- Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
+ PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
+ PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode;
- Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
- Expr e1 = a->min_value * b->min_value;
- Expr e2 = a->max_value * b->min_value;
+ PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+ PrimExpr e1 = a->min_value * b->min_value;
+ PrimExpr e2 = a->max_value * b->min_value;
return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
- Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
- Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
+ PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
+ PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
- Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
- Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
+ PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
+ PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode;
- Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
- Expr e1 = a->min_value / b->min_value;
- Expr e2 = a->max_value / b->min_value;
+ PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+ PrimExpr e1 = a->min_value / b->min_value;
+ PrimExpr e2 = a->max_value / b->min_value;
return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
- const Expr& divisor = b->min_value;
+ const PrimExpr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
- Expr bound = abs(divisor) - 1;
+ PrimExpr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
- Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
- Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
+ PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
+ PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
- Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
- Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
+ PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
+ PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
return IntervalSet(min_value, max_value);
} else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::SelectNode;
- Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
- Expr e1 = floordiv(a->min_value, b->min_value);
- Expr e2 = floordiv(a->max_value, b->min_value);
+ PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+ PrimExpr e1 = floordiv(a->min_value, b->min_value);
+ PrimExpr e2 = floordiv(a->max_value, b->min_value);
return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
}
}
if (b->IsEmpty()) return b;
if (b->IsSinglePoint()) {
- const Expr& divisor = b->min_value;
+ const PrimExpr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
- Expr bound = abs(divisor) - 1;
+ PrimExpr bound = abs(divisor) - 1;
return IntervalSet(-bound, bound);
}
}
// Simplified version of int set evaluator that operates on IntervalSet
// We might use better set analysis in the future to replace the intervalset.
class IntervalSetEvaluator :
- public ExprFunctor<IntervalSet(const Expr&)> {
+ public ExprFunctor<IntervalSet(const PrimExpr&)> {
public:
IntervalSetEvaluator(Analyzer* analyzer,
const Map<Var, IntSet>& dom_map,
eval_vec_(eval_vec) {
}
- IntervalSet Eval(const Expr& val) {
+ IntervalSet Eval(const PrimExpr& val) {
return this->VisitExpr(val);
}
// evaluate and relax the set
}
IntervalSet VisitExpr_(const IntImmNode* op) final {
- return IntervalSet::SinglePoint(GetRef<Expr>(op));
+ return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}
IntervalSet VisitExpr_(const UIntImmNode* op) final {
- return IntervalSet::SinglePoint(GetRef<Expr>(op));
+ return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}
IntervalSet VisitExpr_(const VarNode* op) final {
IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
}
}
- DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<Expr>(op);
+ DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op);
return IntervalSet::Everything();
}
private:
// whether set is exactly single point that equals value.
bool MatchPoint(const IntervalSet& set,
- const Expr& value) const {
+ const PrimExpr& value) const {
return set->min_value.same_as(value) && set->max_value.same_as(value);
}
IntervalSet a = this->Eval(op->a);
IntervalSet b = this->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
- return IntervalSet::SinglePoint(GetRef<Expr>(op));
+ return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}
return Combine<T>(analyzer_, a, b);
}
: analyzer_(analyzer) {
}
- IntSet Eval(const Expr& expr, const Map<Var, IntSet>& dom_map) const {
+ IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
}
delete impl_;
}
-IntSet IntSetAnalyzer::operator()(const Expr& expr,
+IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
const Map<Var, IntSet>& dom_map) {
return impl_->Eval(expr, dom_map);
}
return max_range;
}
-Expr IntSet::min() const {
+PrimExpr IntSet::min() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int);
return s_int->min_value;
}
-Expr IntSet::max() const {
+PrimExpr IntSet::max() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int);
return s_int->max_value;
return kUnknown;
}
}
-Expr IntSet::point_value() const {
+PrimExpr IntSet::point_value() const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
CHECK(s_int && s_int->IsSinglePoint());
return s_int->min_value;
return IntervalSet::Everything();
}
-IntSet IntSet::single_point(Expr x) {
+IntSet IntSet::single_point(PrimExpr x) {
return IntervalSet::SinglePoint(x);
}
-IntSet IntSet::interval(Expr min, Expr max) {
+IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
if (min.same_as(max)) {
return IntSet::single_point(min);
}
}
// Range related code
-inline bool ProveEqual(Expr lhs, Expr rhs) {
+inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
return dmap;
}
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
const Map<Var, IntSet>& dom_map) {
Analyzer ana;
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
}
-IntSet IntSet::vector(Expr x) {
+IntSet IntSet::vector(PrimExpr x) {
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
}
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map) {
return EvalSet(e, ConvertDomMap(dom_map));
}
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map) {
return EvalSet(e, ConvertDomMap(dom_map));
}
Analyzer ana;
IntervalSetEvaluator m(&ana, dom_map);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
- Expr sum = r->min + r->extent - 1;
+ PrimExpr sum = r->min + r->extent - 1;
auto res = m.Eval(IntervalSet(r->min, Simplify(sum)));
return std::move(res);
}
auto dmap = ConvertDomMap(dom_map);
IntervalSetEvaluator m(&ana, dmap);
const IntervalSetNode* s_int = s.as<IntervalSetNode>();
- Expr vmax = s_int->HasUpperBound() ?
+ PrimExpr vmax = s_int->HasUpperBound() ?
m.Eval(s_int->max_value).max() : s_int->max_value;
- Expr vmin = s_int->HasLowerBound() ?
+ PrimExpr vmin = s_int->HasLowerBound() ?
m.Eval(s_int->min_value).min() : s_int->min_value;
return IntervalSet(vmin, vmax);
}
const Map<Var, IntSet>& dom_map)
: IntervalSetEvaluator(analyzer, dom_map) {}
- IntervalSet VisitExpr(const Expr& n) final {
+ IntervalSet VisitExpr(const PrimExpr& n) final {
IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
expr_map[n] = ret;
return ret;
};
ExprIntSetMap EvalSetForEachSubExpr(
- Expr e,
+ PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map) {
Analyzer ana;
auto dmap = ConvertDomMap(dom_map);
class IntervalSetNode : public IntSetNode {
public:
/*! \brief Minimum value in the interval. */
- Expr min_value;
+ PrimExpr min_value;
/*! \brief Maximum value in the interval. */
- Expr max_value;
+ PrimExpr max_value;
// visitor overload.
void VisitAttrs(tvm::AttrVisitor* v) {
* \param max_value The maximum value in the interval.
* \return The created set.
*/
- TVM_DLL IntervalSet(Expr min_value, Expr max_value);
+ TVM_DLL IntervalSet(PrimExpr min_value, PrimExpr max_value);
/*!
* \brief Create an IntervalSet that represents a single point.
* \param value The value to be represented.
* \return The result set.
*/
- static IntervalSet SinglePoint(Expr value) {
+ static IntervalSet SinglePoint(PrimExpr value) {
return IntervalSet(value, value);
}
/*!
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const LetStmtNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const IfThenElseNode* op) {
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, condition);
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AssertStmtNode* op) {
- Expr condition = this->VisitExpr(op->condition);
- Expr message = this->VisitExpr(op->message);
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr message = this->VisitExpr(op->message);
With<ConstraintContext> ctx(analyzer_, condition);
Stmt body = this->VisitStmt(op->body);
}
}
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
- Expr cond = this->VisitExpr(op->args[0]);
- Expr true_value, false_value;
+ PrimExpr cond = this->VisitExpr(op->args[0]);
+ PrimExpr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = this->VisitExpr(op->args[1]);
if (cond.same_as(op->args[0]) &&
true_value.same_as(op->args[1]) &&
false_value.same_as(op->args[2])) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CallNode::make(op->dtype, op->name,
{cond, true_value, false_value},
return StmtExprMutator::VisitExpr_(op);
}
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const LetNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
// as sub-class may or maynot choose to replace it.
- Expr body = this->VisitExpr(op->body);
+ PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const SelectNode* op) {
- Expr cond = this->VisitExpr(op->condition);
- Expr true_value, false_value;
+ PrimExpr cond = this->VisitExpr(op->condition);
+ PrimExpr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = VisitExpr(op->true_value);
if (cond.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return SelectNode::make(cond, true_value, false_value);
}
}
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
- Expr VisitExpr_(const ir::LetNode* op) override;
- Expr VisitExpr_(const ir::SelectNode* op) override;
- Expr VisitExpr_(const ir::CallNode* op) override;
- Expr VisitExpr_(const ir::ReduceNode* op) override;
+ PrimExpr VisitExpr_(const ir::LetNode* op) override;
+ PrimExpr VisitExpr_(const ir::SelectNode* op) override;
+ PrimExpr VisitExpr_(const ir::CallNode* op) override;
+ PrimExpr VisitExpr_(const ir::ReduceNode* op) override;
protected:
/*! \brief internal analyzer field. */
class IRVisitorWithAnalyzer final : public StmtExprVisitor {
public:
- Expr Simplify(const Expr& expr) {
+ PrimExpr Simplify(const PrimExpr& expr) {
return analyzer_.Simplify(expr);
}
};
class ModularSetAnalyzer::Impl :
- public ExprFunctor<ModularSetAnalyzer::Entry(const Expr&)> {
+ public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
public:
explicit Impl(Analyzer* parent)
: parent_(parent) {}
}
// Detect useful constraints and use them in the analysis scope.
- std::function<void()> EnterConstraint(const Expr& constraint) {
+ std::function<void()> EnterConstraint(const PrimExpr& constraint) {
PVar<Var> var;
PVar<Integer> coeff, base;
// pattern match interesting constraints
return Entry(coeff, a.base * b.base);
}
- Entry DivByConst(const Expr& lhs,
+ Entry DivByConst(const PrimExpr& lhs,
int64_t val,
bool round_down) {
Entry a = VisitExpr(lhs);
/*! \brief pointer to parent. */
Analyzer* parent_{nullptr};
// internal variable map
- std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
+ std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
/*!
* \brief Update var by intersecting entry with var's current set.
* \param var The variable.
}
};
-ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
+ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ModularSet(ret.coeff, ret.base);
}
impl_->Update(var, info, override);
}
-std::function<void()> ModularSetAnalyzer::EnterConstraint(const Expr& constraint) {
+std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
};
template<>
-class PEqualChecker<Expr> {
+class PEqualChecker<PrimExpr> {
public:
- bool operator()(const Expr& lhs, const Expr& rhs) const {
+ bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.same_as(rhs)) return true;
return ir::Equal(lhs, rhs);
}
}
}
- Expr Eval() const {
- Expr lhs = a_.Eval();
- Expr rhs = b_.Eval();
- Expr ret = TryConstFold<NodeType>(lhs, rhs);
+ PrimExpr Eval() const {
+ PrimExpr lhs = a_.Eval();
+ PrimExpr rhs = b_.Eval();
+ PrimExpr ret = TryConstFold<NodeType>(lhs, rhs);
if (ret.defined()) return ret;
return NodeType::make(lhs, rhs);
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return make_const(ref_.Eval().dtype(), value_);
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return ir::NotNode::make(value_.Eval());
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return ir::SelectNode::make(
condition_.Eval(), true_value_.Eval(), false_value_.Eval());
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return ir::CastNode::make(dtype_.Eval(), value_.Eval());
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
}
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
}
};
struct PCallExprEvalArgsFunctor {
- Array<Expr> args_;
+ Array<PrimExpr> args_;
template<typename T>
void operator()(size_t i, const T& pattern) {
}
}
- Expr Eval() const {
+ PrimExpr Eval() const {
detail::PCallExprEvalArgsFunctor feval_args;
detail::tuple_for_each(feval_args, args_);
return Op::Eval(feval_args.args_);
// arithemetic intrinsics
#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
- static Expr Eval(Array<Expr> args) { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
return ir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \
} \
// unary intrinsics
#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
struct OpName { \
- static Expr Eval(Array<Expr> args) { \
+ static PrimExpr Eval(Array<PrimExpr> args) { \
return ir::CallNode::make(args[0].dtype(), kName, args, \
ir::CallNode::PureIntrinsic); \
} \
// if_then_else
struct PIfThenElseOp {
- static Expr Eval(Array<Expr> args) {
+ static PrimExpr Eval(Array<PrimExpr> args) {
return ir::CallNode::make(
args[1].dtype(), kName, args,
ir::CallNode::PureIntrinsic);
// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
-TryCompare(const Expr& x, int64_t val) {
- Expr diff = this->VisitExpr(x);
+TryCompare(const PrimExpr& x, int64_t val) {
+ PrimExpr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
return kEQ;
}
void RewriteSimplifier::Impl::
-Update(const Var& var, const Expr& info, bool override) {
+Update(const Var& var, const PrimExpr& info, bool override) {
if (!override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
var_map_[var] = info;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const AddNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AddNode>();
- Expr const_res = TryConstFold<AddNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<AddNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1, b2, s1, s2;
+ PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) {
+std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
size_t old_literal_size = literal_constraints_.size();
literal_constraints_.push_back(constraint);
size_t new_literal_size = literal_constraints_.size();
return frecover;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const SubNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SubNode>();
- Expr const_res = TryConstFold<SubNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<SubNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1, b2, s1, s2;
+ PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const MulNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MulNode>();
- Expr const_res = TryConstFold<MulNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<MulNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1, b2, s1, s2;
+ PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const DivNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<DivNode>();
- Expr const_res = TryConstFold<DivNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<DivNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1;
+ PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const ModNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<ModNode>();
- Expr const_res = TryConstFold<ModNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<ModNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1;
+ PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
// NOTE: trunc div required
TVM_TRY_RECURSIVE_REWRITE_IF(
truncmod(x, c1),
- truncmod(x, PConst<Expr>(make_const(op->dtype, -c1.Eval()->value))),
+ truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))),
c1.Eval()->value < 0);
// try modular analysis
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const FloorDivNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
- Expr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1;
+ PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const FloorModNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
- Expr const_res = TryConstFold<FloorModNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<FloorModNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, b1;
+ PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<Integer> c1, c2;
// Pattern var for lanes in broadcast and ramp
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const MinNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MinNode>();
- Expr const_res = TryConstFold<MinNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<MinNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, s1, s2;
+ PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const MaxNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MaxNode>();
- Expr const_res = TryConstFold<MaxNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<MaxNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, s1, s2;
+ PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const EQNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
- Expr const_res = TryConstFold<EQNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<EQNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<Integer> c1;
PVar<int> lanes;
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const NENode* op) {
return this->VisitExpr(NotNode::make(op->a == op->b));
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const LENode* op) {
return this->VisitExpr(NotNode::make(op->b < op->a));
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const GTNode* op) {
return this->VisitExpr(op->b < op->a);
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const GENode* op) {
return this->VisitExpr(NotNode::make(op->a < op->b));
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const LTNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LTNode>();
- Expr const_res = TryConstFold<LTNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<LTNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y, z, s1, s2;
+ PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const NotNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
- Expr const_res = TryConstFold<NotNode>(op->a);
+ PrimExpr const_res = TryConstFold<NotNode>(op->a);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const AndNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AndNode>();
- Expr const_res = TryConstFold<AndNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<AndNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
broadcast(x && y, lanes));
}
- auto cfalse = PConst<Expr>(make_const(op->dtype, false));
+ auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const OrNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<OrNode>();
- Expr const_res = TryConstFold<OrNode>(op->a, op->b);
+ PrimExpr const_res = TryConstFold<OrNode>(op->a, op->b);
if (const_res.defined()) return const_res;
// Pattern var to match any expression
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<Integer> c1, c2;
PVar<int> lanes;
broadcast(x || y, lanes));
}
- auto ctrue = PConst<Expr>(make_const(op->dtype, true));
+ auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));
TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const SelectNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SelectNode>();
if (op == nullptr) return ret;
// Pattern var to match any expression
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
TVM_TRY_REWRITE(select(x, y, y), y);
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
if (op == nullptr) return ret;
if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) {
return ret;
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
}
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const CastNode* op) {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CastNode>();
return cast(op->dtype, op->value);
}
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
VisitExpr_(const LetNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
return this->VisitExpr(op->body);
}
- Expr body = this->VisitExpr(op->body);
+ PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
-Expr RewriteSimplifier::operator()(const Expr& expr) {
+PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) {
// Run simplification in post order
- Expr res = expr;
+ PrimExpr res = expr;
int max_iter = 2;
for (int i = 0; i < max_iter; ++i) {
- Expr new_expr = impl_->operator()(res);
+ PrimExpr new_expr = impl_->operator()(res);
if (new_expr.same_as(res)) return res;
res = new_expr;
}
}
void RewriteSimplifier::Update(const Var& var,
- const Expr& info,
+ const PrimExpr& info,
bool override) {
impl_->Update(var, info, override);
}
-std::function<void()> RewriteSimplifier::EnterConstraint(const Expr& constraint) {
+std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
return impl_->EnterConstraint(constraint);
}
explicit Impl(Analyzer* parent)
: IRMutatorWithAnalyzer(parent) {}
- void Update(const Var& var, const Expr& info, bool override_info);
- Expr VisitExpr_(const AddNode* op) override;
- Expr VisitExpr_(const SubNode* op) override;
- Expr VisitExpr_(const MulNode* op) override;
- Expr VisitExpr_(const DivNode* op) override;
- Expr VisitExpr_(const ModNode* op) override;
- Expr VisitExpr_(const FloorDivNode* op) override;
- Expr VisitExpr_(const FloorModNode* op) override;
- Expr VisitExpr_(const MinNode* op) override;
- Expr VisitExpr_(const MaxNode* op) override;
- Expr VisitExpr_(const EQNode* op) override;
- Expr VisitExpr_(const NENode* op) override;
- Expr VisitExpr_(const LTNode* op) override;
- Expr VisitExpr_(const LENode* op) override;
- Expr VisitExpr_(const GTNode* op) override;
- Expr VisitExpr_(const GENode* op) override;
- Expr VisitExpr_(const AndNode* op) override;
- Expr VisitExpr_(const OrNode* op) override;
- Expr VisitExpr_(const NotNode* op) override;
- Expr VisitExpr_(const SelectNode* op) override;
- Expr VisitExpr_(const CallNode* op) override;
- Expr VisitExpr_(const VarNode* op) override;
- Expr VisitExpr_(const CastNode* op) override;
- Expr VisitExpr_(const LetNode* op) override;
-
- std::function<void()> EnterConstraint(const Expr& constraint);
+ void Update(const Var& var, const PrimExpr& info, bool override_info);
+ PrimExpr VisitExpr_(const AddNode* op) override;
+ PrimExpr VisitExpr_(const SubNode* op) override;
+ PrimExpr VisitExpr_(const MulNode* op) override;
+ PrimExpr VisitExpr_(const DivNode* op) override;
+ PrimExpr VisitExpr_(const ModNode* op) override;
+ PrimExpr VisitExpr_(const FloorDivNode* op) override;
+ PrimExpr VisitExpr_(const FloorModNode* op) override;
+ PrimExpr VisitExpr_(const MinNode* op) override;
+ PrimExpr VisitExpr_(const MaxNode* op) override;
+ PrimExpr VisitExpr_(const EQNode* op) override;
+ PrimExpr VisitExpr_(const NENode* op) override;
+ PrimExpr VisitExpr_(const LTNode* op) override;
+ PrimExpr VisitExpr_(const LENode* op) override;
+ PrimExpr VisitExpr_(const GTNode* op) override;
+ PrimExpr VisitExpr_(const GENode* op) override;
+ PrimExpr VisitExpr_(const AndNode* op) override;
+ PrimExpr VisitExpr_(const OrNode* op) override;
+ PrimExpr VisitExpr_(const NotNode* op) override;
+ PrimExpr VisitExpr_(const SelectNode* op) override;
+ PrimExpr VisitExpr_(const CallNode* op) override;
+ PrimExpr VisitExpr_(const VarNode* op) override;
+ PrimExpr VisitExpr_(const CastNode* op) override;
+ PrimExpr VisitExpr_(const LetNode* op) override;
+
+ std::function<void()> EnterConstraint(const PrimExpr& constraint);
protected:
/*! \brief internal structure for comparison. */
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
- std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
+ std::unordered_map<Var, PrimExpr, ObjectHash, ObjectEqual> var_map_;
- std::vector<Expr> literal_constraints_;
+ std::vector<PrimExpr> literal_constraints_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;
* \param val The constant value.
* \return comparison result.
*/
- CompareResult TryCompare(const Expr& x, int64_t val);
+ CompareResult TryCompare(const PrimExpr& x, int64_t val);
private:
// Whether x >= val
- bool CanProveGreaterEqual(const Expr& x, int64_t val) {
+ bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
return analyzer_->CanProveGreaterEqual(x, val);
}
// Whether x == val
- bool CanProveEqual(const Expr& x, int64_t val) {
+ bool CanProveEqual(const PrimExpr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
}
// Recursive rewrite x
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
- Expr RecursiveRewrite(const Expr& x) {
+ PrimExpr RecursiveRewrite(const PrimExpr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
- Expr res = this->VisitExpr(x);
+ PrimExpr res = this->VisitExpr(x);
--recur_depth_;
return res;
}
using Parent::VisitStmt;
using Parent::VisitStmt_;
- Expr VisitExpr(const Expr& expr) final {
+ PrimExpr VisitExpr(const PrimExpr& expr) final {
return analyzer_->Simplify(expr);
}
}
Stmt VisitStmt_(const LetStmtNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (!ir::HasSideEffect(value)) {
// it is fine to discard the let binding
// because the call to simplify will always inline the var.
return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
}
-Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
+PrimExpr CanonicalSimplify(PrimExpr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
return analyzer.canonical_simplify(expr);
}
-Expr Simplify(Expr expr, Map<Var, Range> vrange) {
+PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
arith::Analyzer analyzer;
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
- VarExpr var = op->node.as<tvm::IterVarNode>()->var;
+ Var var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
* \param ann_type The type for the for loop
* \return skip Whether skip this node
*/
- virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0;
+ virtual bool EnterItervar_(tvm::Var var, int64_t length, AnnotationType ann_type) = 0;
/*! \brief Exit a for loop subtree */
virtual void ExitItervar_() = 0;
/*!
* \param buffer_var The buffer to access.
* \param index Index expression
*/
- virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0;
+ virtual void EnterMem_(tvm::Var buffer_var, tvm::PrimExpr index) = 0;
/*! \brief Exit a memory access node */
virtual void ExitMem_() = 0;
};
// get touch pattern from index expression
class IndexParser: public ExprVisitor {
public:
- void Parse(Expr expr) {
+ void Parse(PrimExpr expr) {
pattern_map.clear();
this->VisitExpr(expr);
}
};
// extract iter vars and their touch pattern from ir
-bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type) {
+bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_type) {
// do not insert duplicated occurrences of virtual thread
if (ann_type == kVirtualThread && itervar_map.count(var) != 0) {
skip_stack_size_.push_back(itervar_stack_.size());
// these happens when we create tvm.thread_axis("threadIdx.x") once and
// bind it twice. Here we treat them as two axes
// so we create a snapshot for the old one and freeze it
- VarExpr old = VarExpr(var.get()->name_hint);
+ Var old = Var(var.get()->name_hint);
itervar_map.insert({old, itervar_map[var]});
itervar_map.erase(var);
}
skip_stack_size_.pop_back();
return;
}
- VarExpr var = itervar_stack_.back();
+ Var var = itervar_stack_.back();
// update count and reuse ratio for upper iter vars (includes self)
for (auto kv : itervar_map[var].touch_feature) {
}
}
-void TouchExtractor::EnterMem_(VarExpr buffer_var, Expr index) {
+void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) {
std::string name = buffer_var.get()->name_hint;
TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++);
* \note If you want to flatten these features as the input of your model,
* You can use the faster one GetItervarFeatureFlatten below.
*/
-void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *ret_feature) {
+void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > *ret_feature) {
// extract
TouchExtractor touch_analyzer;
touch_analyzer.Analyze(stmt);
// sort according to order
- std::vector<VarExpr> vars;
+ std::vector<Var> vars;
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
// serialize for front end
for (auto var : vars) {
- Array<Array<Expr> > feature_row;
+ Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
- feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
+ feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
- Array<Expr> attr{std::string("_attr_"),
+ Array<PrimExpr> attr{std::string("_attr_"),
FloatImmNode::make(DataType::Float(32), trans(fea.length)),
IntImmNode::make(DataType::Int(32), fea.nest_level),
FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
feature_row.push_back(attr);
// arithmetic
- feature_row.push_back(Array<Expr>{std::string("_arith_"),
+ feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
- Array<Expr>{k,
+ Array<PrimExpr>{k,
FloatImmNode::make(DataType::Float(32), trans(v.stride)),
FloatImmNode::make(DataType::Float(32), trans(v.mod)),
FloatImmNode::make(DataType::Float(32), trans(v.count)),
touch_analyzer.Analyze(stmt);
// sort according to order
- std::vector<VarExpr> vars;
+ std::vector<Var> vars;
for (auto kv : touch_analyzer.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
});
touch_ext.Analyze(stmt);
// sort according to order
- std::vector<VarExpr> vars;
+ std::vector<Var> vars;
for (auto kv : touch_ext.itervar_map) {
vars.push_back(kv.first);
}
- std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+ std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
});
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
- Array<Array<Array<Expr > > > ret_feature;
+ Array<Array<Array<PrimExpr > > > ret_feature;
GetItervarFeature(stmt, take_log, &ret_feature);
// all the feature of an iter var
struct ItervarFeature {
- ItervarFeature(VarExpr var,
+ ItervarFeature(Var var,
int64_t extent,
int nest,
AnnotationType ann_type,
FeatureVisitor::VisitExpr_(op);
}
- std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
+ std::unordered_map<Var, ItervarFeature, tvm::ObjectHash, tvm::ObjectEqual> itervar_map;
private:
- bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type);
+ bool EnterItervar_(Var var, int64_t length, AnnotationType ann_type);
void ExitItervar_();
- void EnterMem_(VarExpr buffer_var, Expr index);
+ void EnterMem_(Var buffer_var, PrimExpr index);
void ExitMem_();
int64_t topdown_product_{1};
std::map<std::string, size_t> buffer_counter_;
size_t itervar_counter_{0};
- std::deque<VarExpr> itervar_stack_; // use deque instead of stack for indexing
+ std::deque<Var> itervar_stack_; // use deque instead of stack for indexing
std::deque<size_t> skip_stack_size_;
using FeatureVisitor::VisitExpr_;
}
}
-Buffer BufferWithOffsetAlignment(Array<Expr> shape,
+Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
DataType dtype,
std::string name,
int data_alignment,
}
BufferType buffer_type = has_any ? kAutoBroadcast : kDefault;
- Expr elem_offset;
+ PrimExpr elem_offset;
if (offset_factor != 0) {
elem_offset = Var(name + "_elem_offset", shape[0].dtype());
} else {
- elem_offset = Expr();
+ elem_offset = PrimExpr();
}
- return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
+ return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
data_alignment, offset_factor, buffer_type);
}
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
- Array<Expr> tags = args[2];
+ Array<PrimExpr> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
return decl_stream.str() + stream.str();
}
-void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
+void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
if (print_ssa_form_) {
std::ostringstream temp;
VisitExpr(n, temp);
// Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef(
- DataType t, const VarNode* buffer, Expr index) {
+ DataType t, const VarNode* buffer, PrimExpr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(
- DataType t, const Expr& buffer, const Expr& index, int kind) {
+ DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os;
os << "(((TVMArray*)";
}
std::string CodeGenC::GetVecLoad(
- DataType t, const VarNode* buffer, Expr base) {
+ DataType t, const VarNode* buffer, PrimExpr base) {
return GetBufferRef(t, buffer, base);
}
void CodeGenC::PrintVecStore(const VarNode* buffer,
- DataType t, Expr base,
+ DataType t, PrimExpr base,
const std::string& value) {
std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent();
void CodeGenC::PrintVecBinaryOp(
const std::string& op, DataType t,
- Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
+ PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
this->PrintExpr(lhs, os);
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
- Expr base;
+ PrimExpr base;
if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
os << ref;
} else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
- Expr base;
+ PrimExpr base;
if (GetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
* a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
*/
class CodeGenC :
- public ExprFunctor<void(const Expr&, std::ostream&)>,
+ public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)>,
public CodeGenSourceBase {
public:
* \param n The expression to be printed.
* \param os The output stream
*/
- void PrintExpr(const Expr& n, std::ostream& os);
+ void PrintExpr(const PrimExpr& n, std::ostream& os);
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
- std::string PrintExpr(const Expr& n) {
+ std::string PrintExpr(const PrimExpr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, DataType op_type,
- Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
+ PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*)
// print vector load
- virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
+ virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
// print vector store
virtual void PrintVecStore(const VarNode* buffer,
- DataType t, Expr base,
+ DataType t, PrimExpr base,
const std::string& value); // NOLINT(*)
// print load of single element
virtual void PrintVecElemLoad(
protected:
// Print reference to struct location
std::string GetStructRef(
- DataType t, const Expr& buffer, const Expr& index, int kind);
+ DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
// print reference to a buffer as type t in index.
virtual std::string GetBufferRef(
- DataType t, const VarNode* buffer, Expr index);
+ DataType t, const VarNode* buffer, PrimExpr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
void CodeGenCUDA::PrintVecBinaryOp(
const std::string&op, DataType t,
- Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
+ PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, DataType t,
- Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
+ PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(
const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*)
}
void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
- Expr base, std::ostream& os) { // NOLINT(*)
+ PrimExpr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
auto it = alloc_storage_scope_.find(buffer);
PrintExpr(base, os);
}
std::string CodeGenOpenCL::GetVecLoad(
- DataType t, const VarNode* buffer, Expr base) {
+ DataType t, const VarNode* buffer, PrimExpr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
}
void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
- DataType t, Expr base,
+ DataType t, PrimExpr base,
const std::string& value) {
this->PrintIndent();
stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
std::string GetVecLoad(DataType t, const VarNode* buffer,
- Expr base) final;
+ PrimExpr base) final;
void PrintVecStore(const VarNode* buffer,
- DataType t, Expr base,
+ DataType t, PrimExpr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const VarNode* buffer, DataType t,
- Expr base, std::ostream& os); // NOLINT(*)
+ PrimExpr base, std::ostream& os); // NOLINT(*)
std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
// overload visitor
}
// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
-std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
+std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) {
std::ostringstream os;
os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
PrintExpr(index, os);
// Print a reference expression to a buffer.
// Format: texelFetch(buffer, index, 0).r
std::string CodeGenOpenGL::GetBufferRef(
- DataType t, const VarNode* buffer, Expr index) {
+ DataType t, const VarNode* buffer, PrimExpr index) {
CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
void InitFuncState(LoweredFunc f) final;
void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const StoreNode* op) final;
- std::string TexelFetch(const VarNode* buffer, Expr index);
- std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
+ std::string TexelFetch(const VarNode* buffer, PrimExpr index);
+ std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
// Codegen for immediate values
std::string whole_code = cg.Finish();
// Generate source code for compilation.
- Array<Array<Expr> > kernel_info;
+ Array<Array<PrimExpr> > kernel_info;
for (LoweredFunc f : funcs) {
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
}
- kernel_info.push_back(Array<Expr>({f->name, code}));
+ kernel_info.push_back(Array<PrimExpr>({f->name, code}));
}
std::string xclbin;
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
- Expr e = args[0];
+ PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
- Expr e = args[0];
+ PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
// Call pure extern function.
template<typename T>
inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
- Expr e = args[0];
+ PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
std::string name = T()(call->dtype, call->name);
const auto *find_rocm_bitcodes =
tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
- Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
+ Array<PrimExpr> bitcode_files = (*find_rocm_bitcodes)();
for (auto &bitcode : bitcode_files) {
std::string path = bitcode.as<StringImmNode>()->value;
llvm::Value* CreateIntrinsic(const CallNode* op) override;
private:
- Expr ARMPopcount(const CallNode* op);
+ PrimExpr ARMPopcount(const CallNode* op);
};
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
op->args[0].as<UIntImmNode>()->value);
if (id == ::llvm::Intrinsic::ctpop) {
- Expr e = ARMPopcount(op);
+ PrimExpr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
}
}
return CodeGenCPU::CreateIntrinsic(op);
}
-Expr CodeGenARM::ARMPopcount(const CallNode *call) {
+PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
using namespace ir;
- const Expr& e = call->args[2];
+ const PrimExpr& e = call->args[2];
::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
int total_size = call->dtype.bits() * call->dtype.lanes();
if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
- Array<Expr> vcnt_args;
+ Array<PrimExpr> vcnt_args;
vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt_args.push_back(e);
uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
// Interpret input as vector of 8bit values
- Expr input8 = reinterpret(uint8_type, e);
+ PrimExpr input8 = reinterpret(uint8_type, e);
// Popcount 8bit->8bit
const CallNode* c0 = input8.as<CallNode>();
CHECK(c0 != nullptr);
- Array<Expr> vcnt8_args;
+ Array<PrimExpr> vcnt8_args;
vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
- Expr vcnt8 = ir::CallNode::make(
+ PrimExpr vcnt8 = ir::CallNode::make(
uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
- Array<Expr> vcnt16_args;
+ Array<PrimExpr> vcnt16_args;
vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
- Expr vcnt16 = ir::CallNode::make(
+ PrimExpr vcnt16 = ir::CallNode::make(
uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 16) {
return vcnt16;
}
// Accumulation 16->32bit
- Array<Expr> vcnt32_args;
+ Array<PrimExpr> vcnt32_args;
vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
- Expr vcnt32 = ir::CallNode::make(
+ PrimExpr vcnt32 = ir::CallNode::make(
uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
if (call->dtype.bits() == 32) {
return vcnt32;
}
// Accumulation 32->64bit
- Array<Expr> vcnt64_args;
+ Array<PrimExpr> vcnt64_args;
vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
}
llvm::BasicBlock *
-CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
+CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end) {
using llvm::BasicBlock;
CHECK(parallel_env_.num_task.defined());
CHECK(parallel_env_.penv != nullptr);
DataType t = op->extent.dtype();
- Expr num_task = cast(t, parallel_env_.num_task);
- Expr task_id = cast(t, parallel_env_.task_id);
+ PrimExpr num_task = cast(t, parallel_env_.num_task);
+ PrimExpr task_id = cast(t, parallel_env_.task_id);
CHECK(!parallel_env_.in_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse them instead";
parallel_env_.in_parallel_loop = true;
op->loop_var,
op->body);
} else {
- Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
- Expr begin = MinNode::make(task_id * step, op->extent);
- Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
+ PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
+ PrimExpr begin = MinNode::make(task_id * step, op->extent);
+ PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
ConstInt32(1),
private:
// the parallel group information
struct ParallelEnv {
- VarExpr task_id;
- VarExpr num_task;
+ Var task_id;
+ Var num_task;
bool stride_pattern{false};
bool in_parallel_loop{false};
int parallel_loop_count{0};
const Array<Var>& fields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
- llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
+ llvm::BasicBlock *MakeCallPacked(const Array<PrimExpr> &args,
llvm::Value **rvalue,
llvm::Value **ret_tcode, const DataType &r_type,
const int64_t begin, const int64_t end);
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
const VarNode* buffer,
- Expr index,
+ PrimExpr index,
DataType type) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
void CodeGenLLVM::GetAlignment(DataType t,
const VarNode* buf_var,
- const Expr& index,
+ const PrimExpr& index,
int* p_alignment,
int* p_native_bits) {
int max_align_bits = t.bits();
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
- const VarExpr& loop_var,
+ const Var& loop_var,
const Stmt& body) {
using llvm::BasicBlock;
BasicBlock* pre_block = builder_->GetInsertBlock();
addrspace = llvm::dyn_cast<llvm::PointerType>(
ptr->getType())->getAddressSpace();
} else {
- Expr index = r->base / make_const(DataType::Int(32), r->lanes);
+ PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
ptr = CreateBufferVecPtr(
l->dtype, MakeValue(l->buffer_var), MakeValue(index));
addrspace = llvm::dyn_cast<llvm::PointerType>(
}
}
-void CodeGenLLVM::Scalarize(const Expr& e,
+void CodeGenLLVM::Scalarize(const PrimExpr& e,
std::function<void(int i, llvm::Value* v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
- Expr offset = ramp->base + (ramp->stride * i);
+ PrimExpr offset = ramp->base + (ramp->stride * i);
f(i, MakeValue(offset));
}
} else {
llvm::LoadInst* load = builder_->CreateAlignedLoad(
ptr, basic_align, is_volatile);
ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
- AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
+ AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t);
};
this->Scalarize(op->index, f);
return ret;
llvm::StoreInst* store = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, i),
ptr, basic_align, is_volatile);
- AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.dtype());
+ AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype());
};
this->Scalarize(op->index, f);
}
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM :
- public ExprFunctor<llvm::Value* (const Expr&)>,
+ public ExprFunctor<llvm::Value* (const PrimExpr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \param e The expression to be created value for.
* \return created value.
*/
- llvm::Value* MakeValue(const Expr& e) {
+ llvm::Value* MakeValue(const PrimExpr& e) {
return VisitExpr(e);
}
// Short hande code to get a constant int 32
virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
// Scalarize by iterating elements of e.
// f is a callback that takes index and v.
- virtual void Scalarize(const Expr& e,
+ virtual void Scalarize(const PrimExpr& e,
std::function<void(int i, llvm::Value* v)> f);
// Initialize target
virtual void InitTarget(llvm::TargetMachine* tm);
void InitFuncState();
// Get alignment given index.
void GetAlignment(
- DataType t, const VarNode* buf_var, const Expr& index,
+ DataType t, const VarNode* buf_var, const PrimExpr& index,
int* p_alignment, int* p_native_bits);
// Get constant string
llvm::Value* GetConstString(const std::string& str);
void CreateSerialFor(llvm::Value* begin,
llvm::Value* end,
llvm::Value* stride,
- const VarExpr& loop_var, const Stmt& body);
+ const Var& loop_var, const Stmt& body);
// add alias information.
- void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
+ void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type);
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
- Expr e = targs[0];
+ PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
- const Expr& x = call->args[0];
- Expr one = make_const(x.dtype(), 1);
- Expr two = make_const(x.dtype(), 2);
- Expr neg_two = make_const(x.dtype(), -2);
+ const PrimExpr& x = call->args[0];
+ PrimExpr one = make_const(x.dtype(), 1);
+ PrimExpr two = make_const(x.dtype(), 2);
+ PrimExpr neg_two = make_const(x.dtype(), -2);
- Expr exp_neg2x = ir::CallNode::make(
+ PrimExpr exp_neg2x = ir::CallNode::make(
x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
- Expr exp_pos2x = ir::CallNode::make(
+ PrimExpr exp_pos2x = ir::CallNode::make(
x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
- Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
- Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
+ PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
+ PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::SelectNode::make(
x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
});
// num_signature means number of arguments used to query signature
template<unsigned id, int num_signature>
inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
- Expr e = targs[0];
+ PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
- Array<Expr> cargs;
+ Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
- for (Expr arg : call->args) {
+ for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::CallNode::make(
template<unsigned id, int num_signature>
inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
- Expr e = targs[0];
+ PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
- Array<Expr> cargs;
+ Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
- for (Expr arg : call->args) {
+ for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::CallNode::make(
namespace codegen {
inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
- Expr e = args[0];
+ PrimExpr e = args[0];
using namespace ir;
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
namespace codegen {
inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
- Expr e = args[0];
+ PrimExpr e = args[0];
using namespace ir;
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
}
spirv::Value CodeGenSPIRV::GetThreadIndex(
- const IterVar& iv, const Expr& extent) {
+ const IterVar& iv, const PrimExpr& extent) {
runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
spirv::Value v;
if (ts.rank == 1) {
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
- Expr vec_index = ir::Simplify(
+ PrimExpr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
return spirv::Value();
}
-void CodeGenSPIRV::Scalarize(const Expr& e,
+void CodeGenSPIRV::Scalarize(const PrimExpr& e,
std::function<void(int i, spirv::Value v)> f) {
if (const RampNode* ramp = e.as<RampNode>()) {
for (int i = 0; i < ramp->dtype.lanes(); ++i) {
- Expr offset = ramp->base + ramp->stride * i;
+ PrimExpr offset = ramp->base + ramp->stride * i;
f(i, MakeValue(offset));
}
} else {
CHECK((me->coeff % ramp->lanes) == 0 &&
(me->base % ramp->lanes) == 0)
<< "Only aligned vector access is allowed in SPIRV";
- Expr vec_index = ir::Simplify(
+ PrimExpr vec_index = ir::Simplify(
ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
spirv::Value ptr = builder_->StructArrayAccess(
ptr_type, buffer, MakeValue(vec_index));
* \brief Code generator into SPIRV
*/
class CodeGenSPIRV:
- public ExprFunctor<spirv::Value(const Expr&)>,
+ public ExprFunctor<spirv::Value(const PrimExpr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \param e The expression to be created value for.
* \return created value.
*/
- spirv::Value MakeValue(const Expr& e) {
+ spirv::Value MakeValue(const PrimExpr& e) {
return VisitExpr(e);
}
// override codegen
// Reset the state so it works for a new function.
void InitFuncState();
// Get the thread index
- spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
+ spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
spirv::Value CreateStorageSync(const CallNode* op);
- void Scalarize(const Expr& e,
+ void Scalarize(const PrimExpr& e,
std::function<void(int i, spirv::Value v)> f);
// The builder
std::unique_ptr<spirv::IRBuilder> builder_;
// num_signature means number of arguments used to query signature
template<unsigned id>
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
- Expr e = targs[0];
+ PrimExpr e = targs[0];
const ir::CallNode* call = e.as<ir::CallNode>();
CHECK(call != nullptr);
- Array<Expr> cargs;
+ Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
- for (Expr arg : call->args) {
+ for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
*rv = ir::CallNode::make(
}
void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
- const Expr& a,
- const Expr& b) {
+ const PrimExpr& a,
+ const PrimExpr& b) {
this->Push(a);
this->Push(b);
DataType t = a.dtype();
* into device function when only device JIT is available.
*/
class CodeGenStackVM
- : public ExprFunctor<void(const Expr&)>,
+ : public ExprFunctor<void(const PrimExpr&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
- void Push(const Expr& n) {
+ void Push(const PrimExpr& n) {
VisitExpr(n);
}
/*!
int GetVarID(const VarNode* v) const;
// Push binary operator
void PushBinary(StackVM::OpCode op_int64,
- const Expr& a,
- const Expr& b);
+ const PrimExpr& a,
+ const PrimExpr& b);
// push cast;
void PushCast(DataType dst, DataType src);
// overloadable functions
* For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
*/
class CodeGenHybrid :
- public ExprFunctor<void(const Expr&, std::ostream&)>,
+ public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
public StmtFunctor<void(const Stmt&)> {
public:
/*!
* \param n The expression to be printed.
* \param os The output stream
*/
- void PrintExpr(const Expr &n, std::ostream &os) {
+ void PrintExpr(const PrimExpr &n, std::ostream &os) {
this->VisitExpr(n, os);
}
/*!
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
- std::string PrintExpr(const Expr &n) {
+ std::string PrintExpr(const PrimExpr &n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kStr) {
- dict.Set(key, Expr(val.operator std::string()));
+ dict.Set(key, PrimExpr(val.operator std::string()));
} else {
- dict.Set(key, val.operator Expr());
+ dict.Set(key, val.operator PrimExpr());
}
}
}
using IndexMod = ir::FloorModNode;
using IndexDiv = ir::FloorDivNode;
-Array<Expr> SimplifyArray(Array<Expr> array) {
+Array<PrimExpr> SimplifyArray(Array<PrimExpr> array) {
for (size_t i = 0; i < array.size(); ++i) {
array.Set(i, ir::Simplify(array[i]));
}
return array;
}
-Buffer decl_buffer(Array<Expr> shape,
+Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype,
std::string name) {
return BufferNode::make(
Var(name, DataType::Handle()),
dtype,
shape,
- Array<Expr>(),
- Expr(),
+ Array<PrimExpr>(),
+ PrimExpr(),
name,
"",
0, 0,
}
// Split the given expression w.r.t the add operator
-inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
+inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr &expr) {
using namespace ir;
- std::vector<const Expr*> ret;
- std::stack<const Expr*> split_buffer;
+ std::vector<const PrimExpr*> ret;
+ std::stack<const PrimExpr*> split_buffer;
split_buffer.push(&expr);
while (!split_buffer.empty()) {
- const Expr* top_ele = split_buffer.top();
+ const PrimExpr* top_ele = split_buffer.top();
split_buffer.pop();
auto expr_add_match = top_ele->as<AddNode>();
if (expr_add_match) {
// If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
// Currently the we will not search the add/mult combinations exhaustively
// as it will take too much computation.
-inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
- const Expr &mod_l_expr,
- const Expr &mod_r_expr) {
+inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr &mult_expr,
+ const PrimExpr &mod_l_expr,
+ const PrimExpr &mod_r_expr) {
using namespace ir;
const MulNode* mult_ptr = mult_expr.as<MulNode>();
- if (!mult_ptr) return std::make_pair(false, Expr());
- Expr mult_outer = mult_ptr->b;
- const Expr* inner = &(mult_ptr->a);
+ if (!mult_ptr) return std::make_pair(false, PrimExpr());
+ PrimExpr mult_outer = mult_ptr->b;
+ const PrimExpr* inner = &(mult_ptr->a);
// 1. Calculate the outer multiplier
while (true) {
mult_ptr = inner->as<MulNode>();
// If Mult is found, we will expand the inner multiplication factor
// If Div is found, we will go on testing whether lhs matches the lhs of mod expr
// and returns the optimization result.
- const Expr* search_ptr = inner;
- Expr mult_inner; // The inner multiplication factor
- Expr no_opt_sum; // Sum of the exprs that cannot be optimized
+ const PrimExpr* search_ptr = inner;
+ PrimExpr mult_inner; // The inner multiplication factor
+ PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized
while (true) {
auto inner_div_ptr = search_ptr->as<IndexDiv>();
auto inner_mult_ptr = search_ptr->as<MulNode>();
auto inner_add_ptr = search_ptr->as<AddNode>();
if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
- return std::make_pair(false, Expr());
+ return std::make_pair(false, PrimExpr());
} else if (inner_div_ptr) {
- Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
+ PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
if (Equal(overall_mult, inner_div_ptr->b)
&& Equal(overall_mult, mod_r_expr)
&& Equal(inner_div_ptr->a, mod_l_expr)) {
// Found!
- Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
+ PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
return std::make_pair(true, ret);
} else {
- return std::make_pair(false, Expr());
+ return std::make_pair(false, PrimExpr());
}
} else if (inner_mult_ptr) {
mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
search_ptr = &(inner_mult_ptr->a);
} else if (inner_add_ptr) {
if (mult_inner.get()) {
- return std::make_pair(false, Expr());
+ return std::make_pair(false, PrimExpr());
}
no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
search_ptr = &(inner_add_ptr->b);
break;
}
}
- return std::make_pair(false, Expr());
+ return std::make_pair(false, PrimExpr());
}
// Insert the elements into the corresponding mult_exprs and mod_exprs.
// If the element is found to match Mul, it will be pushed to the mult_exprs.
// If the element it found to match Mod, it will be pused to the mod_exprs.
// Otherwise, the elements will be added to the no_opt_sum variable
-inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
- std::list<Expr>* mult_exprs,
- std::list<std::pair<Expr, Expr> >* mod_exprs,
- Expr* no_opt_sum,
+inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
+ std::list<PrimExpr>* mult_exprs,
+ std::list<std::pair<PrimExpr, PrimExpr> >* mod_exprs,
+ PrimExpr* no_opt_sum,
bool* has_mult,
bool* has_mod) {
using namespace ir;
*has_mult = false;
*has_mod = false;
- for (const Expr* ele : eles) {
+ for (const PrimExpr* ele : eles) {
auto mod_ptr = ele->as<IndexMod>();
auto mult_ptr = ele->as<MulNode>();
if (mod_ptr) {
// The search will be performed repeatively until no pattern is found.
// Return: a pair with (false, Expr()) if cannot be optimized.
// a pair with (true, optimized_expr) if can be optimized
-inline Expr MergeMulMod(const Expr &base) {
+inline PrimExpr MergeMulMod(const PrimExpr &base) {
using namespace ir;
// 1. Prepare the lists.
// We store two lists, a list that contain all the elements that match Mul and
// a list that contain all the elements that match Mod.
// The elements in the Mod will be used to match against the elements in Mul.
// The result will then be split and pushed back to these two lists.
- Expr simplified_base = Simplify(base);
- std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
- std::list<Expr> mult_exprs;
- std::list<std::pair<Expr, Expr> > mod_exprs;
- Expr no_opt_sum;
+ PrimExpr simplified_base = Simplify(base);
+ std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base);
+ std::list<PrimExpr> mult_exprs;
+ std::list<std::pair<PrimExpr, PrimExpr> > mod_exprs;
+ PrimExpr no_opt_sum;
bool has_mult;
bool has_mod;
MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
&no_opt_sum, &has_mult, &has_mod);
bool find_opt = false;
- std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
+ std::list<std::pair<PrimExpr, PrimExpr> >::iterator search_mod_it = mod_exprs.begin();
// 2. Exhaustive Search
while (search_mod_it != mod_exprs.end()) {
- std::list<Expr>::iterator mult_it = mult_exprs.begin();
+ std::list<PrimExpr>::iterator mult_it = mult_exprs.begin();
bool inner_find_opt = false;
while (mult_it != mult_exprs.end()) {
- std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
+ std::pair<bool, PrimExpr> ret = MergeMulModInner(*mult_it,
search_mod_it->first,
search_mod_it->second);
if (ret.first) {
++search_mod_it;
mod_exprs.erase(temp_mod_it);
mult_exprs.erase(mult_it);
- std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
+ std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second);
MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
&no_opt_sum, &has_mult, &has_mod);
if (has_mult) {
if (!find_opt) {
return simplified_base;
}
- for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
+ for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
}
- for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
+ for (std::list<std::pair<PrimExpr, PrimExpr> >::iterator it = mod_exprs.begin();
it != mod_exprs.end(); ++it) {
no_opt_sum = no_opt_sum.get() ?
no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
-inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
- Expr base = n->elem_offset;
+inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
+ PrimExpr base = n->elem_offset;
if (n->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
} else {
CHECK_EQ(n->shape.size(), index.size());
if (index.size() > 0) {
- Expr offset = index[0];
+ PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(offset * n->shape[i] + index[i]);
}
return base;
}
-inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype) {
- Expr offset = ElemOffset(n, index);
+inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
+ PrimExpr offset = ElemOffset(n, index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
}
}
-Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
+PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType dtype) const {
// specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
CHECK(dtype.element_of() == n->dtype.element_of() &&
}
}
-Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
+Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
// specially handle bool, stored asDataType::Int(8)
const BufferNode* n = operator->();
DataType dtype = value.dtype();
Buffer Buffer::MakeStrideView() const {
if ((*this)->strides.size() != 0) return *this;
if ((*this)->shape.size() == 0) return *this;
- std::vector<Expr> temp;
+ std::vector<PrimExpr> temp;
auto n = make_object<BufferNode>(*operator->());
- Expr acc = make_const(n->DefaultIndexType(), 1);
+ PrimExpr acc = make_const(n->DefaultIndexType(), 1);
for (size_t i = n->shape.size(); i != 0 ; --i) {
temp.push_back(acc);
acc = acc * n->shape[i - 1];
return Buffer(n);
}
-Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
+Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const {
const BufferNode* n = operator->();
begins = SimplifyArray(begins);
- Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
- Array<Expr> strides = n->strides;
+ PrimExpr elem_offset = ir::Simplify(ElemOffset(n, begins));
+ Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
bool need_stride = false;
n->buffer_type);
}
-Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, Expr offset) const {
+PrimExpr Buffer::access_ptr(int access_mask,
+ DataType ptr_type,
+ int content_lanes,
+ PrimExpr offset) const {
const BufferNode* self = operator->();
- Expr e_dtype;
- Expr extent;
+ PrimExpr e_dtype;
+ PrimExpr extent;
if (self->shape.size() == 0) {
extent = make_const(self->DefaultIndexType(), 1);
} else if (self->strides.size() == self->shape.size()) {
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
- extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
+ extent = arith::ComputeReduce<ir::MulNode>(self->shape, PrimExpr()) - offset;
}
- Expr elem_offset = self->elem_offset + offset;
+ PrimExpr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
} else {
e_dtype = ir::TypeAnnotation(self->dtype);
}
- Array<Expr> acc_args{
+ Array<PrimExpr> acc_args{
e_dtype, self->data, elem_offset,
extent, make_const(DataType::Int(32), access_mask)};
return ir::CallNode::make(
Buffer BufferNode::make(Var data,
DataType dtype,
- Array<Expr> shape,
- Array<Expr> strides,
- Expr elem_offset,
+ Array<PrimExpr> shape,
+ Array<PrimExpr> strides,
+ PrimExpr elem_offset,
std::string name,
std::string scope,
int data_alignment,
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
- IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)),
+ IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
<< factor << " for dimension " << c;
- IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)),
+ IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)),
Var(std::string(1, c)), kDataPar);
node->axes.push_back(axis);
factor = 0;
Array<IterVar> new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
- new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)),
+ new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)),
Var(axis.ToSubordinate().name()), kDataPar));
}
if (i == this->ndim()) break;
p->stream << "Layout(" << l->name << ")";
});
-inline bool GetStoreRule(Array<Expr>* rule,
+inline bool GetStoreRule(Array<PrimExpr>* rule,
const Layout& src_layout,
const Layout& dst_layout) {
if (!src_layout.defined() || src_layout.name().empty() ||
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
- Expr store(0);
+ PrimExpr store(0);
for (size_t j = 0; j < src_layout.ndim(); ++j) {
const auto& orig_axis = src_layout[j];
const IterVar& orig_axis_impl = src_layout->axes[j];
if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
if (orig_axis.IsPrimal()) {
- Expr orig_var = orig_axis_impl->var;
+ PrimExpr orig_var = orig_axis_impl->var;
const int32_t factor = src_layout.FactorOf(orig_axis);
if (factor > 0) {
- orig_var = orig_var * Expr(factor);
+ orig_var = orig_var * PrimExpr(factor);
}
store = store + orig_var;
} else {
if (store_axis.IsPrimal()) {
const int32_t factor = dst_layout.FactorOf(store_axis);
if (factor > 0) {
- store = indexdiv(store, Expr(factor));
+ store = indexdiv(store, PrimExpr(factor));
}
} else {
store = indexmod(store, store_axis_impl->dom->extent);
return true;
}
-inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
+inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index,
const Array<IterVar>& src_axis,
- const Array<Expr>& transform_rule) {
- Array<Expr> result;
- std::unordered_map<const VarNode*, Expr> bind_map;
+ const Array<PrimExpr>& transform_rule) {
+ Array<PrimExpr> result;
+ std::unordered_map<const VarNode*, PrimExpr> bind_map;
for (size_t i = 0; i < src_index.size(); ++i) {
bind_map[src_axis[i]->var.get()] = src_index[i];
}
- for (Expr rule : transform_rule) {
+ for (PrimExpr rule : transform_rule) {
result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
}
return result;
}
-Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
+Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(src_index.size(), self->src_layout->axes.size())
}
-Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
+Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
}
-inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
+inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
const Array<IterVar>& src_axis,
const Array<IterVar>& target_axis,
- const Array<Expr>& transform_rule) {
+ const Array<PrimExpr>& transform_rule) {
CHECK_EQ(src_shape.size(), src_axis.size());
// bind variables for original axes
// for major-axis, bind the corresponding size
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
// e.g., (C * 16 + c) / 32
- std::unordered_map<const VarNode*, Expr> bind_map;
+ std::unordered_map<const VarNode*, PrimExpr> bind_map;
std::unordered_set<size_t> symbolic_var_set;
for (size_t i = 0; i < src_shape.size(); ++i) {
- Expr orig_shape = src_shape[i];
+ PrimExpr orig_shape = src_shape[i];
IterVar orig_axis = src_axis[i];
if (orig_shape.as<ir::AnyNode>()) {
symbolic_var_set.insert(i);
<< orig_axis->dom->extent << ", get " << orig_shape;
}
}
- bind_map[orig_axis->var.get()] = Expr(0);
+ bind_map[orig_axis->var.get()] = PrimExpr(0);
} else {
bind_map[orig_axis->var.get()] = orig_shape;
}
// infer the target shape,
// for major-axis, use the forward/backward_rule directly,
// for minor-axis, simply use the extent.
- Array<Expr> result;
+ Array<PrimExpr> result;
CHECK_EQ(transform_rule.size(), target_axis.size());
for (size_t i = 0; i < transform_rule.size(); ++i) {
- Expr rule = transform_rule[i];
+ PrimExpr rule = transform_rule[i];
IterVar axis = target_axis[i];
if (!LayoutAxis::Get(axis).IsPrimal()) {
result.push_back(axis->dom->extent);
return result;
}
-Array<Expr> BijectiveLayout::ForwardShape(const Array<Expr>& shape) const {
+Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->src_layout->axes,
self->dst_layout->axes, self->forward_rule);
}
-Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
+Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
const BijectiveLayoutNode* self = operator->();
return TransformShape(shape, self->dst_layout->axes,
namespace tvm {
-Expr::Expr(int32_t value)
- : Expr(IntImmNode::make(DataType::Int(32), value)) {}
+PrimExpr::PrimExpr(int32_t value)
+ : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {}
-Expr::Expr(float value)
- : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
+PrimExpr::PrimExpr(float value)
+ : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
-Expr::Expr(std::string str)
- : Expr(ir::StringImmNode::make(str)) {}
+PrimExpr::PrimExpr(std::string str)
+ : PrimExpr(ir::StringImmNode::make(str)) {}
Var::Var(std::string name_hint, DataType t)
: Var(VarNode::make(t, name_hint)) {}
return Var(node);
}
-Range::Range(Expr begin, Expr end)
+Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>(
begin,
is_zero(begin) ? end : (end - begin))) {
return Integer(node);
}
-Range Range::make_by_min_extent(Expr min, Expr extent) {
+Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
namespace tvm {
// simple cast that only checks if type matches and cast
-inline Expr SimpleCast(const DataType& t, Expr value) {
+inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
return ir::CastNode::make(t, value);
}
// The public function with a quick checking path.
-void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
+void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
DataType ltype = lhs.dtype();
DataType rtype = rhs.dtype();
// maximum and min limits
-Expr max_value(const DataType& dtype) {
+PrimExpr max_value(const DataType& dtype) {
using namespace ir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
}
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
- return Expr();
+ return PrimExpr();
}
-Expr min_value(const DataType& dtype) {
+PrimExpr min_value(const DataType& dtype) {
using namespace ir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
}
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
- return Expr();
+ return PrimExpr();
}
template<typename ValueType>
return true;
}
-bool is_const_power_of_two_integer(const Expr& x, int* shift) {
+bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return ConstPowerHelper(op->value, shift);
} else if (const auto* op = x.as<ir::UIntImmNode>()) {
}
}
-Expr cast(const DataType& t, Expr value) {
+PrimExpr cast(const DataType& t, PrimExpr value) {
using ir::IntImmNode;
using ir::UIntImmNode;
using ir::FloatImmNode;
}
}
-Expr reinterpret(const DataType& t, Expr value) {
+PrimExpr reinterpret(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
return ir::CallNode::make(
t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic);
}
-Expr operator+(Expr a, Expr b) {
+PrimExpr operator+(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::AddNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::AddNode>(a, b);
if (ret.defined()) return ret;
return ir::AddNode::make(a, b);
}
// negation
-Expr operator-(Expr a) {
+PrimExpr operator-(PrimExpr a) {
using ir::IntImmNode;
using ir::FloatImmNode;
const IntImmNode* pa = a.as<IntImmNode>();
return make_zero(a.dtype()) - a;
}
-Expr operator-(Expr a, Expr b) {
+PrimExpr operator-(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::SubNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::SubNode>(a, b);
if (ret.defined()) return ret;
return ir::SubNode::make(a, b);
}
-Expr operator*(Expr a, Expr b) {
+PrimExpr operator*(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::MulNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::MulNode>(a, b);
if (ret.defined()) return ret;
return ir::MulNode::make(a, b);
}
-Expr div(Expr a, Expr b) {
+PrimExpr div(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::DivNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::DivNode>(a, b);
if (ret.defined()) return ret;
return ir::DivNode::make(a, b);
}
-Expr truncdiv(Expr a, Expr b) {
+PrimExpr truncdiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
return div(a, b);
}
-Expr truncmod(Expr a, Expr b) {
+PrimExpr truncmod(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::ModNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::ModNode>(a, b);
if (ret.defined()) return ret;
return ir::ModNode::make(a, b);
}
-Expr operator/(Expr a, Expr b) {
+PrimExpr operator/(PrimExpr a, PrimExpr b) {
return div(a, b);
}
-Expr operator%(Expr a, Expr b) {
+PrimExpr operator%(PrimExpr a, PrimExpr b) {
return truncmod(a, b);
}
// TODO(tqchen): switch to floordiv
-Expr indexdiv(Expr a, Expr b) {
+PrimExpr indexdiv(PrimExpr a, PrimExpr b) {
return floordiv(a, b);
}
-Expr indexmod(Expr a, Expr b) {
+PrimExpr indexmod(PrimExpr a, PrimExpr b) {
return floormod(a, b);
}
-Expr floordiv(Expr a, Expr b) {
+PrimExpr floordiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
if (ret.defined()) return ret;
return ir::FloorDivNode::make(a, b);
}
-Expr floormod(Expr a, Expr b) {
+PrimExpr floormod(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
if (ret.defined()) return ret;
return ir::FloorModNode::make(a, b);
}
-Expr min(Expr a, Expr b) {
+PrimExpr min(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::MinNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::MinNode>(a, b);
if (ret.defined()) return ret;
return ir::MinNode::make(a, b);
}
-Expr max(Expr a, Expr b) {
+PrimExpr max(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
using arith::is_pos_inf;
using arith::is_neg_inf;
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::MaxNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::MaxNode>(a, b);
if (ret.defined()) return ret;
return ir::MaxNode::make(a, b);
}
-Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
+PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
using ir::IntImmNode;
using ir::UIntImmNode;
CHECK(cond.dtype() == DataType::Bool(1))
ir::CallNode::PureIntrinsic);
}
-Expr likely(Expr cond) {
+PrimExpr likely(PrimExpr cond) {
if (is_const(cond)) return cond;
return ir::CallNode::make(cond.dtype(),
ir::CallNode::likely,
ir::CallNode::PureIntrinsic);
}
-Expr operator>(Expr a, Expr b) {
+PrimExpr operator>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::GTNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::GTNode>(a, b);
if (ret.defined()) return ret;
return ir::GTNode::make(a, b);
}
-Expr operator>=(Expr a, Expr b) {
+PrimExpr operator>=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::GENode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::GENode>(a, b);
if (ret.defined()) return ret;
return ir::GENode::make(a, b);
}
-Expr operator<(Expr a, Expr b) {
+PrimExpr operator<(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::LTNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::LTNode>(a, b);
if (ret.defined()) return ret;
return ir::LTNode::make(a, b);
}
-Expr operator<=(Expr a, Expr b) {
+PrimExpr operator<=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::LENode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::LENode>(a, b);
if (ret.defined()) return ret;
return ir::LENode::make(a, b);
}
-Expr operator==(Expr a, Expr b) {
+PrimExpr operator==(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::EQNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::EQNode>(a, b);
if (ret.defined()) return ret;
return ir::EQNode::make(a, b);
}
-Expr operator!=(Expr a, Expr b) {
+PrimExpr operator!=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
- Expr ret = arith::TryConstFold<ir::NENode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::NENode>(a, b);
if (ret.defined()) return ret;
return ir::NENode::make(a, b);
}
-Expr operator&&(Expr a, Expr b) {
+PrimExpr operator&&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::AndNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::AndNode>(a, b);
if (ret.defined()) return ret;
return ir::AndNode::make(a, b);
}
-Expr operator||(Expr a, Expr b) {
+PrimExpr operator||(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::OrNode>(a, b);
+ PrimExpr ret = arith::TryConstFold<ir::OrNode>(a, b);
if (ret.defined()) return ret;
return ir::OrNode::make(a, b);
}
-Expr operator!(Expr a) {
+PrimExpr operator!(PrimExpr a) {
CHECK(a.dtype().is_bool());
- Expr ret = arith::TryConstFold<ir::NotNode>(a);
+ PrimExpr ret = arith::TryConstFold<ir::NotNode>(a);
if (ret.defined()) return ret;
return ir::NotNode::make(a);
}
-Expr operator>>(Expr a, Expr b) {
+PrimExpr operator>>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic);
}
-Expr operator<<(Expr a, Expr b) {
+PrimExpr operator<<(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic);
}
-Expr operator&(Expr a, Expr b) {
+PrimExpr operator&(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
}
-Expr operator|(Expr a, Expr b) {
+PrimExpr operator|(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
}
-Expr operator^(Expr a, Expr b) {
+PrimExpr operator^(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
}
-Expr operator~(Expr a) {
+PrimExpr operator~(PrimExpr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
return ir::CallNode::make(
a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic);
}
-Expr pow(Expr x, Expr y) {
+PrimExpr pow(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
return ir::CallNode::make(
x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic);
}
-Expr abs(Expr x) {
+PrimExpr abs(PrimExpr x) {
if (x.dtype().is_int()) {
using ir::IntImmNode;
const IntImmNode* px = x.as<IntImmNode>();
}
}
-Expr isnan(Expr x) {
+PrimExpr isnan(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
}
}
-Expr sum(Expr source, Array<IterVar> rdom) {
+PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::AddNode::make(x, y);
- Expr identity_element = make_zero(source.dtype());
+ PrimExpr result = ir::AddNode::make(x, y);
+ PrimExpr identity_element = make_zero(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr all(Expr source, Array<IterVar> rdom) {
+PrimExpr all(PrimExpr source, Array<IterVar> rdom) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::AndNode::make(x, y);
- Expr identity_element = make_const(source.dtype(), true);
+ PrimExpr result = ir::AndNode::make(x, y);
+ PrimExpr identity_element = make_const(source.dtype(), true);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr any(Expr source, Array<IterVar> rdom) {
+PrimExpr any(PrimExpr source, Array<IterVar> rdom) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::OrNode::make(x, y);
- Expr identity_element = make_const(source.dtype(), false);
+ PrimExpr result = ir::OrNode::make(x, y);
+ PrimExpr identity_element = make_const(source.dtype(), false);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr max(Expr source, Array<IterVar> rdom) {
+PrimExpr max(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::MaxNode::make(x, y);
- Expr identity_element = min_value(source.dtype());
+ PrimExpr result = ir::MaxNode::make(x, y);
+ PrimExpr identity_element = min_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr min(Expr source, Array<IterVar> rdom) {
+PrimExpr min(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::MinNode::make(x, y);
- Expr identity_element = max_value(source.dtype());
+ PrimExpr result = ir::MinNode::make(x, y);
+ PrimExpr identity_element = max_value(source.dtype());
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr prod(Expr source, Array<IterVar> rdom) {
+PrimExpr prod(PrimExpr source, Array<IterVar> rdom) {
Var x("x", source.dtype()), y("y", source.dtype());
- Expr result = ir::MulNode::make(x, y);
- Expr identity_element = make_const(source.dtype(), 1);
+ PrimExpr result = ir::MulNode::make(x, y);
+ PrimExpr identity_element = make_const(source.dtype(), 1);
ir::CommReducer combiner =
ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
}
-Expr fmod(Expr x, Expr y) {
+PrimExpr fmod(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic);
}
-Expr floor(Expr x) {
+PrimExpr floor(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
}
-Expr ceil(Expr x) {
+PrimExpr ceil(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
}
-Expr round(Expr x) {
+PrimExpr round(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
}
-Expr nearbyint(Expr x) {
+PrimExpr nearbyint(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
}
-Expr trunc(Expr x) {
+PrimExpr trunc(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
namespace ir {
// constructors
-Expr UIntImmNode::make(DataType t, uint64_t value) {
+PrimExpr UIntImmNode::make(DataType t, uint64_t value) {
CHECK(t.is_uint() && t.lanes() == 1)
<< "ValueError: UIntImm can only take scalar";
ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
node->dtype = t;
node->value = value;
- return Expr(node);
+ return PrimExpr(node);
}
-Expr FloatImmNode::make(DataType t, double value) {
+PrimExpr FloatImmNode::make(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = t;
node->value = value;
- return Expr(node);
+ return PrimExpr(node);
}
-Expr StringImmNode::make(std::string value) {
+PrimExpr StringImmNode::make(std::string value) {
ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
node->dtype = DataType::Handle();
node->value = std::move(value);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr CastNode::make(DataType t, Expr value) {
+PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr AndNode::make(Expr a, Expr b) {
+PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr OrNode::make(Expr a, Expr b) {
+PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr NotNode::make(Expr a) {
+PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
ObjectPtr<NotNode> node = make_object<NotNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) {
+PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
node->condition = std::move(condition);
node->true_value = std::move(true_value);
node->false_value = std::move(false_value);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
+PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) {
CHECK(buffer_var.defined());
CHECK(predicate.defined());
CHECK(index.defined());
node->index = std::move(index);
node->predicate = std::move(predicate);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr RampNode::make(Expr base, Expr stride, int lanes) {
+PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
CHECK(base.dtype().is_scalar());
node->base = base;
node->stride = stride;
node->lanes = lanes;
- return Expr(node);
+ return PrimExpr(node);
}
-Expr BroadcastNode::make(Expr value, int lanes) {
+PrimExpr BroadcastNode::make(PrimExpr value, int lanes) {
CHECK(value.defined());
CHECK(value.dtype().is_scalar());
CHECK_GT(lanes, 1);
node->dtype = value.dtype().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
- return Expr(node);
+ return PrimExpr(node);
}
-Expr LetNode::make(Var var, Expr value, Expr body) {
+PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
- return Expr(node);
+ return PrimExpr(node);
}
const char* CallNode::vectorizable_intrinsics[] = {
return false;
}
-Expr CallNode::make(DataType dtype,
+PrimExpr CallNode::make(DataType dtype,
std::string name,
- Array<Expr> args,
+ Array<PrimExpr> args,
CallType call_type,
FunctionRef func,
int value_index) {
node->call_type = call_type;
node->func = std::move(func);
node->value_index = value_index;
- return Expr(node);
+ return PrimExpr(node);
}
-Expr ShuffleNode::make(Array<Expr> vectors,
- Array<Expr> indices) {
+PrimExpr ShuffleNode::make(Array<PrimExpr> vectors,
+ Array<PrimExpr> indices) {
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
DataType base_type = vectors[0].dtype().element_of();
int total_lanes = 0;
- for (Expr val : vectors) {
+ for (PrimExpr val : vectors) {
CHECK(val.dtype().element_of() == base_type);
total_lanes += val.dtype().lanes();
}
node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
node->vectors = std::move(vectors);
node->indices = std::move(indices);
- return Expr(node);
+ return PrimExpr(node);
}
-Expr ShuffleNode::make_concat(Array<Expr> vectors) {
+PrimExpr ShuffleNode::make_concat(Array<PrimExpr> vectors) {
CHECK_NE(vectors.size(), 0);
if (vectors.size() == 1) {
return vectors[0];
}
- Array<Expr> indices;
+ Array<PrimExpr> indices;
int index = 0;
- for (const Expr& e : vectors) {
+ for (const PrimExpr& e : vectors) {
for (int i = 0; i < e.dtype().lanes(); ++i) {
indices.push_back(IntImmNode::make(DataType::Int(32), index++));
}
return make(vectors, indices);
}
-Expr ShuffleNode::make_extract_element(Expr vector, int index) {
+PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) {
return make({vector}, {Integer(index)});
}
CommReducer CommReducerNode::make(Array<Var> lhs,
Array<Var> rhs,
- Array<Expr> result,
- Array<Expr> identity_element) {
+ Array<PrimExpr> result,
+ Array<PrimExpr> identity_element) {
auto node = make_object<CommReducerNode>();
node->lhs = lhs;
node->rhs = rhs;
return CommReducer(node);
}
-Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
+Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
- Map<Var, Expr> value_map;
+ Map<Var, PrimExpr> value_map;
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
- return UpdateArray(result, [&value_map] (const Expr& e) {
+ return UpdateArray(result, [&value_map] (const PrimExpr& e) {
return Substitute(e, value_map);
});
}
-Expr ReduceNode::make(CommReducer combiner, Array<Expr> source,
- Array<IterVar> axis, Expr condition, int value_index) {
+PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
+ Array<IterVar> axis, PrimExpr condition, int value_index) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce)
<< "Can only take axis created by reduce_axis";
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
- return Expr(n);
+ return PrimExpr(n);
}
-Expr AnyNode::make() {
+PrimExpr AnyNode::make() {
auto n = make_object<AnyNode>();
- return Expr(n);
+ return PrimExpr(n);
}
-Stmt LetStmtNode::make(Var var, Expr value, Stmt body) {
+Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
Stmt AttrStmtNode::make(ObjectRef node,
std::string attr_key,
- Expr value,
+ PrimExpr value,
Stmt body) {
auto n = make_object<AttrStmtNode>();
n->node = node;
return Stmt(n);
}
-Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) {
+Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
message.as<StringImmNode>())
}
Stmt ForNode::make(Var loop_var,
- Expr min,
- Expr extent,
+ PrimExpr min,
+ PrimExpr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body) {
return Stmt(node);
}
-Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
CHECK(value.defined());
CHECK(index.defined());
CHECK(predicate.defined());
return Stmt(node);
}
-Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
CHECK(value_index >=0 && value_index < func->num_outputs())
<< "value index output function return value bound";
CHECK(value.defined()) << "Provide of undefined value\n";
Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
- Array<Expr> extents,
- Expr condition,
+ Array<PrimExpr> extents,
+ PrimExpr condition,
Stmt body,
- Expr new_expr,
+ PrimExpr new_expr,
std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
return Stmt(node);
}
-int32_t AllocateNode::constant_allocation_size(const Array<Expr>& extents) {
+int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
int value_index,
DataType dtype,
Region bounds,
- Expr condition,
+ PrimExpr condition,
Stmt body) {
for (size_t i = 0; i < bounds.size(); ++i) {
CHECK(bounds[i]->min.defined());
data_ = std::move(node);
}
-Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) {
+Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
CHECK(condition.defined());
CHECK(then_case.defined());
// else_case may be null.
return Stmt(node);
}
-Stmt EvaluateNode::make(Expr value) {
+Stmt EvaluateNode::make(PrimExpr value) {
CHECK(value.defined());
ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
namespace tvm {
// Tensor
-Expr Tensor::operator()(Array<Var> indices) const {
- Array<Expr> arr(indices.begin(), indices.end());
+PrimExpr Tensor::operator()(Array<Var> indices) const {
+ Array<PrimExpr> arr(indices.begin(), indices.end());
return operator()(arr);
}
-Expr Tensor::operator()(Array<Expr> indices) const {
+PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
using ir::CallNode;
if (ndim() != 0) {
CHECK_EQ(ndim(), indices.size())
return Tensor(node);
}
-Tensor TensorNode::make(Array<Expr> shape,
+Tensor TensorNode::make(Array<PrimExpr> shape,
DataType dtype,
Operation op,
int value_index) {
Array<Tensor> tensors,
Array<Region> regions,
Array<IterVar> reduce_axis,
- Array<Expr> scalar_inputs) {
+ Array<PrimExpr> scalar_inputs) {
auto n = make_object<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
return body[idx].dtype();
}
-Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
+Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of a BaseComputeOp have the same shape
- Array<Expr> shape;
+ Array<PrimExpr> shape;
for (const auto& ivar : this->axis) {
const Range& r = ivar->dom;
shape.push_back(r->extent);
return shape;
}
-Tensor compute(Array<Expr> shape,
+Tensor compute(Array<PrimExpr> shape,
FCompute fcompute,
std::string name,
std::string tag,
name, tag, attrs, axis, {fcompute(args)}).output(0);
}
-Array<Tensor> compute(Array<Expr> shape,
+Array<Tensor> compute(Array<PrimExpr> shape,
FBatchCompute fcompute,
std::string name,
std::string tag,
std::string tag,
Map<std::string, ObjectRef> attrs,
Array<IterVar> axis,
- Array<Expr> body) {
+ Array<PrimExpr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, ObjectRef>();
}
const std::unordered_map<Tensor, Tensor>& rmap) const {
CHECK_EQ(self.operator->(), this);
VerifyComputeOp(this);
- Array<Expr> arr;
+ Array<PrimExpr> arr;
if (this->body[0]->IsInstance<ir::ReduceNode>()) {
// Specially handle reduce so the replaced op
// still share all the components
- Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
+ PrimExpr new_reduce = op::ReplaceTensor(this->body[0], rmap);
if (!new_reduce.same_as(this->body[0])) {
const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
for (size_t k = 0; k < this->body.size(); ++k) {
auto n = make_object<ir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
- arr.push_back(Expr(n));
+ arr.push_back(PrimExpr(n));
}
} else {
arr = this->body;
}
} else {
- arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
+ arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
return op::ReplaceTensor(e, rmap);
});
}
IntSet arg_intset = EvalSet(call->args[i], dom_map);
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
- Expr shape_i_min_value = make_zero(t->shape[i].dtype());
- Expr shape_i_max_value = t->shape[i] - 1;
- Expr min_value = arg_interval->min_value;
- Expr max_value = arg_interval->max_value;
+ PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
+ PrimExpr shape_i_max_value = t->shape[i] - 1;
+ PrimExpr min_value = arg_interval->min_value;
+ PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
if (it != stage->iter_var_attrs.end()) {
IterVarAttr attr = (*it).second;
if (attr->dim_align_factor != 0) {
- Array<Expr> tuple = {static_cast<int>(i),
+ Array<PrimExpr> tuple = {static_cast<int>(i),
attr->dim_align_factor,
attr->dim_align_offset};
realize = ir::AttrStmtNode::make(
const Array<Tensor>& tensors,
Stmt* init,
Stmt* provide) {
- Array<Expr> args;
+ Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
CHECK(reduce);
const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
CHECK(combiner);
- Array<Expr> lhs;
+ Array<PrimExpr> lhs;
for (size_t i = 0; i < size; ++i) {
lhs.push_back(tensors[i](args));
}
- Array<Expr> init_value = combiner->identity_element;
- Array<Expr> update_value = (*combiner)(lhs, reduce->source);
+ Array<PrimExpr> init_value = combiner->identity_element;
+ Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(ProvideNode::make(
// Normal computation.
Stmt MakeProvide(const ComputeOpNode* op,
const Tensor& t) {
- Array<Expr> args;
+ Array<PrimExpr> args;
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
/// Interface to perform compute verification
void Run() {
- for (const Expr e : compute_->body) {
+ for (const PrimExpr e : compute_->body) {
// Check for consistency of top level reductions
const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
CHECK((reduce && reduce_) || (!reduce && !reduce_))
protected:
/// Visitor implementation
//@{
- void VisitExpr(const Expr& n) final {
+ void VisitExpr(const PrimExpr& n) final {
++level_;
ExprVisitor::VisitExpr(n);
--level_;
const ComputeLoopNest& n,
Stmt body,
Stmt update) {
- Array<Expr> conds;
+ Array<PrimExpr> conds;
std::unordered_set<const VarNode*> banned;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
banned.insert(iv->var.get());
}
}
- for (const Expr& pred : n.main_predicates) {
+ for (const PrimExpr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition "
<< pred << " has a conflict with the reset condition";
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// The common number of loops between init and main
size_t num_common_loop;
// predicates for the initialize loop
- std::vector<Expr> init_predicates;
+ std::vector<PrimExpr> init_predicates;
// Initialization nest involved.
std::vector<std::vector<Stmt> > init_nest;
// Value map for the init code
- std::unordered_map<IterVar, Expr> init_vmap;
+ std::unordered_map<IterVar, PrimExpr> init_vmap;
// Predicates for the main update loop
- std::vector<Expr> main_predicates;
+ std::vector<PrimExpr> main_predicates;
// The general loop nest
std::vector<std::vector<Stmt> > main_nest;
// Value map for the IterVar.
- std::unordered_map<IterVar, Expr> main_vmap;
+ std::unordered_map<IterVar, PrimExpr> main_vmap;
/*!
* \brief constructor to build ComputeOpNest
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) {
- Array<Expr> args;
+ Array<PrimExpr> args;
for (IterVar iv : self->axis) {
args.push_back(iv->var);
}
- std::unordered_map<IterVar, Expr> value_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
auto conds = schedule::MakeBoundCheck(
CHECK(reduce);
reduces[i] = reduce;
}
- Expr cond = reduces[0]->condition;
- for (Expr v : conds) {
+ PrimExpr cond = reduces[0]->condition;
+ for (PrimExpr v : conds) {
cond = cond && v;
}
- Array<Expr> freduce_args;
+ Array<PrimExpr> freduce_args;
freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
for (size_t i = 0; i < size; ++i) {
freduce_args.push_back(reduces[0]->source[i]);
}
}
// Checks for the thread.
- std::vector<Expr> thread_head_check;
+ std::vector<PrimExpr> thread_head_check;
if (stage->store_predicate.defined()) {
thread_head_check.emplace_back(stage->store_predicate);
}
return output_placeholders[i]->dtype;
}
-Array<Expr> ExternOpNode::output_shape(size_t i) const {
+Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
return output_placeholders[i]->shape;
}
Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec;
- Array<Expr> tuple;
+ Array<PrimExpr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
return outputs[i]->dtype;
}
-Array<Expr> HybridOpNode::output_shape(size_t i) const {
+Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
return outputs[i]->shape;
}
Stmt ApplyLoopShapes(const Stage &stage,
const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
class LoopSpliter : public StmtExprMutator {
- Expr factor;
+ PrimExpr factor;
const VarNode *parent;
IterVar inner, outer;
Stmt VisitStmt_(const ForNode *op) final {
if (op->loop_var.get() == parent) {
- std::unordered_map<const VarNode *, Expr> rmap;
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
rmap[op->loop_var.get()] = inner + outer * factor;
Stmt ret = ir::Substitute(op->body, rmap);
- Expr cond = likely(outer * factor < (op->extent - inner));
+ PrimExpr cond = likely(outer * factor < (op->extent - inner));
ret = IfThenElseNode::make(cond, ret);
- ret = ForNode::make(inner->var, Expr(0), inner->dom->extent,
+ ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
IterVarTypeToForType(inner->iter_type), op->device_api, ret);
- ret = ForNode::make(outer->var, Expr(0), outer->dom->extent,
+ ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
IterVarTypeToForType(outer->iter_type), op->device_api, ret);
splitted = true;
return ret;
const VarNode *inner;
const VarNode *outer;
bool under_outer;
- Expr extent;
+ PrimExpr extent;
public:
bool fused;
Stmt VisitStmt_(const ForNode* op) final {
if (op->loop_var.get() == inner) {
CHECK(under_outer);
- std::unordered_map<const VarNode *, Expr> rmap;
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexmod(parent, op->extent);
extent = op->extent;
fused = true;
} else if (op->loop_var.get() == outer) {
under_outer = true;
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, Expr> rmap;
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = ir::Substitute(body, rmap);
under_outer = false;
- return ForNode::make(parent->var, Expr(0), extent * op->extent,
+ return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
op->for_type, op->device_api, body);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
- std::unordered_map<const VarNode *, Expr> rmap;
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
body = ir::Substitute(body, rmap);
extent = extent * op->extent;
CHECK(Equal(iter_var->dom->extent, op->extent))
<< "Thread extent and loop extent mismatch!\n";
}
- std::unordered_map<const VarNode *, Expr> rmap;
+ std::unordered_map<const VarNode *, PrimExpr> rmap;
rmap[op->loop_var.get()] = iter_var;
Stmt body = ir::Substitute(op->body, rmap);
return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, Expr>* p_value_map,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = EvaluateNode::make(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
- std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
+ std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
- Expr pvalue = it_attr->pragma_values[k];
+ PrimExpr pvalue = it_attr->pragma_values[k];
if (!pvalue.defined()) {
pvalue = make_const(DataType::Int(32), 1);
}
nest[i + 1].emplace_back(
ForNode::make(idx, 0, dom->extent,
for_type, DeviceAPI::None, no_op));
- Expr new_value = dom->min + idx;
+ PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(
LetStmtNode::make(var, new_value, no_op));
return nest;
}
-std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
Stmt no_op = EvaluateNode::make(0);
std::vector<Stmt> nest;
- for (const Expr& cond : predicates) {
+ for (const PrimExpr& cond : predicates) {
nest.emplace_back(IfThenElseNode::make(cond, no_op));
}
return nest;
explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
: vmap_(vmap) {}
- Expr VisitExpr_(const ir::CallNode* op) final {
+ PrimExpr VisitExpr_(const ir::CallNode* op) final {
if (op->call_type == ir::CallNode::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- Expr ret = ir::CallNode::make(
+ PrimExpr ret = ir::CallNode::make(
op->dtype, it->second->op->name, op->args,
op->call_type, it->second->op, it->second->value_index);
found = true;
Stmt ret = repl(stmt);
return repl.found ? ret : stmt;
}
-Expr ReplaceTensor(Expr expr,
+PrimExpr ReplaceTensor(PrimExpr expr,
const std::unordered_map<Tensor, Tensor>& replace) {
TensorReplacer repl(replace);
- Expr ret = repl(expr);
+ PrimExpr ret = repl(expr);
return repl.found ? ret : expr;
}
Stmt Substitute(Stmt s,
- const std::unordered_map<IterVar, Expr>& value_map) {
- std::unordered_map<const VarNode*, Expr> init;
+ const std::unordered_map<IterVar, PrimExpr>& value_map) {
+ std::unordered_map<const VarNode*, PrimExpr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
size_t begin_iter_pos,
bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
- std::unordered_map<IterVar, Expr>* p_value_map,
+ std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop);
/*!
* \param predicates The predicates to be checked.
* \return List of If nest that checks the predicates.
*/
-std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
-Expr ReplaceTensor(Expr expr,
+PrimExpr ReplaceTensor(PrimExpr expr,
const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \return Substituted result.
*/
Stmt Substitute(Stmt stmt,
- const std::unordered_map<IterVar, Expr>& value_map);
+ const std::unordered_map<IterVar, PrimExpr>& value_map);
/*!
* \brief Converts Halide ForType to its corresponding IterVarType
return dtype;
}
-Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
+Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
CHECK_EQ(i, 0U);
return shape;
}
Operation PlaceholderOpNode::make(std::string name,
- Array<Expr> shape,
+ Array<PrimExpr> shape,
DataType dtype) {
auto n = make_object<PlaceholderOpNode>();
n->name = name;
return Operation(n);
}
-Tensor placeholder(Array<Expr> shape, DataType dtype, std::string name) {
+Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
});
TVM_REGISTER_NODE_TYPE(ScanOpNode);
-inline bool prove_equal(Expr lhs, Expr rhs) {
+inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
return is_zero(ir::Simplify(lhs - rhs));
}
return update[i]->dtype;
}
-Array<Expr> ScanOpNode::output_shape(size_t i) const {
+Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
CHECK_LT(i, state_placeholder.size());
return state_placeholder[i]->shape;
}
Range r = arith::Union(time_dom).cover_range(sdom);
(*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
- Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
+ Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
begin_scan = i + 1;
}
}
- std::unordered_map<IterVar, Expr> vmap;
+ std::unordered_map<IterVar, PrimExpr> vmap;
std::unordered_set<IterVar> empty;
auto nest = op::MakeLoopNest(
stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
- Array<Expr> scalar_inputs) {
+ Array<PrimExpr> scalar_inputs) {
auto n = make_object<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
Buffer buffer = this->intrin->buffers[i];
Array<ObjectRef> bind_spec{buffer, tensor};
- Array<Expr> tuple;
+ Array<PrimExpr> tuple;
for (size_t i = 0; i < region.size(); ++i) {
tuple.push_back(region[i]->min);
tuple.push_back(region[i]->extent);
Buffer buffer = this->intrin->buffers[num_inputs + i];
Array<ObjectRef> bind_spec{buffer, tensor};
- Array<Expr> tuple;
+ Array<PrimExpr> tuple;
for (size_t i = 0; i < this->axis.size(); ++i) {
auto ivar = this->axis[i];
if (i < static_cast<size_t>(this->schedulable_ndim)) {
}
// Check variable remap
- std::unordered_map<const VarNode*, Expr> vmap;
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
ir::ArgBinder binder(&vmap);
// Map the expressions passed in the call to the TensorIntrin, to the placeholder
// variables
- Array<Expr> user_expr = this->scalar_inputs;
+ Array<PrimExpr> user_expr = this->scalar_inputs;
Array<Var> scalar_params = this->intrin->scalar_params;
- Array<Expr> sp_expr;
+ Array<PrimExpr> sp_expr;
for (auto sp : scalar_params) {
- Expr esp = sp;
+ PrimExpr esp = sp;
sp_expr.push_back(esp);
}
CHECK_EQ(sp_expr.size(), user_expr.size());
}
}
}
- for (const Expr& pred : n.main_predicates) {
+ for (const PrimExpr& pred : n.main_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize failed, split condition "
<< pred << " relies on var defined inside tensorize scope";
}
}
- for (const Expr& pred : n.init_predicates) {
+ for (const PrimExpr& pred : n.init_predicates) {
if (ir::ExprUseVar(pred, banned)) {
LOG(FATAL) << "Tensorize failed, split condition "
<< pred << " relies on var defined inside tensorize scope";
// Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public StmtExprMutator {
public:
- Expr VisitExpr_(const CallNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (op->call_type == CallNode::Halide) {
Tensor t = Downcast<Operation>(op->func).output(op->value_index);
if (it != in_remap_.end()) {
const InputEntry& e = it->second;
CHECK_EQ(op->args.size(), e.region.size());
- Array<Expr> args;
+ Array<PrimExpr> args;
for (size_t i = e.start; i < e.region.size(); ++i) {
args.push_back(op->args[i] - e.region[i]->min);
}
return expr;
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
- Expr VisitExpr_(const ReduceNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const ReduceNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<ReduceNode>();
Array<IterVar> axis;
for (size_t i = 0; i < op->axis.size(); ++i) {
// input data remap
std::unordered_map<Tensor, InputEntry> in_remap_;
// variable remap.
- std::unordered_map<const VarNode*, Expr> var_remap_;
+ std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// IterVar remap.
std::unordered_map<IterVar, IterVar> axis_remap_;
};
// Try to match tensor dataflow of the stage with the intrinsic
-Array<Expr> MatchTensorizeBody(
+Array<PrimExpr> MatchTensorizeBody(
const ComputeOpNode* self,
const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
Map<Var, Range>* compute_intrin_iter_space) {
TensorIntrinMatcher matcher;
matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
- Array<Expr> ret;
- for (Expr expr : self->body) {
+ Array<PrimExpr> ret;
+ for (PrimExpr expr : self->body) {
ret.push_back(matcher(expr));
}
return ret;
const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) {
Map<Var, Range> compute_intrin_iter_space;
- Array<Expr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
+ Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
&compute_intrin_iter_space);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(body.size(), intrin_compute->body.size())
<< "Tensorize failed: body size mismatch";
for (size_t i = 0; i < body.size(); ++i) {
- Expr lhs = Simplify(body[i], compute_intrin_iter_space);
+ PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
- Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
+ PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
if (lhs.dtype() != rhs.dtype()) {
LOG(FATAL)
auto it = in_region.find(tensor);
CHECK(it != in_region.end());
const Array<Range>& region = it->second;
- Array<Expr> tuple;
+ Array<PrimExpr> tuple;
for (const Range r : region) {
tuple.push_back(r->min);
tuple.push_back(r->extent);
CHECK(intrin_compute) << "Only support compute intrinsic for now";
CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
CHECK_EQ(intrin_compute->body.size(), self->body.size());
- Array<Expr> tuple;
+ Array<PrimExpr> tuple;
for (IterVar iv : self->axis) {
auto it = out_dom.find(iv);
CHECK(it != out_dom.end());
tuple, CallNode::Intrinsic), nop));
}
// Check variable remap
- std::unordered_map<const VarNode*, Expr> vmap;
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
ir::ArgBinder binder(&vmap);
CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
<< "Tensorization fail: reduction axis size do not match";
namespace tvm {
namespace ir {
-void BinderAddAssert(Expr cond,
+void BinderAddAssert(PrimExpr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
- Expr scond = Simplify(cond);
+ PrimExpr scond = Simplify(cond);
if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
}
}
-bool ArgBinder::Bind_(const Expr& arg,
- const Expr& value,
+bool ArgBinder::Bind_(const PrimExpr& arg,
+ const PrimExpr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.dtype(), value.dtype());
return false;
}
-void ArgBinder::Bind(const Expr& arg,
- const Expr& value,
+void ArgBinder::Bind(const PrimExpr& arg,
+ const PrimExpr& value,
const std::string& arg_name,
bool with_let) {
Bind_(arg, value, arg_name, with_let);
}
-void ArgBinder::BindArray(const Array<Expr>& arg,
- const Array<Expr>& value,
+void ArgBinder::BindArray(const Array<PrimExpr>& arg,
+ const Array<PrimExpr>& value,
const std::string& arg_name) {
CHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
- Expr offset = value->elem_offset;
- Expr factor = make_const(offset.dtype(), arg->offset_factor);
- Expr zero = make_zero(offset.dtype());
+ PrimExpr offset = value->elem_offset;
+ PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
+ PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
}
-inline Expr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
+inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
void ArgBinder::BindDLTensor(const Buffer& buffer,
- const Expr& device_type,
- const Expr& device_id,
+ const PrimExpr& device_type,
+ const PrimExpr& device_id,
const Var& handle,
const std::string& arg_name) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = EvaluateNode::make(0);
// dimension checks
- Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
- Expr a_ndim = make_const(tvm_ndim_type,
+ PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
+ PrimExpr a_ndim = make_const(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
- Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
+ PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
init_nest_.emplace_back(LetStmtNode::make(
v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
nop));
- Expr is_null = CallNode::make(
+ PrimExpr is_null = CallNode::make(
DataType::Bool(1), intrinsic::tvm_handle_is_null,
{v_strides}, CallNode::PureIntrinsic);
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
- Expr expect_stride = make_const(stype, 1);
- Array<Expr> conds;
+ PrimExpr expect_stride = make_const(stype, 1);
+ Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
- Expr svalue = cast(
+ PrimExpr svalue = cast(
stype,
LoadNode::make(tvm_shape_type, v_strides,
IntImmNode::make(DataType::Int(32), k), const_true(1)));
<< " expected to be compact array";
if (conds.size() != 0) {
Stmt check =
- AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, Expr()),
+ AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, PrimExpr()),
stride_err_msg.str(), EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
- Expr stride = make_const(stype, 1);
+ PrimExpr stride = make_const(stype, 1);
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
- Expr value = cast(buffer->shape[k].dtype(),
+ PrimExpr value = cast(buffer->shape[k].dtype(),
LoadNode::make(tvm_shape_type, v_strides,
IntImmNode::make(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
make_const(DataType::UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
- Expr offset = buffer->elem_offset;
- Expr factor = make_const(offset.dtype(), buffer->offset_factor);
- Expr zero = make_zero(offset.dtype());
+ PrimExpr offset = buffer->elem_offset;
+ PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
+ PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
}
}
* ArgBinder will update this def_map when adding new definitions.
*/
explicit ArgBinder(
- std::unordered_map<const VarNode*, Expr>* def_map)
+ std::unordered_map<const VarNode*, PrimExpr>* def_map)
: def_map_(def_map) {
}
/*!
* \param arg_name argument name.
* \param with_let Whether add lets during bind
*/
- void Bind(const Expr& arg,
- const Expr& value,
+ void Bind(const PrimExpr& arg,
+ const PrimExpr& value,
const std::string& arg_name,
bool with_let = false);
/*!
* \param value The target expression value
* \param arg_name argument name.
*/
- void BindArray(const Array<Expr>& arg,
- const Array<Expr>& value,
+ void BindArray(const Array<PrimExpr>& arg,
+ const Array<PrimExpr>& value,
const std::string& arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
* \param arg_name argument name.
*/
void BindDLTensor(const Buffer& buffer,
- const Expr& device_type,
- const Expr& device_id,
+ const PrimExpr& device_type,
+ const PrimExpr& device_id,
const Var& handle,
const std::string& arg_name);
return init_nest_;
}
/*! \return Handle data type of the data */
- const Map<Var, Expr>& def_handle_dtype() const {
+ const Map<Var, PrimExpr>& def_handle_dtype() const {
return def_handle_dtype_;
}
private:
// Internal bind function
- bool Bind_(const Expr& arg,
- const Expr& value,
+ bool Bind_(const PrimExpr& arg,
+ const PrimExpr& value,
const std::string& arg_name,
bool with_lets);
/*! \brief The definition map, can be uses to substitute */
- std::unordered_map<const VarNode*, Expr>* def_map_;
+ std::unordered_map<const VarNode*, PrimExpr>* def_map_;
/*! \brief defs generated in the current binder */
std::vector<Var> defs_;
/*! \brief Initialize nest */
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
- Map<Var, Expr> def_handle_dtype_;
+ Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
};
StmtVisitor::VisitStmt_(op);
}
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const VarNode *, Expr> mem_to_shape;
+ std::unordered_map<const VarNode *, PrimExpr> mem_to_shape;
};
class BoundChecker : public StmtExprMutator {
public:
explicit BoundChecker(
- const std::unordered_map<const VarNode *, Expr> &mem_to_shape)
+ const std::unordered_map<const VarNode *, PrimExpr> &mem_to_shape)
: mem_to_shape_(mem_to_shape) {}
Stmt VisitStmt_(const AllocateNode* op) final {
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
unsafe_rewritten_ = true;
}
}
// The collector should has at least one item.
if (store_scope_bound_collector_.size()) {
- Expr condition = MakeCondition();
+ PrimExpr condition = MakeCondition();
if (!condition.as<StringImmNode>()) {
Stmt nop = EvaluateNode::make(1);
Stmt then_case =
return GetRef<Stmt>(op);
}
- Expr VisitExpr_(const LoadNode* op) final {
+ PrimExpr VisitExpr_(const LoadNode* op) final {
if (CanInstrument(op->index, op->buffer_var)) {
Collect(op->index, op->buffer_var);
}
}
private:
- bool UpdateIsNeeded(const VarExpr& buffer_var) const {
+ bool UpdateIsNeeded(const Var& buffer_var) const {
return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
}
- void Update(const VarExpr& buffer_var,
- const Array<Expr>& new_shape,
+ void Update(const Var& buffer_var,
+ const Array<PrimExpr>& new_shape,
const DataType& type) {
// Sanity check at first.
if (!new_shape.size()) {
}
// Scalarize the shape.
- Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+ PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
CastNode::make(DataType::UInt(64), new_shape[0]));
for (size_t i = 1; i < new_shape.size(); ++i) {
// Cast to unsigned to avoid integer overlow at frist.
mem_to_shape_[buffer_var.get()] = shape;
}
- bool IndexIsValid(const Expr& index) const {
+ bool IndexIsValid(const PrimExpr& index) const {
if (!index.defined()) {
return false;
}
return true;
}
- bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const {
+ bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
IndexIsValid(index) && !unsafe_rewritten_;
}
- void Collect(Expr index, VarExpr buffer_var) {
+ void Collect(PrimExpr index, Var buffer_var) {
store_scope_bound_collector_.push_back(
std::make_pair(index, mem_to_shape_[buffer_var.get()]));
}
- Expr MakeCondition() {
- Expr condition;
+ PrimExpr MakeCondition() {
+ PrimExpr condition;
for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
- std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i];
- Expr index = buffer_to_mem.first;
- Expr upper_bound = buffer_to_mem.second;
+ std::pair<PrimExpr, PrimExpr> buffer_to_mem = store_scope_bound_collector_[i];
+ PrimExpr index = buffer_to_mem.first;
+ PrimExpr upper_bound = buffer_to_mem.second;
if (const RampNode *ramp_index = index.as<RampNode>()) {
// In case index is base + stride * i.
upper_bound = CastNode::make(DataType::Int(64), upper_bound);
// Looks like a lower bound should always be zero after normalization.
- Expr lower_bound = make_zero(DataType::Int(64));
+ PrimExpr lower_bound = make_zero(DataType::Int(64));
- Expr current_condition =
+ PrimExpr current_condition =
AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
condition =
!i ? current_condition : AndNode::make(condition, current_condition);
// Whether we face tvm_if_then_else intrinsic.
bool unsafe_rewritten_{false};
// Pool which collects the pair of index and shape for specific store/load.
- std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_;
+ std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
// Error message.
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
- std::unordered_map<const VarNode *, Expr> mem_to_shape_;
+ std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
class ContextCallCombiner final : public StmtExprMutator {
public:
struct CompareExpr {
- bool operator()(const Expr& lhs, const Expr& rhs) const {
+ bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
return Compare(lhs, rhs) < 0;
}
};
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
CHECK_EQ(op->args.size(), 1U);
- Expr ctx = op->args[0];
+ PrimExpr ctx = op->args[0];
auto it = ctx_map_.find(ctx);
if (it != ctx_map_.end()) {
return it->second;
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::coproc_uop_scope) {
// Map of comparison expression to variable
- std::map<Expr, Var, CompareExpr> temp;
+ std::map<PrimExpr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Parallel) {
// Map of comparison expression to variable
- std::map<Expr, Var, CompareExpr> temp;
+ std::map<PrimExpr, Var, CompareExpr> temp;
std::swap(temp, ctx_map_);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
std::swap(temp, ctx_map_);
}
private:
- static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap,
+ static Stmt BuildContext(const std::map<PrimExpr, Var, CompareExpr>& cmap,
Stmt body) {
for (const auto& kv : cmap) {
body = LetStmtNode::make(kv.second, kv.first, body);
return body;
}
// Map of comparison expression to variable
- std::map<Expr, Var, CompareExpr> ctx_map_;
+ std::map<PrimExpr, Var, CompareExpr> ctx_map_;
};
LoweredFunc CombineContextCall(LoweredFunc f) {
Range r = arith::Union(wset).cover_range(none);
CHECK(r.defined())
<< "Cannot deduce write range of " << wvec[0].buffer;
- Expr min = r->min;
- Expr extent = r->extent;
+ PrimExpr min = r->min;
+ PrimExpr extent = r->extent;
return EvaluateNode::make(CallNode::make(
DataType::Int(32), func,
{wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic));
});
return IRTransform(parent_for_stmt, nullptr, replace_target_for,
- {Expr("For")});
+ {PrimExpr("For")});
}
// Remove IfThenElse node from a For node.
});
then_for = IRTransform(for_stmt, nullptr, replace_then_case,
- {Expr("IfThenElse")});
+ {PrimExpr("IfThenElse")});
if (if_stmt.as<IfThenElseNode>()->else_case) {
else_for = IRTransform(for_stmt, nullptr, replace_else_case,
- {Expr("IfThenElse")});
+ {PrimExpr("IfThenElse")});
}
return std::make_pair(then_for, else_for);
*ret = new_for;
}
});
- return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")});
+ return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
}
Stmt HoistIfThenElse(Stmt stmt) {
std::string shape = std::to_string(info.m) + ", " +
std::to_string(info.n) + ", " +
std::to_string(info.k);
- Expr shape_expr = StringImmNode::make(shape);
+ PrimExpr shape_expr = StringImmNode::make(shape);
Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
if (store == nullptr) return false;
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
- PVar<Expr> sel_cond, sel_true_value, sel_false_value;
+ PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
bool has_cond =
if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
for (const ForNode* op : loops) {
loop_vars.push_back(op->loop_var);
}
- Array<Expr> store_strides =
+ Array<PrimExpr> store_strides =
arith::DetectLinearEquation(store->index, loop_vars);
- Array<Expr> load_strides =
+ Array<PrimExpr> load_strides =
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
- Array<Expr> dst_shape;
+ Array<PrimExpr> dst_shape;
const size_t loop_var_size = loop_vars.size();
if (loop_var_size == 0) {
dst_shape.push_back(make_const(DataType::Int(32), 1));
dst_shape.push_back(op->extent);
}
}
- Array<Expr> src_shape = dst_shape;
- Array<Expr> pad_before, pad_after;
- Expr pad_value;
- Expr src_elem_offset = load_strides[loop_var_size];
+ Array<PrimExpr> src_shape = dst_shape;
+ Array<PrimExpr> pad_before, pad_after;
+ PrimExpr pad_value;
+ PrimExpr src_elem_offset = load_strides[loop_var_size];
if (has_cond) {
- Array<Expr> clip_bound =
+ Array<PrimExpr> clip_bound =
arith::DetectClipBound(sel_cond.Eval(), loop_vars);
pad_value = sel_false_value.Eval();
if (clip_bound.size() == 0) return false;
CHECK_EQ(src_shape.size(), loop_vars.size());
CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
for (size_t i = 0; i < src_shape.size(); ++i) {
- Expr min_value = clip_bound[2 * i];
- Expr max_value = clip_bound[2 * i + 1];
+ PrimExpr min_value = clip_bound[2 * i];
+ PrimExpr max_value = clip_bound[2 * i + 1];
DataType t = loop_vars[i].dtype();
- Expr svalue = src_shape[i];
+ PrimExpr svalue = src_shape[i];
if (min_value.defined()) {
- Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
+ PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
src_elem_offset = src_elem_offset + pbefore * load_strides[i];
svalue = svalue - pbefore;
pad_before.push_back(pbefore);
pad_before.push_back(make_zero(t));
}
if (max_value.defined()) {
- Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
+ PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
make_zero(t)));
svalue = svalue - pafter;
pad_after.push_back(pafter);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_var_size + 1);
- Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
- Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
+ Array<PrimExpr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
+ Array<PrimExpr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
if (loop_var_size == 0) {
src_strides.push_back(make_const(DataType::Int(32), 1));
dst_strides.push_back(make_const(DataType::Int(32), 1));
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
it->second.stride = arith::ComputeReduce<MulNode>(
- op->extents, Expr()) * op->dtype.lanes();
+ op->extents, PrimExpr()) * op->dtype.lanes();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
- Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
- for (Expr e : op->extents) {
+ Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
+ for (PrimExpr e : op->extents) {
new_extents.push_back(e);
}
CHECK(it->second.loop != nullptr);
CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
<< "It is better to split with multiple of 2";
CHECK(is_zero(old_loop->min));
- Expr zero = old_loop->min;
- Expr new_ext =
+ PrimExpr zero = old_loop->min;
+ PrimExpr new_ext =
old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
- Expr factor = make_const(new_ext.dtype(), split_loop_);
- Expr outer_ext = new_ext / factor;
- Expr tail_base = outer_ext * factor;
+ PrimExpr factor = make_const(new_ext.dtype(), split_loop_);
+ PrimExpr outer_ext = new_ext / factor;
+ PrimExpr tail_base = outer_ext * factor;
Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
- std::unordered_map<const VarNode*, Expr> vmap;
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
std::vector<Stmt> loop_seq;
for (int32_t i = 0; i < split_loop_; ++i) {
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
- Expr idx = tail_base + make_const(tail_base.dtype(), i);
+ PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
tail_seq.emplace_back(
IfThenElseNode::make(idx < old_loop->extent,
}
}
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
}
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
CHECK(!dbuffer_info_.count(op));
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
private:
Stmt MakeProducer(const AttrStmtNode* op) {
- const VarExpr buffer = Downcast<VarExpr>(op->node);
+ const Var buffer = Downcast<Var>(op->node);
CHECK_NE(loop_nest_.size(), 0U)
<< "Double buffer scope must be inside a loop";
auto it = dbuffer_info_.find(buffer.get());
}
StorageEntry& e = it->second;
e.loop = loop_nest_.back();
- Expr zero = make_const(e.loop->loop_var.dtype(), 0);
- Expr one = make_const(e.loop->loop_var.dtype(), 1);
- Expr two = make_const(e.loop->loop_var.dtype(), 2);
- Expr loop_shift = e.loop->loop_var + one;
+ PrimExpr zero = make_const(e.loop->loop_var.dtype(), 0);
+ PrimExpr one = make_const(e.loop->loop_var.dtype(), 1);
+ PrimExpr two = make_const(e.loop->loop_var.dtype(), 2);
+ PrimExpr loop_shift = e.loop->loop_var + one;
e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
e.loop->loop_var.dtype());
e.switch_read_var = indexmod(e.loop->loop_var, two);
in_double_buffer_scope_ = true;
Stmt body = this->VisitStmt(op->body);
in_double_buffer_scope_ = false;
- std::unordered_map<const VarNode*, Expr> vmap;
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
vmap[e.switch_write_var.get()] = zero;
vmap[e.loop->loop_var.get()] = zero;
loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
// Storage entry for those who need double buffering.
struct StorageEntry {
// The size of the buffer
- Expr stride;
+ PrimExpr stride;
// The loop we need
const ForNode* loop{nullptr};
// The switch variable.
- VarExpr switch_write_var;
+ Var switch_write_var;
// The switch variable for reading.
- Expr switch_read_var;
+ PrimExpr switch_read_var;
// The storage scope.
std::string scope;
};
}
private:
- std::vector<VarExpr> loop_nest_;
+ std::vector<Var> loop_nest_;
std::unordered_map<const VarNode *, IntSet> vectorized_;
static const Range none;
};
bool check_write)
: touched_var_(touched), check_write_(check_write) {}
- void VisitExpr(const Expr& n) final {
+ void VisitExpr(const PrimExpr& n) final {
// early stopping
if (expr_touched_ && !check_write_) return;
StmtExprVisitor::VisitExpr(n);
return stmt;
}
// Variable
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
CHECK(!alloc_remap_.count(op))
<< "Buffer address may get rewritten in virtual thread";
if (touched_var_.count(op)) {
visit_touched_var_ = true;
}
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
- Expr RewriteIndex(Expr index, Expr alloc_extent) const {
+ PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
return index + var_ * alloc_extent;
}
// Load
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (touched_var_.count(op->buffer_var.get())) {
visit_touched_var_ = true;
}
}
// Expression.
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
auto it = alloc_remap_.find(buffer);
if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
visit_touched_var_ = true;
- Expr offset = this->VisitExpr(op->args[2]);
- Expr extent = this->VisitExpr(op->args[3]);
- Expr stride =
+ PrimExpr offset = this->VisitExpr(op->args[2]);
+ PrimExpr extent = this->VisitExpr(op->args[3]);
+ PrimExpr stride =
it->second / make_const(offset.dtype(), dtype.lanes());
offset = stride * var_ + offset;
return CallNode::make(
{op->args[0], op->args[1], offset, extent, op->args[4]},
op->call_type);
} else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
- return allow_share_ ? GetRef<Expr>(op) : var_;
+ return allow_share_ ? GetRef<PrimExpr>(op) : var_;
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
// Attribute
Stmt VisitStmt_(const AttrStmtNode* op) final {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
} else if (!allow_share_ && !vt_loop_injected_ &&
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode* op) final {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
// For
Stmt VisitStmt_(const ForNode* op) final {
CHECK(is_zero(op->min));
- Expr extent = this->VisitExpr(op->extent);
+ PrimExpr extent = this->VisitExpr(op->extent);
if (visit_touched_var_ && !vt_loop_injected_) {
Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
++max_loop_depth_;
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
if (op->new_expr.defined() && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
bool changed = false;
- Array<Expr> extents;
+ Array<PrimExpr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
- Expr new_ext = this->VisitExpr(op->extents[i]);
+ PrimExpr new_ext = this->VisitExpr(op->extents[i]);
if (visit_touched_var_ && !vt_loop_injected_) {
return InjectVTLoop(GetRef<Stmt>(op), true);
}
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
- Expr stride = arith::ComputeReduce<MulNode>(
- op->extents, Expr()) * op->dtype.lanes();
- Array<Expr> other;
+ PrimExpr stride = arith::ComputeReduce<MulNode>(
+ op->extents, PrimExpr()) * op->dtype.lanes();
+ Array<PrimExpr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
- for (Expr e : extents) {
+ for (PrimExpr e : extents) {
other.push_back(e);
}
extents = other;
} else {
// insert a for loop
Var idx(var_->name_hint + ".s", var_->dtype);
- Map<Var, Expr> values{{var_, idx}};
+ Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return ForNode::make(idx, make_zero(idx.dtype()),
make_const(idx.dtype(), num_threads_),
// Whether allow shareding.
bool allow_share_;
// The allocations that get touched -> extent
- std::unordered_map<const VarNode*, Expr> alloc_remap_;
+ std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
};
// ConvertSSA need to be applied after this pass
class IRInline final : public StmtExprMutator {
public:
- IRInline(FunctionRef f, Array<Var> args, Expr body)
+ IRInline(FunctionRef f, Array<Var> args, PrimExpr body)
: f_(f), args_(args), body_(body) {}
- Expr VisitExpr_(const CallNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (op->func == f_) {
expr = LetNode::make(args_[i], op->args[i], expr);
}
} else {
- Map<Var, Expr> vmap;
+ Map<Var, PrimExpr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
private:
FunctionRef f_;
Array<Var> args_;
- Expr body_;
+ PrimExpr body_;
};
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
- Expr body) {
+ PrimExpr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
Stmt ret = IRInline(f, args, body)(std::move(stmt));
namespace tvm {
namespace ir {
-using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>;
+using ExprComparator = ExprFunctor<void(const PrimExpr& n, const PrimExpr &other)>;
using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>;
#define DEFINE_BIOP_EXPR_CMP_(OP) \
- void VisitExpr_(const OP* op, const Expr& other) final { \
+ void VisitExpr_(const OP* op, const PrimExpr& other) final { \
const OP* rhs = other.as<OP>(); \
- if (CompareExpr(op->a, rhs->a) != 0) return; \
- if (CompareExpr(op->b, rhs->b) != 0) return; \
+ if (CompareExpr(op->a, rhs->a) != 0) return; \
+ if (CompareExpr(op->b, rhs->b) != 0) return; \
}
// Deep comparison to check if two IR graph are equivalent
return order_ == 0;
}
- bool Equal(const Expr& lhs, const Expr& rhs) {
+ bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) {
tie_def_ = true;
VisitExpr(lhs, rhs);
return order_ == 0;
}
- int Compare(const Expr& lhs, const Expr& rhs) {
+ int Compare(const PrimExpr& lhs, const PrimExpr& rhs) {
tie_def_ = false;
VisitExpr(lhs, rhs);
return order_;
}
- void VisitExpr(const Expr& n, const Expr& other) override {
+ void VisitExpr(const PrimExpr& n, const PrimExpr& other) override {
if (order_ != 0) return;
if (n.same_as(other)) return;
if (CompareValue(n->type_index(), other->type_index()) != 0) return;
}
// Exprs
- void VisitExpr_(const VarNode* op, const Expr& other) final {
+ void VisitExpr_(const VarNode* op, const PrimExpr& other) final {
const VarNode* rhs = other.as<VarNode>();
auto it = vmap_.find(op);
if (it != vmap_.end()) op = it->second;
order_ = +1;
}
}
- void VisitExpr_(const LoadNode* op, const Expr& other) final {
+ void VisitExpr_(const LoadNode* op, const PrimExpr& other) final {
const LoadNode* rhs = other.as<LoadNode>();
if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
if (CompareExpr(op->index, rhs->index) != 0) return;
if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
}
- void VisitExpr_(const LetNode* op, const Expr& other) final {
+ void VisitExpr_(const LetNode* op, const PrimExpr& other) final {
const LetNode* rhs = other.as<LetNode>();
if (tie_def_) {
vmap_[op->var.get()] = rhs->var.get();
if (CompareExpr(op->body, rhs->body) != 0) return;
}
- void VisitExpr_(const CallNode* op, const Expr& other) final {
+ void VisitExpr_(const CallNode* op, const PrimExpr& other) final {
const CallNode* rhs = other.as<CallNode>();
if (CompareString(op->name, rhs->name)) return;
if (CompareArray(op->args, rhs->args)) return;
if (CompareValue(op->value_index, rhs->value_index) != 0) return;
}
- void VisitExpr_(const ReduceNode *op, const Expr& other) final {
+ void VisitExpr_(const ReduceNode *op, const PrimExpr& other) final {
const ReduceNode* rhs = other.as<ReduceNode>();
if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
if (CompareArray(op->source, rhs->source) != 0) return;
}
- void VisitExpr_(const IntImmNode *op, const Expr& other) final {
+ void VisitExpr_(const IntImmNode *op, const PrimExpr& other) final {
CompareValue(op->value, other.as<IntImmNode>()->value);
}
- void VisitExpr_(const UIntImmNode *op, const Expr& other) final {
+ void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final {
CompareValue(op->value, other.as<UIntImmNode>()->value);
}
- void VisitExpr_(const FloatImmNode *op, const Expr& other) final {
+ void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final {
CompareValue(op->value, other.as<FloatImmNode>()->value);
}
- void VisitExpr_(const StringImmNode *op, const Expr& other) final {
+ void VisitExpr_(const StringImmNode *op, const PrimExpr& other) final {
CompareString(op->value, other.as<StringImmNode>()->value);
}
- void VisitExpr_(const CastNode *op, const Expr& other) final {
+ void VisitExpr_(const CastNode *op, const PrimExpr& other) final {
CompareExpr(op->value, other.as<CastNode>()->value);
}
- void VisitExpr_(const NotNode *op, const Expr& other) final {
+ void VisitExpr_(const NotNode *op, const PrimExpr& other) final {
CompareExpr(op->a, other.as<NotNode>()->a);
}
- void VisitExpr_(const SelectNode *op, const Expr& other) final {
+ void VisitExpr_(const SelectNode *op, const PrimExpr& other) final {
const SelectNode* rhs = other.as<SelectNode>();
if (CompareExpr(op->condition, rhs->condition) != 0) return;
if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
}
- void VisitExpr_(const RampNode *op, const Expr& other) final {
+ void VisitExpr_(const RampNode *op, const PrimExpr& other) final {
const RampNode* rhs = other.as<RampNode>();
if (CompareExpr(op->base, rhs->base) != 0) return;
if (CompareExpr(op->stride, rhs->stride) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
- void VisitExpr_(const BroadcastNode *op, const Expr& other) final {
+ void VisitExpr_(const BroadcastNode *op, const PrimExpr& other) final {
const BroadcastNode* rhs = other.as<BroadcastNode>();
if (CompareExpr(op->value, rhs->value) != 0) return;
if (CompareValue(op->lanes, rhs->lanes) != 0) return;
}
- void VisitExpr_(const ShuffleNode *op, const Expr& other) final {
+ void VisitExpr_(const ShuffleNode *op, const PrimExpr& other) final {
const ShuffleNode* rhs = other.as<ShuffleNode>();
if (CompareArray(op->vectors, rhs->vectors) != 0) return;
if (CompareArray(op->indices, rhs->indices) != 0) return;
DEFINE_BIOP_EXPR_CMP_(OrNode)
private:
- int CompareExpr(const Expr& lhs, const Expr& rhs) {
+ int CompareExpr(const PrimExpr& lhs, const PrimExpr& rhs) {
if (order_ != 0) return order_;
if (!lhs.defined() && rhs.defined()) {
order_ = -1; return order_;
return order_;
}
- int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) {
+ int CompareArray(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
if (order_ != 0) return order_;
if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
for (size_t i = 0; i < lhs.size(); ++i) {
return IRDeepCompare().Equal(lhs, rhs);
}
-bool Equal(const Expr& lhs, const Expr& rhs) {
+bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) {
// quick pass for constant expressions.
if (const int64_t *a = as_const_int(lhs)) {
if (const int64_t *b = as_const_int(rhs)) {
return IRDeepCompare().Equal(lhs, rhs);
}
-int Compare(const Expr& lhs, const Expr& rhs) {
+int Compare(const PrimExpr& lhs, const PrimExpr& rhs) {
return IRDeepCompare().Compare(lhs, rhs);
}
public:
explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
- void VisitExpr(const Expr& node) final {
+ void VisitExpr(const PrimExpr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
ExprVisitor::VisitExpr(node);
visitor(Downcast<Stmt>(node));
} else {
IRApplyVisit visitor(fvisit);
- visitor(Downcast<Expr>(node));
+ visitor(Downcast<PrimExpr>(node));
}
}
return this->BaseVisitStmt(s);
});
}
- Expr VisitExpr(const Expr& expr) final {
- return MutateInternal<Expr>(expr, [this](const Expr& e) {
+ PrimExpr VisitExpr(const PrimExpr& expr) final {
+ return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
return this->BaseVisitExpr(e);
});
}
Stmt BaseVisitStmt(const Stmt& s) {
return StmtMutator::VisitStmt(s);
}
- Expr BaseVisitExpr(const Expr& e) {
+ PrimExpr BaseVisitExpr(const PrimExpr& e) {
return ExprMutator::VisitExpr(e);
}
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
- const Array<Expr>& only_enable) {
+ const Array<PrimExpr>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
- for (Expr s : only_enable) {
+ for (PrimExpr s : only_enable) {
only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
}
void StmtVisitor::VisitStmt_(const AllocateNode* op) {
- VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); });
+ VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitStmt(op->body);
this->VisitExpr(op->condition);
if (op->new_expr.defined()) {
}
void StmtVisitor::VisitStmt_(const ProvideNode* op) {
- VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
+ VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitExpr(op->value);
}
}
void ExprVisitor::VisitExpr_(const CallNode* op) {
- VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
+ VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
#define DEFINE_BINOP_VISIT_(OP) \
this->VisitExpr(r->dom->min);
this->VisitExpr(r->dom->extent);
});
- VisitArray(op->source, [this](const Expr& e) { this->VisitExpr(e); });
+ VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
this->VisitExpr(op->condition);
}
}
void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
- VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); });
- VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); });
+ VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
+ VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const BroadcastNode* op) {
class StmtMutator::Internal {
public:
- static Array<Expr> Mutate(StmtMutator* self, const Array<Expr>& arr) {
- auto fmutate = [self](const Expr& e) { return self->VisitExpr(e); };
+ static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
+ auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
return MutateArray(arr, fmutate, self->allow_copy_on_write_);
}
static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
auto fmutate = [self](const Range& r) {
- Expr min = self->VisitExpr(r->min);
- Expr extent = self->VisitExpr(r->extent);
+ PrimExpr min = self->VisitExpr(r->min);
+ PrimExpr extent = self->VisitExpr(r->extent);
if (min.same_as(r->min) && extent.same_as(r->extent)) {
return r;
} else {
};
Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
}
Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
}
Stmt StmtMutator::VisitStmt_(const ForNode* op) {
- Expr min = this->VisitExpr(op->min);
- Expr extent = this->VisitExpr(op->extent);
+ PrimExpr min = this->VisitExpr(op->min);
+ PrimExpr extent = this->VisitExpr(op->extent);
Stmt body = this->VisitStmt(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
}
Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
- Array<Expr> extents = Internal::Mutate(this, op->extents);
+ Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
Stmt body = this->VisitStmt(op->body);
- Expr condition = this->VisitExpr(op->condition);
- Expr new_expr;
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr new_expr;
if (op->new_expr.defined()) {
new_expr = this->VisitExpr(op->new_expr);
}
}
Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
}
Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
- Expr value = this->VisitExpr(op->value);
- Expr index = this->VisitExpr(op->index);
- Expr predicate = this->VisitExpr(op->predicate);
+ PrimExpr value = this->VisitExpr(op->value);
+ PrimExpr index = this->VisitExpr(op->index);
+ PrimExpr predicate = this->VisitExpr(op->predicate);
if (value.same_as(op->value) &&
index.same_as(op->index) &&
predicate.same_as(op->predicate)) {
}
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
- Array<Expr> args = Internal::Mutate(this, op->args);
- Expr value = this->VisitExpr(op->value);
+ Array<PrimExpr> args = Internal::Mutate(this, op->args);
+ PrimExpr value = this->VisitExpr(op->value);
if (args.same_as(op->args) &&
value.same_as(op->value)) {
return GetRef<Stmt>(op);
Stmt StmtMutator::VisitStmt_(const RealizeNode* op) {
Region bounds = Internal::Mutate(this, op->bounds);
Stmt body = this->VisitStmt(op->body);
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (bounds.same_as(op->bounds) &&
body.same_as(op->body) &&
condition.same_as(op->condition)) {
}
Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
- Expr condition = this->VisitExpr(op->condition);
- Expr message = this->VisitExpr(op->message);
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr message = this->VisitExpr(op->message);
Stmt body = this->VisitStmt(op->body);
if (condition.same_as(op->condition) &&
}
Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
}
-Expr ExprMutator::VisitExpr_(const VarNode* op) {
- return GetRef<Expr>(op);
+PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
+ return GetRef<PrimExpr>(op);
}
-Expr ExprMutator::VisitExpr_(const LoadNode* op) {
- Expr index = this->VisitExpr(op->index);
- Expr predicate = this->VisitExpr(op->predicate);
+PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
+ PrimExpr index = this->VisitExpr(op->index);
+ PrimExpr predicate = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LoadNode::make(op->dtype, op->buffer_var, index, predicate);
}
}
-Expr ExprMutator::VisitExpr_(const LetNode* op) {
- Expr value = this->VisitExpr(op->value);
- Expr body = this->VisitExpr(op->body);
+PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
+ PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
-Expr ExprMutator::VisitExpr_(const CallNode* op) {
- auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); };
- Array<Expr> args = MutateArray(op->args, fmutate);
+PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
+ auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+ Array<PrimExpr> args = MutateArray(op->args, fmutate);
if (args.same_as(op->args)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CallNode::make(op->dtype,
op->name,
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
- Expr ExprMutator::VisitExpr_(const OP *op) { \
- return GetRef<Expr>(op); \
+ PrimExpr ExprMutator::VisitExpr_(const OP *op) { \
+ return GetRef<PrimExpr>(op); \
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
#define DEFINE_BIOP_EXPR_MUTATE_(OP) \
- Expr ExprMutator::VisitExpr_(const OP* op) { \
- Expr a = this->VisitExpr(op->a); \
- Expr b = this->VisitExpr(op->b); \
+ PrimExpr ExprMutator::VisitExpr_(const OP* op) { \
+ PrimExpr a = this->VisitExpr(op->a); \
+ PrimExpr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
- return GetRef<Expr>(op); \
+ return GetRef<PrimExpr>(op); \
} else { \
return OP::make(a, b); \
} \
DEFINE_BIOP_EXPR_MUTATE_(AndNode);
DEFINE_BIOP_EXPR_MUTATE_(OrNode);
-Expr ExprMutator::VisitExpr_(const ReduceNode* op) {
+PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
auto fitervar = [this](const IterVar& v) {
Range r = v->dom;
- Expr min = this->VisitExpr(r->min);
- Expr extent = this->VisitExpr(r->extent);
+ PrimExpr min = this->VisitExpr(r->min);
+ PrimExpr extent = this->VisitExpr(r->extent);
if (min.same_as(r->min) &&
extent.same_as(r->extent)) {
return v;
};
Array<IterVar> axis = MutateArray(op->axis, fitervar);
- auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
- Array<Expr> source = MutateArray(op->source, fexpr);
+ auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+ Array<PrimExpr> source = MutateArray(op->source, fexpr);
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (axis.same_as(op->axis) &&
source.same_as(op->source) &&
condition.same_as(op->condition)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return ReduceNode::make(
op->combiner, source, axis, condition, op->value_index);
}
}
-Expr ExprMutator::VisitExpr_(const CastNode* op) {
- Expr value = this->VisitExpr(op->value);
+PrimExpr ExprMutator::VisitExpr_(const CastNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CastNode::make(op->dtype, value);
}
}
-Expr ExprMutator::VisitExpr_(const NotNode* op) {
- Expr a = this->VisitExpr(op->a);
+PrimExpr ExprMutator::VisitExpr_(const NotNode* op) {
+ PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return NotNode::make(a);
}
}
-Expr ExprMutator::VisitExpr_(const SelectNode* op) {
- Expr condition = this->VisitExpr(op->condition);
- Expr true_value = this->VisitExpr(op->true_value);
- Expr false_value = this->VisitExpr(op->false_value);
+PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) {
+ PrimExpr condition = this->VisitExpr(op->condition);
+ PrimExpr true_value = this->VisitExpr(op->true_value);
+ PrimExpr false_value = this->VisitExpr(op->false_value);
if (condition.same_as(op->condition) &&
true_value.same_as(op->true_value) &&
false_value.same_as(op->false_value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return SelectNode::make(condition, true_value, false_value);
}
}
-Expr ExprMutator::VisitExpr_(const RampNode* op) {
- Expr base = this->VisitExpr(op->base);
- Expr stride = this->VisitExpr(op->stride);
+PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
+ PrimExpr base = this->VisitExpr(op->base);
+ PrimExpr stride = this->VisitExpr(op->stride);
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return RampNode::make(base, stride, op->lanes);
}
}
-Expr ExprMutator::VisitExpr_(const BroadcastNode* op) {
- Expr value = this->VisitExpr(op->value);
+PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return BroadcastNode::make(value, op->lanes);
}
}
-Expr ExprMutator::VisitExpr_(const ShuffleNode* op) {
- auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
+PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
+ auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = MutateArray(op->vectors, fexpr);
if (vectors.same_as(op->vectors)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return ShuffleNode::make(vectors, op->indices);
}
* \param kind The data kind.
* \return the get expression.
*/
-inline Expr TVMStructGet(
+inline PrimExpr TVMStructGet(
DataType dtype, Var handle, int index,
intrinsic::TVMStructFieldKind kind) {
- Array<Expr> args ={
+ Array<PrimExpr> args ={
handle,
make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind))};
* \param dtype The data type.
* \param offset the offset index.
*/
-inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
+inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
return CallNode::make(
DataType::Handle(), intrinsic::tvm_address_of,
{LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
* \param dtype The data type.
* \param offset the offset index.
*/
-inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
+inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
if (dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
*/
inline Stmt TVMStructSet(
Var handle, int index,
- intrinsic::TVMStructFieldKind kind, Expr value) {
- Array<Expr> args ={
+ intrinsic::TVMStructFieldKind kind, PrimExpr value) {
+ Array<PrimExpr> args ={
handle,
make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind)),
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
-inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
+inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
const RampNode* r = index.as<RampNode>();
if (!r) return false;
if (!is_one(r->stride)) return false;
attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
- attr_value_ = Expr();
+ attr_value_ = PrimExpr();
return AllocateNode::make(
op->buffer_var, op->dtype,
op->extents, op->condition, body,
Stmt VisitStmt_(const SeqStmtNode* op) final {
// remember the decorations.
std::vector<ObjectRef> attr_node;
- std::vector<Expr> attr_value;
+ std::vector<PrimExpr> attr_value;
auto fmutate = [&](const Stmt& s) {
attr_node_ = ObjectRef();
- attr_value_ = Expr();
+ attr_value_ = PrimExpr();
Stmt ret = this->VisitStmt(s);
attr_node.push_back(attr_node_);
attr_value.push_back(attr_value_);
begin = end;
}
attr_node_ = ObjectRef();
- attr_value_ = Expr();
+ attr_value_ = PrimExpr();
return SeqStmt::Flatten(reorg);
}
}
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
- Expr first_value;
+ PrimExpr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
- attr_value_ = Expr();
+ attr_value_ = PrimExpr();
}
if (then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
private:
// value comparison that also compares content of int constant
- static bool ValueSame(const Expr& a, const Expr& b) {
+ static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
if (a.same_as(b)) return true;
if (!a.defined() || !b.defined()) return false;
if (a->type_index() != b->type_index()) return false;
std::string attr_key_;
ObjectRef attr_node_;
- Expr attr_value_;
+ PrimExpr attr_value_;
};
Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
-bool ExprUseVars(Expr expr, const std::unordered_set<const VarNode*>& vars) {
+bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
if (const VarNode* v = node.as<VarNode>()) {
// (currently, "likely" conditions) has fixed true or false value
class PartitionFinder : public StmtExprVisitor {
public:
- explicit PartitionFinder(VarExpr current_var,
+ explicit PartitionFinder(Var current_var,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map)
: current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) {
void VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(CallNode::likely)) {
- Expr cond = op->args[0];
+ PrimExpr cond = op->args[0];
if (ExprUseVars(cond,
std::unordered_set<const VarNode*>({current_var_.get()}))) {
// For cond, find out the interval, if exists, in which we can prove that cond is
// cond is true within interval
partitions[{cond.get(), true}] = interval;
}
- Expr inverse_cond = InverseCond(cond);
+ PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval =
DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
Partition partitions;
private:
- Expr InverseCond(const Expr& cond) {
- Expr inverse_cond;
+ PrimExpr InverseCond(const PrimExpr& cond) {
+ PrimExpr inverse_cond;
if (const LTNode* op = cond.as<LTNode>()) {
// a < b -> a >= b
inverse_cond = GENode::make(op->a, op->b);
return inverse_cond;
}
- VarExpr current_var_;
+ Var current_var_;
std::unordered_set<const VarNode*> out_vars_;
std::unordered_map<const VarNode*, IntSet> hint_map_;
std::unordered_map<const VarNode*, IntSet> relax_map_;
explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}
- Expr VisitExpr(const Expr& e) final {
+ PrimExpr VisitExpr(const PrimExpr& e) final {
if (ps_.find(e.get()) != ps_.end()) {
return VisitExpr(cond_value_ ? const_true() : const_false());
}
class ThreadPartitionInserter : public StmtMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
- Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
+ PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_)(op->body);
Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
private:
const std::unordered_set<const Object*>& ps_;
- Expr cond_;
+ PrimExpr cond_;
bool innermost_thread_scope_;
};
}
private:
- Stmt TryPartition(const Object* op, const Stmt& stmt, VarExpr var,
- Expr min, Expr max, Stmt body, bool partition_thread_scope);
+ Stmt TryPartition(const Object* op, const Stmt& stmt, Var var,
+ PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope);
std::pair<IntSet, std::unordered_set<const Object*>>
GetIntervalAndCondset(const Partition &partitions,
const arith::IntervalSet &for_interval,
bool cond_value);
- inline Stmt MakeFor(const Object* op, Expr extent, Stmt body);
+ inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
/* Candidate IRs that may be partitioned potentially */
std::unordered_map<const VarNode*, IntSet> hint_map_;
*/
Stmt LoopPartitioner::TryPartition(const Object* node,
const Stmt& stmt,
- VarExpr var,
- Expr min,
- Expr max,
+ Var var,
+ PrimExpr min,
+ PrimExpr max,
Stmt body,
bool partition_thread_scope) {
using namespace arith;
// Calculating pre-subrange and generating code for it.
// pre-subrange = [min, body_begin)
- Expr body_begin;
+ PrimExpr body_begin;
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
- Expr cond = (body_begin - min >= 0);
+ PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the pre doubt loop";
// Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max+1)
- Expr post_doubt_begin;
+ PrimExpr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
- Expr cond = (max - post_doubt_begin + 1 >= 0);
+ PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
if (!analyzer_.CanProve(cond)) {
LOG(WARNING) << "Cannot prove: " << cond
<< ", when generating the post doubt loop";
}
s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
} else {
- Expr cond = const_true();
+ PrimExpr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond)(stmt);
return s;
}
-inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) {
+inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) {
const ForNode *for_node = static_cast<const ForNode*>(node);
CHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
class RemoveLikelyTags : public StmtExprMutator {
public:
- Expr VisitExpr_(const CallNode *op) final {
+ PrimExpr VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(CallNode::likely)) {
CHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
public:
explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
- inline Expr VisitExpr_(const CastNode* op) final {
+ inline PrimExpr VisitExpr_(const CastNode* op) final {
auto type_code = op->dtype.code();
auto src_type_code = op->value.dtype().code();
// If either datatype is a registered custom datatype, we must lower.
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
datatype::Registry::Global()->GetTypeRegistered(src_type_code);
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CastNode>();
if (toBeLowered) {
auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
return expr;
}
- inline Expr VisitExpr_(const FloatImmNode* imm) final {
+ inline PrimExpr VisitExpr_(const FloatImmNode* imm) final {
auto type_code = imm->dtype.code();
- auto e = GetRef<Expr>(imm);
+ auto e = GetRef<PrimExpr>(imm);
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
return stmt;
}
- inline Expr VisitExpr_(const LoadNode* load) final {
+ inline PrimExpr VisitExpr_(const LoadNode* load) final {
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
- Expr expr = StmtExprMutator::VisitExpr_(load);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(load);
load = expr.as<LoadNode>();
if (toBeLowered) {
auto new_load_type = DataType::UInt(load->dtype.bits());
}
#define DEFINE_MUTATE__(OP, NodeName) \
- inline Expr VisitExpr_(const NodeName* op) final { \
+ inline PrimExpr VisitExpr_(const NodeName* op) final { \
auto type_code = op->dtype.code(); \
bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
- Expr expr = StmtExprMutator::VisitExpr_(op); \
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op); \
op = expr.as<NodeName>(); \
if (toBeLowered) { \
auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->call_type == CallNode::Intrinsic ||
op->call_type == CallNode::PureIntrinsic) {
- Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
+ PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
if (r.defined()) return r;
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const AddNode* op) final {
+ PrimExpr VisitExpr_(const AddNode* op) final {
if (const MulNode* mb = op->b.as<MulNode>()) {
return MakeFMA(mb->a, mb->b, op->a, op);
} else if (const MulNode* ma = op->a.as<MulNode>()) {
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
- Expr VisitExpr_(const FloorDivNode* op) final {
- auto e = GetRef<Expr>(op);
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr VisitExpr_(const FloorDivNode* op) final {
+ auto e = GetRef<PrimExpr>(op);
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (op == nullptr) return ret;
int shift;
return truncdiv(op->a, op->b);
} else {
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
- Expr rdiv = truncdiv(op->a, op->b);
- Expr rmod = truncmod(op->a, op->b);
+ PrimExpr rdiv = truncdiv(op->a, op->b);
+ PrimExpr rmod = truncmod(op->a, op->b);
// condition on b >= 0.
// truncmod(a, b) < 0 will implies ceildiv,
// So we need to correct these cases.
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
- Expr rdiv = truncdiv(op->a, op->b);
- Expr rmod = truncmod(op->a, op->b);
+ PrimExpr rdiv = truncdiv(op->a, op->b);
+ PrimExpr rmod = truncmod(op->a, op->b);
return ir::SelectNode::make(
(op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
rdiv, rdiv - make_const(dtype, 1));
}
}
- Expr VisitExpr_(const FloorModNode* op) final {
- Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+ PrimExpr VisitExpr_(const FloorModNode* op) final {
+ PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorModNode>();
if (op == nullptr) return ret;
// Lower floordiv to native truncdiv.
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
// So we need to correct these cases.
- Expr rmod = truncmod(op->a, op->b);
+ PrimExpr rmod = truncmod(op->a, op->b);
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
// (rmod >> shift) & b
// -> (rmod >= 0 ? 0: -1) & b
} else {
// uncommon case
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
- Expr rmod = truncmod(op->a, op->b);
+ PrimExpr rmod = truncmod(op->a, op->b);
// b > 0 && rmod >= 0 -> rmod
// b > 0 && rmod < 0 -> rmod + b
// b < 0 && rmod < 0 -> rmod
}
}
- Expr VisitExpr_(const MaxNode* op) final {
+ PrimExpr VisitExpr_(const MaxNode* op) final {
using namespace arith;
- PVar<Expr> x, y;
+ PVar<PrimExpr> x, y;
PVar<Integer> c;
- auto e = GetRef<Expr>(op);
+ auto e = GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const EQNode* op) final {
+ PrimExpr VisitExpr_(const EQNode* op) final {
using namespace arith;
- PVar<Expr> x, y;
- auto e = GetRef<Expr>(op);
+ PVar<PrimExpr> x, y;
+ auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval());
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr VisitExpr_(const NENode* op) final {
+ PrimExpr VisitExpr_(const NENode* op) final {
using namespace arith;
- PVar<Expr> x, y;
- auto e = GetRef<Expr>(op);
+ PVar<PrimExpr> x, y;
+ auto e = GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval());
}
}
private:
- Expr SwapBroadcastCast(const Expr& e) {
+ PrimExpr SwapBroadcastCast(const PrimExpr& e) {
// Try to change broadcast(cast(x)) to cast(broadcast(x))
// For some targets, LLVM will generate more efficient FMA
// instruction with the latter. For example, vmla vs. vmlal
};
if (should_swap()) {
- Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
+ PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
return CastNode::make(bcast->dtype, new_bcast);
}
}
return e;
}
- Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
+ PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c,
const AddNode* op) {
// emit fma instruction: a * b + c
- Expr lhs = SwapBroadcastCast(a);
- Expr rhs = SwapBroadcastCast(b);
+ PrimExpr lhs = SwapBroadcastCast(a);
+ PrimExpr rhs = SwapBroadcastCast(b);
if (fma_ != nullptr && op->dtype.is_float()) {
- Expr r = (*fma_)(CallNode::make(
+ PrimExpr r = (*fma_)(CallNode::make(
op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
if (r.defined()) return this->VisitExpr(r);
} else {
if (!lhs.same_as(a) || !rhs.same_as(b)) {
- Expr mul = this->VisitExpr(MulNode::make(lhs, rhs));
+ PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs));
return AddNode::make(mul, this->VisitExpr(c));
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
- Expr ApplyPattern(const std::string& name, const Expr& e) {
+ PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) {
for (size_t i = 0; i < patterns_.size(); ++i) {
std::string& p = patterns_[i];
size_t psize = p.length();
p.resize(psize);
// if pattern exists.
if (f != nullptr) {
- Expr r = (*f)(e);
+ PrimExpr r = (*f)(e);
CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
return this->VisitExpr(r);
}
}
}
- return Expr();
+ return PrimExpr();
}
// patterns
return stmt;
}
}
- Expr VisitExpr_(const LoadNode* op) final {
+ PrimExpr VisitExpr_(const LoadNode* op) final {
auto it = load_remap_.find(op->buffer_var.get());
if (it != load_remap_.end()) {
CHECK(is_zero(op->index));
const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
CHECK(size_of_args) << call->args[0]->GetTypeKey();
CHECK_EQ(size, size_of_args->value);
- Array<Expr> inits = combiner->identity_element;
- std::vector<Expr> values(size);
+ Array<PrimExpr> inits = combiner->identity_element;
+ std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
- Expr cond = call->args[size+1];
+ PrimExpr cond = call->args[size+1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1+idx];
if (!is_one(cond)) {
// the size of each index.
int reduce_extent, group_extent;
int threadx_extent = 1;
- Expr reduce_index = FlattenThread(vred, &reduce_extent);
- Expr group_index = FlattenThread(vpar, &group_extent);
+ PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
+ PrimExpr group_index = FlattenThread(vpar, &group_extent);
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
- Expr pred = const_true(types[i].lanes());
+ PrimExpr pred = const_true(types[i].lanes());
Var buffer_var = Downcast<Var>(call->args[2+size+i]);
stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
}
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
- Expr pred = const_true(types[idx].lanes());
+ PrimExpr pred = const_true(types[idx].lanes());
seq.emplace_back(StoreNode::make(
shared_bufs[idx], values[idx],
BufIndex(reduce_index, group_index, reduce_extent), pred));
reduce_index, group_index, reduce_extent, threadx_extent));
for (size_t idx = 0; idx < size; ++idx) {
CHECK(!load_remap_.count(buffers[idx]));
- Expr pred = const_true(types[idx].lanes());
+ PrimExpr pred = const_true(types[idx].lanes());
load_remap_[buffers[idx]] = LoadNode::make(
types[idx], shared_bufs[idx],
BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
alloc_remap_[buffers[idx]] = AllocateNode::make(
shared_bufs[idx], types[idx],
- {Expr(group_extent), Expr(reduce_extent)},
+ {PrimExpr(group_extent), PrimExpr(reduce_extent)},
pred, EvaluateNode::make(0));
}
return SeqStmt::Flatten(seq);
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
const std::vector<DataType>& types,
const Array<Var>& shared_bufs,
- Expr reduce_index,
- Expr group_index,
+ PrimExpr reduce_index,
+ PrimExpr group_index,
int reduce_extent,
int threadx_extent) {
// Get next power of two
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
- Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
+ PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto freduce = [&](int offset) {
- Array<Expr> a, b;
+ Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
b.push_back(LoadNode::make(types[i], shared_bufs[i],
BufIndex(reduce_index + offset, group_index, reduce_extent),
const_true()));
a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true()));
}
- Array<Expr> ret = (*combiner)(a, b);
+ Array<PrimExpr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
- Expr cond = reduce_index < (reduce_extent - reduce_align);
+ PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
while (reduce_align > threadx_extent ||
reduce_align > warp_size_) {
reduce_align = reduce_align >> 1;
- Expr cond = reduce_index < reduce_align;
+ PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
std::vector<Stmt> in_warp_seq;
- Expr in_warp_cond = reduce_index < (reduce_align >> 1);
+ PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
in_warp_seq.emplace_back(freduce(reduce_align));
}
// Flatten the thread index.
// Also return a warp number,
- Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
+ PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec,
int* out_total_extent) {
int& total_extent = *out_total_extent;
total_extent = 1;
return make_zero(DataType::Int(32));
}
- Expr ret;
+ PrimExpr ret;
for (const ThreadEntry& e : tvec) {
if (ret.defined()) {
ret = ret + e.iv->var * total_extent;
CallNode::Intrinsic));
}
// The local buffer index.
- static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
+ static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
if (!is_zero(group_index)) {
return ir::Simplify(group_index * reduce_extent + reduce_index);
} else {
std::vector<const AttrStmtNode*> thread_extents_;
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
- std::unordered_map<const VarNode *, Expr> load_remap_;
+ std::unordered_map<const VarNode *, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode *, Stmt> alloc_remap_;
};
namespace tvm {
namespace ir {
-inline Expr ConstInt32(size_t index) {
+inline PrimExpr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(DataType::Int(32), static_cast<int>(index));
}
-inline Expr StackAlloca(std::string type, size_t num) {
- Array<Expr> args = {StringImmNode::make(type), ConstInt32(num)};
+inline PrimExpr StackAlloca(std::string type, size_t num) {
+ Array<PrimExpr> args = {StringImmNode::make(type), ConstInt32(num)};
return CallNode::make(
DataType::Handle(),
intrinsic::tvm_stack_alloca,
}
}
}
- Expr total_bytes = make_const(op->extents[0].dtype(), nbytes);
+ PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
}
CallNode::Extern),
body);
- Expr free_op = CallNode::make(DataType::Int(32),
+ PrimExpr free_op = CallNode::make(DataType::Int(32),
"TVMBackendFreeWorkspace",
{cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_),
return StmtExprMutator::VisitStmt_(op);
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op);
} else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
}
}
// call shape
- Expr MakeShape(const CallNode* op) {
+ PrimExpr MakeShape(const CallNode* op) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
// make array
- Expr MakeArray(const CallNode* op) {
+ PrimExpr MakeArray(const CallNode* op) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
- Expr strides = op->args[2];
+ PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
- Expr byte_offset = op->args[5];
+ PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packed.
- Expr MakeCallPacked(const CallNode* op) {
+ PrimExpr MakeCallPacked(const CallNode* op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
- Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
- Expr arg = op->args[i];
+ PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
+ PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
run_arg_stack_ = arg_stack_begin;
- Array<Expr> packed_args = {
+ Array<PrimExpr> packed_args = {
op->args[0],
stack_value_,
stack_tcode_,
packed_args, CallNode::Intrinsic);
}
- Expr MakeCallTracePacked(const CallNode *op) {
+ PrimExpr MakeCallTracePacked(const CallNode *op) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
size_t args_size = op->args.size();
CHECK_GT(args_size, 0);
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 1; i < op->args.size(); ++i) {
- Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
- Expr arg = op->args[i];
+ PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
+ PrimExpr arg = op->args[i];
DataType t = arg.dtype();
DataType api_type = APIType(t);
if (t != api_type) {
// Update the top of the stack, so we can use more than one
// packed function's arguments with the one stack.
run_arg_stack_ = arg_stack_begin + args_size - 1;
- Array<Expr> packed_args = {
+ Array<PrimExpr> packed_args = {
op->args[0],
stack_value_,
stack_tcode_,
}
private:
- bool IsArrayHandle(const Expr& arg) {
+ bool IsArrayHandle(const PrimExpr& arg) {
// specially set array handle.
if (const CallNode* buf = arg.as<CallNode>()) {
if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
- Expr device_type_;
- Expr device_id_;
+ PrimExpr device_type_;
+ PrimExpr device_id_;
// Var handle for each stack.
Var stack_shape_;
Var stack_array_;
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
} else {
- Expr base;
+ PrimExpr base;
CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
}
}
- void UpdatePattern(const Expr& index) {
- Array<Expr> m =
+ void UpdatePattern(const PrimExpr& index) {
+ Array<PrimExpr> m =
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff = 0;
- Expr mcoeff = analyzer_->canonical_simplify(m[0]);
+ PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
<< "LowerWarpMemory failed due to store index=" << index
}
protected:
- Expr Mutate_(const VarNode* op) {
+ PrimExpr Mutate_(const VarNode* op) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
Stmt VisitStmt_(const StoreNode* op) {
if (op->buffer_var.get() == buffer_) {
- Expr local_index, group;
+ PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
}
}
- Expr Mutate_(const LoadNode* op) {
+ PrimExpr Mutate_(const LoadNode* op) {
if (op->buffer_var.get() == buffer_) {
- Expr local_index, group;
+ PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
// invariance: local index must do not contain warp id
CHECK(!ExprUseVar(local_index, {warp_index_.get()}))
<< "LowerWarpMemory failed to rewrite load to shuffle for index "
<< op->index << " local_index=" << local_index;
- Expr load_value = LoadNode::make(
+ PrimExpr load_value = LoadNode::make(
op->dtype, op->buffer_var, local_index, op->predicate);
return CallNode::make(load_value.dtype(),
intrinsic::tvm_warp_shuffle,
// local index is the index in the local
// source index is the corresponding source index
// in this access pattern.
- std::pair<Expr, Expr> SplitIndexByGroup(const Expr& index) {
+ std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
if (index.dtype().lanes() != 1) {
- Expr base, local_index, group;
+ PrimExpr base, local_index, group;
CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
local_index =
RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group);
}
- Expr m = make_const(index.dtype(), warp_coeff_);
+ PrimExpr m = make_const(index.dtype(), warp_coeff_);
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
- Expr x = analyzer_->canonical_simplify(indexmod(index, m));
- Expr z = analyzer_->canonical_simplify(indexdiv(index, m));
+ PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
+ PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m));
return std::make_pair(x, z);
} else {
- Expr x = analyzer_->canonical_simplify(indexmod(index, m));
- Expr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
+ PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
+ PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
y = y * m + x;
- Expr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
+ PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
m);
return std::make_pair(analyzer_->canonical_simplify(y),
analyzer_->canonical_simplify(z));
namespace tvm {
namespace ir {
-inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
+inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
}
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check;
- std::unordered_map<const VarNode*, Expr> vmap;
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
- Array<Expr> call_args{v_packed_args,
+ Array<PrimExpr> call_args{v_packed_args,
IntImmNode::make(DataType::Int(32), i),
IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
- Expr res = CallNode::make(
+ PrimExpr res = CallNode::make(
api_type, intrinsic::tvm_struct_get, call_args,
CallNode::PureIntrinsic);
// cast to the target version.
StringImmNode::make(name + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
- Expr node = StringImmNode::make("default");
+ PrimExpr node = StringImmNode::make("default");
CHECK(vmap.count(device_type.get()));
seq_check.push_back(AttrStmtNode::make(
node, attr::device_context_id, device_id, nop));
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
- Expr value = make_const(op->value.dtype(), device_type_);
+ PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
return res;
}
- Expr VisitExpr_(const NENode* op) final {
+ PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
- Expr res = StmtExprMutator::VisitExpr_(op);
+ PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->dtype, false);
return res;
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = vmap_.find(op);
if (it != vmap_.end()) return it->second;
return StmtExprMutator::VisitExpr_(op);
};
LoweredFunc
-RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
+RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImmNode* str = kv.first.as<StringImmNode>();
}
private:
- Stmt MakeEvaluate(Expr value) {
+ Stmt MakeEvaluate(PrimExpr value) {
if (HasSideEffect(value)) {
return EvaluateNode::make(value);
} else {
return EvaluateNode::make(0);
}
}
- Stmt MakeEvaluate(const Array<Expr>& values) {
+ Stmt MakeEvaluate(const Array<PrimExpr>& values) {
Stmt stmt;
- for (Expr e : values) {
+ for (PrimExpr e : values) {
if (HasSideEffect(e)) {
if (stmt.defined()) {
stmt = SeqStmt({stmt, EvaluateNode::make(e)});
// For now, rewrite unsafe select expression to if_then_else
// TODO(tqchen) pattern matching to support masked load
-class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
+class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
public:
// select itself is always considered safe if condition is safe
// Because we will issue guard to make sure it is.
const LoadNode* l = op->args[0].as<LoadNode>();
return this->VisitExpr(l->index);
} else if (op->is_pure()) {
- for (Expr e : op->args) {
+ for (PrimExpr e : op->args) {
if (VisitExpr(e)) return true;
}
return false;
return VisitExpr(op->base) && VisitExpr(op->stride);
}
bool VisitExpr_(const ShuffleNode* op) final {
- for (Expr e : op->vectors) {
+ for (PrimExpr e : op->vectors) {
if (VisitExpr(e)) return true;
}
return false;
class UnsafeSelectRewriter : public StmtExprMutator {
public:
- Expr VisitExpr_(const SelectNode* op) {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const SelectNode* op) {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<SelectNode>();
UnsafeExprDetector unsafe;
bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
class IRSideEffect : public ExprVisitor {
public:
- void VisitExpr(const Expr& e) final {
+ void VisitExpr(const PrimExpr& e) final {
if (has_side_effect_) return;
ExprVisitor::VisitExpr(e);
}
bool has_side_effect_{false};
};
-bool HasSideEffect(const Expr& e) {
+bool HasSideEffect(const PrimExpr& e) {
IRSideEffect v;
v(e);
return v.has_side_effect_;
class IRSubstitue : public StmtExprMutator {
public:
explicit IRSubstitue(
- const std::unordered_map<const VarNode*, Expr>& smap)
+ const std::unordered_map<const VarNode*, PrimExpr>& smap)
: smap_(smap) {
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
private:
- const std::unordered_map<const VarNode*, Expr>& smap_;
+ const std::unordered_map<const VarNode*, PrimExpr>& smap_;
};
Stmt Substitute(Stmt stmt,
- const std::unordered_map<const VarNode*, Expr>& value_map) {
+ const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return stmt;
return IRSubstitue(value_map)(std::move(stmt));
}
-Expr Substitute(Expr expr,
- const std::unordered_map<const VarNode*, Expr>& value_map) {
+PrimExpr Substitute(PrimExpr expr,
+ const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map)(std::move(expr));
}
-Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
- std::unordered_map<const VarNode*, Expr> vmap;
+Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map) {
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return Substitute(stmt, vmap);
}
-Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
- std::unordered_map<const VarNode*, Expr> vmap;
+PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map) {
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
class VarTouchVisitor : public ExprVisitor {
public:
- void VisitExpr(const Expr& e) final {
+ void VisitExpr(const PrimExpr& e) final {
if (use_var_) return;
ExprVisitor::VisitExpr(e);
}
const std::unordered_set<const VarNode*>& vset_;
};
-bool ExprUseVar(const Expr& e, const Var& v) {
+bool ExprUseVar(const PrimExpr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get());
visitor(e);
return visitor.use_var_;
}
-bool ExprUseVar(const Expr& e,
+bool ExprUseVar(const PrimExpr& e,
const std::unordered_set<const VarNode*>& vset) {
ExprUseVSetVisitor visitor(vset);
visitor(e);
thread_extent_.push_back(op->value);
}
- Expr value = op->value;
+ PrimExpr value = op->value;
if (visit_thread_extent_) {
value = this->VisitExpr(value);
}
!HasSideEffect(op->value)) {
return body;
} else {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return GetRef<Stmt>(op);
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const LetNode* op) final {
+ PrimExpr VisitExpr_(const LetNode* op) final {
this->HandleDef(op->var.get());
- Expr body = this->VisitExpr(op->body);
+ PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
}
- Expr VisitExpr_(const VarNode* op) final {
- this->HandleUse(GetRef<Expr>(op));
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ this->HandleUse(GetRef<PrimExpr>(op));
return StmtExprMutator::VisitExpr_(op);
}
- Expr VisitExpr_(const LoadNode* op) final {
+ PrimExpr VisitExpr_(const LoadNode* op) final {
this->HandleUse(op->buffer_var);
return StmtExprMutator::VisitExpr_(op);
}
def_count_[v] = 1;
}
- void HandleUse(const Expr& v) {
+ void HandleUse(const PrimExpr& v) {
CHECK(v.as<VarNode>());
Var var = Downcast<Var>(v);
auto it = use_count_.find(var.get());
bool visit_thread_extent_{true};
Array<Var> undefined_;
Array<IterVar> thread_axis_;
- Array<Expr> thread_extent_;
+ Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;
};
}
}
LoweredFunc f_device(n);
- Array<Expr> call_args;
+ Array<PrimExpr> call_args;
call_args.push_back(StringImmNode::make(f_device->name));
for (Var arg : n->args) {
call_args.push_back(arg);
}
- for (Expr ext : m.thread_extent_) {
+ for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
device_funcs_.emplace_back(f_device);
std::string name_;
// the device functions
std::vector<LoweredFunc> device_funcs_;
- std::unordered_map<const VarNode*, Expr> handle_data_type_;
+ std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
public:
bool is_ssa{true};
- void VisitExpr(const Expr& n) final {
+ void VisitExpr(const PrimExpr& n) final {
if (!is_ssa) return;
StmtExprVisitor::VisitExpr(n);
}
class IRConvertSSA final : public StmtExprMutator {
public:
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
if (scope_.count(op)) {
return scope_[op].back();
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
- Expr VisitExpr_(const LetNode* op) final {
- const VarExpr& v = op->var;
+ PrimExpr VisitExpr_(const LetNode* op) final {
+ const Var& v = op->var;
if (defined_.count(v.get())) {
- Expr value = this->VisitExpr(op->value);
- VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+ PrimExpr value = this->VisitExpr(op->value);
+ Var new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
- Expr body = this->VisitExpr(op->body);
+ PrimExpr body = this->VisitExpr(op->body);
scope_[v.get()].pop_back();
return LetNode::make(new_var, value, body);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (scope_.count(op->buffer_var.get())) {
return LoadNode::make(
}
}
Stmt VisitStmt_(const LetStmtNode* op) final {
- const VarExpr& v = op->var;
+ const Var& v = op->var;
if (defined_.count(v.get())) {
- Expr value = this->VisitExpr(op->value);
- VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+ PrimExpr value = this->VisitExpr(op->value);
+ Var new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back();
}
}
Stmt VisitStmt_(const ForNode* op) final {
- const VarExpr& v = op->loop_var;
+ const Var& v = op->loop_var;
if (defined_.count(v.get())) {
- VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+ Var new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
}
}
Stmt VisitStmt_(const AllocateNode* op) final {
- const VarExpr& v = op->buffer_var;
+ const Var& v = op->buffer_var;
if (defined_.count(v.get())) {
- VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+ Var new_var = VarNode::make(v.dtype(), v->name_hint);
scope_[v.get()].push_back(new_var);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
}
private:
- std::unordered_map<const VarNode*, std::vector<VarExpr> > scope_;
+ std::unordered_map<const VarNode*, std::vector<Var> > scope_;
std::unordered_set<const VarNode*> defined_;
};
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
- Expr offset = op->args[2];
- Expr extent = op->args[3];
+ PrimExpr offset = op->args[2];
+ PrimExpr extent = op->args[3];
const IntImmNode* flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(buffer);
// The buffer scope.
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
- e.buffer = Downcast<VarExpr>(op->args[1]);
+ e.buffer = Downcast<Var>(op->args[1]);
e.touched = arith::IntSet::range(
Range::make_by_min_extent(offset, extent));
e.scope = scope;
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op);
} else {
private:
// tvm_access_ptr
- Expr MakeAccessPtr(const CallNode* op) {
+ PrimExpr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
- Expr offset = op->args[2];
+ PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
return AddressOffset(buffer_var, dtype, offset);
}
- Expr MakeTaggedAccessPtr(DataType ptr_type,
+ PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
DataType dtype,
- Expr offset,
+ PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<VarNode>());
- VarExpr buf_var = Downcast<VarExpr>(it->second);
+ Var buf_var = Downcast<Var>(it->second);
return StoreNode::make(buf_var, op->value, op->index, op->predicate);
} else {
return stmt;
// create a buffer entry
BufferEntry e;
e.bounds = op->bounds;
- Array<Expr> shape;
+ Array<PrimExpr> shape;
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
<< "Allocation exceed bound of memory tag " << skey.to_string();
}
}
- Array<Expr> strides;
+ Array<PrimExpr> strides;
if (dim_align_.count(key) != 0 && shape.size() != 0) {
- std::vector<Expr> rstrides;
+ std::vector<PrimExpr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
int first_dim = 0;
- Expr stride = make_const(shape[first_dim].dtype(), 1);
+ PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
- Expr factor = make_const(stride.dtype(), avec[dim].align_factor);
- Expr offset = make_const(stride.dtype(), avec[dim].align_offset);
+ PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+ PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride);
}
rstrides.push_back(stride);
stride = stride * shape[dim];
}
- strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+ strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}
e.buffer = BufferNode::make(
Var(key.GetName(), DataType::Handle()),
- op->dtype, shape, strides, Expr(),
+ op->dtype, shape, strides, PrimExpr(),
key.GetName(), skey.to_string(),
align, 0, kDefault);
}
}
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
auto it = var_remap_.find(op->buffer_var.get());
if (it != var_remap_.end() &&
!it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<VarNode>());
- VarExpr buf_var = Downcast<VarExpr>(it->second);
+ Var buf_var = Downcast<Var>(it->second);
return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
} else {
return expr;
}
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = var_remap_.find(op);
if (it != var_remap_.end()) {
return it->second;
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
- Expr VisitExpr_(const CallNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (op != nullptr && op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
block_size *= shape;
starts--;
}
- Expr stride(elem_cnt / block_size);
+ PrimExpr stride(elem_cnt / block_size);
- Array<Expr> args;
- std::vector<VarExpr> vars;
+ Array<PrimExpr> args;
+ std::vector<Var> vars;
for (int i = op->bounds.size() - 1; i > starts; --i) {
args.push_back(op->bounds[i]->min);
}
auto &func_name = op->func->func_name();
- vars.push_back(VarExpr(
+ vars.push_back(Var(
"prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back());
for (int i = starts - 1; i >= 0; --i) {
- vars.push_back(VarExpr(
+ vars.push_back(Var(
"prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
args.push_back(vars.back() + op->bounds[i]->min);
}
stmt = ForNode::make(
vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
- Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
- Expr address = CallNode::make(
+ PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
+ PrimExpr address = CallNode::make(
DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
- Expr prefetch = CallNode::make(
+ PrimExpr prefetch = CallNode::make(
op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
stmt = EvaluateNode::make(prefetch);
- Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
+ PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
const BufferEntry& be = buf_map_.at(key);
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
- Array<Expr> begins, extents;
+ Array<PrimExpr> begins, extents;
if (be.bounds.size() != 0) {
CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
// Whether we are out of allocation bounds and buffer get released.
bool released{false};
// relative index
- inline Array<Expr> RelIndex(Array<Expr> args) const {
+ inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
if (bounds.size() != 0) {
- Array<Expr> index;
+ Array<PrimExpr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
}
};
- bool ShapeIsValid(const Array<Expr> &shape) {
+ bool ShapeIsValid(const Array<PrimExpr> &shape) {
// Zero-dimensional tensor does not need boundary check.
if (!shape.size())
return false;
return true;
}
- Expr MakeBound(const DataType &type, const Array<Expr> &shape) {
+ PrimExpr MakeBound(const DataType &type, const Array<PrimExpr> &shape) {
// We have already checked the shape size to be greater then 0.
- Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
+ PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
for (size_t i = 1; i < shape.size(); ++i) {
bound = MulNode::make(
bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
// The buffer assignment map
// Variable remap
- std::unordered_map<const VarNode*, Expr> var_remap_;
+ std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Buffer map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
// Dimension alignment
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes.
- std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
+ std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
// bounds populator. We really need the analyzer from it.
// However
IRVisitorWithAnalyzer* bounded_analyzer_;
if (!result_) return;
StmtExprVisitor::VisitStmt(n);
}
- void VisitExpr(const Expr& n) final {
+ void VisitExpr(const PrimExpr& n) final {
if (!result_) return;
StmtExprVisitor::VisitExpr(n);
}
RemapIndex(op->value.dtype(), op->index, it->second),
op->predicate);
}
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return expr;
RemapIndex(op->dtype, op->index, it->second),
op->predicate);
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->bits_offset != 0) {
}
return it->second->alloc_var;
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
return StmtExprMutator::VisitExpr_(op);
}
const StorageEntry* se = it->second;
- Expr offset = this->VisitExpr(op->args[2]);
- Expr extent = this->VisitExpr(op->args[3]);
+ PrimExpr offset = this->VisitExpr(op->args[2]);
+ PrimExpr extent = this->VisitExpr(op->args[3]);
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(se->bits_offset % elem_bits, 0U);
if (se->bits_offset != 0) {
// The replacement allocation, if any.
Stmt new_alloc;
// The var expr of new allocation.
- VarExpr alloc_var;
+ Var alloc_var;
// The allocation element type.
DataType elem_type;
// This is non-zero if this allocate is folded into another one
return MergeNest(nest, body);
}
// Remap the index
- Expr RemapIndex(DataType dtype, Expr index, StorageEntry* e) {
+ PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
if (e->bits_offset == 0) return index;
uint64_t elem_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(e->bits_offset % elem_bits, 0U);
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
- Expr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
+ PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
make_const(DataType::Int(32), 1));
e->new_alloc = AllocateNode::make(
e->alloc_var, alloc_type, {sz},
}
} else {
// Build a merged allocation
- Expr combo_size;
+ PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
- Expr sz = arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
+ PrimExpr sz = arith::ComputeReduce<MulNode>(
+ op->extents, make_const(DataType::Int(32), 1));
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
- Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
+ PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
e->new_alloc = AllocateNode::make(
e->alloc_var, e->elem_type, {alloc_size}, const_true(),
// if all its access is the same vector type.
class VectorAllocRewriter : public StmtExprMutator {
public:
- Expr VisitExpr_(const LoadNode* op) final {
+ PrimExpr VisitExpr_(const LoadNode* op) final {
UpdateTypeMap(op->buffer_var.get(), op->dtype);
return StmtExprMutator::VisitExpr_(op);
}
UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
return StmtExprMutator::VisitStmt_(op);
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
tvec[0].lanes() % op->dtype.lanes() == 0 &&
tvec[0].lanes() != op->dtype.lanes()) {
int factor = tvec[0].lanes() / op->dtype.lanes();
- Array<Expr> extents = op->extents;
+ Array<PrimExpr> extents = op->extents;
arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
if (arg.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
- Expr dtype = make_const(tvec[0], 0);
+ PrimExpr dtype = make_const(tvec[0], 0);
n->handle_data_type.Set(arg, dtype);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
- Expr dtype = make_const(tvec[0].with_lanes(1), 0);
+ PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
n->handle_data_type.Set(arg, dtype);
}
}
return StmtExprMutator::VisitStmt(stmt);
}
}
- Expr VisitExpr_(const LoadNode* op) final {
+ PrimExpr VisitExpr_(const LoadNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count;
// first thread scope.
if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
- num_blocks_ = Expr();
- is_lead_ = Expr();
+ num_blocks_ = PrimExpr();
+ is_lead_ = PrimExpr();
}
return ret;
} else if (op->attr_key == attr::storage_scope) {
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
const VarNode* buffer_var = op->args[1].as<VarNode>();
// private functions.
Stmt InitGlobalBarrier(const AttrStmtNode* op) {
CHECK(op != nullptr);
- Array<Expr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
+ Array<PrimExpr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
Stmt prep = EvaluateNode::make(
CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
Stmt body = op->body;
num_blocks_ = (num_blocks_.defined() ?
attr->value * num_blocks_ : attr->value);
} else if (s.rank == 1) {
- Expr cond = iv->var == make_zero(iv->var.dtype());
+ PrimExpr cond = iv->var == make_zero(iv->var.dtype());
is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
}
}
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageScope> storage_scope_;
// The read write statistics of storage
- std::unordered_map<VarExpr, Entry, ObjectHash, ObjectEqual> rw_stats_;
+ std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> rw_stats_;
// The statistics for global barrier
bool in_thread_env_{false};
// memorized results
std::vector<const AttrStmtNode*> thread_extents_;
size_t num_work_dim_{0};
- Expr num_blocks_;
- Expr is_lead_;
+ PrimExpr num_blocks_;
+ PrimExpr is_lead_;
};
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
}
}
-Expr unpack_type_cast(const Expr &input, const DataType &target_type) {
+PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) {
auto cast = input.as<CastNode>();
if (cast == nullptr) {
return input;
} else if (cast->dtype == target_type) {
return cast->value;
}
- return Expr();
+ return PrimExpr();
}
// MMAMatcher matches C = Cast(A)*Cast(B)+C,
buf_name_.insert(std::make_pair(load_a, buffer_a.name));
buf_name_.insert(std::make_pair(load_b, buffer_b.name));
mma_sync_.insert(std::make_pair(op,
- Array<Expr>{load_a_expr, load_b_expr, add->a}));
+ Array<PrimExpr>{load_a_expr, load_b_expr, add->a}));
return true;
}
std::unordered_map<TensorKey, BufferInfo> buf_map_;
std::unordered_map<const Object*, std::string> storage_scope_;
- std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
+ std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
std::unordered_set<std::string> frag_reg_;
bool matched_{false};
if (comm_add == nullptr || op->combiner->result.size() > 1) {
return;
}
- for (Expr source : op->source) {
+ for (PrimExpr source : op->source) {
auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<MulNode>();
auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<MulNode>();
if (mul_0 == nullptr && mul_1 == nullptr) {
friend class ScheduleAnalyser;
private:
- std::unordered_map<std::string, Array<Expr>> args_;
+ std::unordered_map<std::string, Array<PrimExpr>> args_;
bool tensorcore_candidate_{false};
};
reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
BodyVisitor body_visitor;
- for (Expr expr : compute->body) {
+ for (PrimExpr expr : compute->body) {
body_visitor(expr);
}
if (!body_visitor.tensorcore_candidate_) {
if (it0->second == "matrix_a" && it1->second == "matrix_b") {
return true;
} else if (it0->second == "matrix_b" && it1->second == "matrix_a") {
- mma_sync.second = Array<Expr>{operands[1], operands[0], operands[2]};
+ mma_sync.second = Array<PrimExpr>{operands[1], operands[0], operands[2]};
} else {
return false;
}
private:
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
- std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
+ std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
std::unordered_map<const Object*, std::string> buf_name_;
};
}
}
- Array<Expr> strides;
+ Array<PrimExpr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- Expr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, bi.shape[j]);
}
strides_.insert(std::make_pair(key.GetName(), strides));
if (frag_reg_.count(bi.name)) {
- Expr dst = CallNode::make(bi.dtype,
+ PrimExpr dst = CallNode::make(bi.dtype,
bi.name,
op->args,
CallNode::Halide,
const CallNode* value = op->value.as<CallNode>();
if (value != nullptr && frag_reg_.count(value->name)) {
- Expr dst = CallNode::make(bi.dtype,
+ PrimExpr dst = CallNode::make(bi.dtype,
bi.name,
op->args,
CallNode::Halide,
}
}
- Array<Expr> strides;
+ Array<PrimExpr> strides;
if (bi.strides.size() > 0) {
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- Expr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, bi.shape[j]);
}
BufferInfo bi;
bi.bounds = op->bounds;
- Array<Expr> shape;
+ Array<PrimExpr> shape;
for (auto r : bi.bounds) {
shape.push_back(r->extent);
}
- Array<Expr> strides;
+ Array<PrimExpr> strides;
if (dim_align_.count(key) != 0 && shape.size() != 0) {
- std::vector<Expr> rstrides;
+ std::vector<PrimExpr> rstrides;
const std::vector<DimAlignInfo>& avec = dim_align_[key];
int first_dim = 0;
- Expr stride = make_const(shape[first_dim].dtype(), 1);
+ PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
for (size_t i = shape.size(); i != 0; --i) {
size_t dim = i - 1;
if (dim < avec.size() && avec[dim].align_factor != 0) {
- Expr factor = make_const(stride.dtype(), avec[dim].align_factor);
- Expr offset = make_const(stride.dtype(), avec[dim].align_offset);
+ PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+ PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
stride = stride + \
indexmod(factor + offset - indexmod(stride, factor), factor);
stride = ir::Simplify(stride);
rstrides.push_back(stride);
stride = stride * shape[dim];
}
- strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+ strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}
bi.name = key.GetName();
struct BufferInfo {
std::string name;
DataType dtype;
- Array<Expr> strides;
- Array<Expr> shape;
+ Array<PrimExpr> strides;
+ Array<PrimExpr> shape;
Region bounds;
bool external{false};
bool released{false};
- inline Array<Expr> RelIndex(Array<Expr> args) const {
+ inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
if (bounds.size() != 0) {
- Array<Expr> index;
+ Array<PrimExpr> index;
CHECK_EQ(bounds.size(), args.size());
for (size_t i = 0; i < bounds.size(); ++i) {
index.push_back(args[i] - bounds[i]->min);
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
std::unordered_set<std::string> frag_reg_;
- std::unordered_map<std::string, Array<Expr>> strides_;
- std::unordered_map<const ProvideNode*, Expr> frag_load_;
- std::unordered_map<const ProvideNode*, Expr> frag_store_;
+ std::unordered_map<std::string, Array<PrimExpr>> strides_;
+ std::unordered_map<const ProvideNode*, PrimExpr> frag_load_;
+ std::unordered_map<const ProvideNode*, PrimExpr> frag_store_;
std::unordered_map<std::string, int> thread_extent_;
IndexVisitor index_visitor;
Tile warp_tile_;
// ThreadIdxMutator does the thread index unification inside a warp
class ThreadIdxMutator : public StmtExprMutator {
public:
- explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
+ explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {}
- Expr VisitExpr_(const VarNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<VarNode>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
- Expr zero = IntImmNode::make(DataType::Int(32), 0);
+ PrimExpr zero = IntImmNode::make(DataType::Int(32), 0);
return zero;
}
if (op->name_hint == "threadIdx.y") {
- Expr div = DivNode::make(expr, warp_y_);
- Expr mul = MulNode::make(div, warp_y_);
+ PrimExpr div = DivNode::make(expr, warp_y_);
+ PrimExpr mul = MulNode::make(div, warp_y_);
return mul;
}
}
}
private:
- Expr warp_y_;
+ PrimExpr warp_y_;
};
// TensorCoreIRMutator mutates the AST for TensorCore CodeGen
auto it = mma_sync_.find(op);
if (it != mma_sync_.end()) {
const auto &operands = it->second;
- Expr a = operands[0];
+ PrimExpr a = operands[0];
auto ca = a.as<CallNode>();
- Expr b = operands[1];
+ PrimExpr b = operands[1];
auto cb = b.as<CallNode>();
- Expr c = operands[2];
+ PrimExpr c = operands[2];
auto cc = c.as<CallNode>();
ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
auto it2 = frag_load_.find(op);
if (it2 != frag_load_.end()) {
- Expr dst = it2->second;
+ PrimExpr dst = it2->second;
if (op->value.as<FloatImmNode>() != nullptr ||
op->value.as<IntImmNode>() != nullptr) {
auto call = dst.as<CallNode>();
<< "Cannot find stride for " << value->name;
auto strides = it->second;
CHECK_GE(strides.size(), 2);
- Expr stride = strides[strides.size()-2];
+ PrimExpr stride = strides[strides.size()-2];
// thread index unification inside a warp
- Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+ PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
- Expr mutated_value = thread_idx_mutator(op->value);
- Expr src = CallNode::make(value->dtype,
+ PrimExpr mutated_value = thread_idx_mutator(op->value);
+ PrimExpr src = CallNode::make(value->dtype,
"&",
{mutated_value},
CallNode::Extern);
auto call = dst.as<CallNode>();
- Expr matrix_major;
+ PrimExpr matrix_major;
auto iter2 = matrix_major_.find(simplify_name(call->name));
CHECK(iter2 != matrix_major_.end())
<< "Can not determine matrix major for " << call->name;
<< "Cannot find stride for " << key.GetName();
auto strides = it->second;
CHECK_GE(strides.size(), 2);
- Expr stride = strides[strides.size()-2];
+ PrimExpr stride = strides[strides.size()-2];
- Expr dst = it3->second;
+ PrimExpr dst = it3->second;
// thread index unification inside a warp
- Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+ PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
dst = CallNode::make(DataType::Handle(),
int ori_extent_value = ori_extent->value;
scaled_extent_value = ori_extent_value / scale_factor;
}
- Expr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
+ PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type,
op->device_api, op->body);
}
}
private:
- Array<Expr> get_tile_size_(const std::string &name) {
+ Array<PrimExpr> get_tile_size_(const std::string &name) {
auto it = matrix_abc_.find(name);
auto it2 = matrix_major_.find(name);
CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
<< "Cannot find matrix info for " << name;
- Expr size0 = make_const(DataType::Int(32), 16);
- Expr size1 = make_const(DataType::Int(32), 16);
+ PrimExpr size0 = make_const(DataType::Int(32), 16);
+ PrimExpr size1 = make_const(DataType::Int(32), 16);
if (it->second == "matrix_a" && it2->second == "col_major") {
size0 = make_const(DataType::Int(32), warp_tile_.k);
size1 = make_const(DataType::Int(32), warp_tile_.m);
size0 = make_const(DataType::Int(32), warp_tile_.n);
size1 = make_const(DataType::Int(32), warp_tile_.m);
}
- Array<Expr> tile_size = {size0, size1};
+ Array<PrimExpr> tile_size = {size0, size1};
return tile_size;
}
DataType datatype) {
auto it = bounds_.find(key);
CHECK(it != bounds_.end());
- Array<Expr> min_bound;
+ Array<PrimExpr> min_bound;
for (auto i : it->second) {
min_bound.push_back(i->min);
}
CHECK_GE(it->second.size(), 2);
- Array<Expr> shape;
+ Array<PrimExpr> shape;
for (size_t i = 0; i < it->second.size() - 2; ++i) {
shape.push_back(it->second[i]->extent);
}
shape.push_back(tile_size[0]);
shape.push_back(tile_size[1]);
- Array<Expr> strides;
+ Array<PrimExpr> strides;
for (size_t i = 1; i < shape.size(); ++i) {
- Expr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
for (size_t j = shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, shape[j]);
}
}
strides.push_back(make_const(DataType::Int(32), 1));
- Expr elem_offset = IntImmNode::make(DataType::Int(32), 0);
+ PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0);
CHECK_EQ(call->args.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
elem_offset = AddNode::make(
tensor_node->dtype = datatype;
Tensor tensor(tensor_node);
- Array<Expr> args;
+ Array<PrimExpr> args;
for (size_t i = 0; i < call->args.size(); ++i) {
args.push_back(call->args[i]);
args.push_back(shape[i]);
std::unordered_map<std::string, std::string> matrix_abc_;
std::unordered_map<std::string, std::string> matrix_major_;
- std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
- std::unordered_map<std::string, Array<Expr>> strides_;
+ std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
+ std::unordered_map<std::string, Array<PrimExpr>> strides_;
std::unordered_set<std::string> frag_reg_;
std::unordered_map<const VarNode*, unsigned> loop_scaling_;
- std::unordered_map<const ProvideNode*, Expr> frag_load_;
- std::unordered_map<const ProvideNode*, Expr> frag_store_;
+ std::unordered_map<const ProvideNode*, PrimExpr> frag_load_;
+ std::unordered_map<const ProvideNode*, PrimExpr> frag_store_;
std::unordered_map<TensorKey, Region> bounds_;
Tile warp_tile_;
int warp_threads_y_{-1};
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
if (value == 0) return EvaluateNode::make(0);
Stmt body = op->body;
- Map<Var, Expr> vmap;
+ Map<Var, PrimExpr> vmap;
Array<Stmt> unrolled;
for (int i = 0; i < value; ++i) {
vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
// returns the extent of the loop if it's a constant integer, otherwise return -1
int GetExtent(const ForNode* op) {
// constant folding.
- Expr extent = ir::Simplify(op->extent);
+ PrimExpr extent = ir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
const UIntImmNode *v2 = extent.as<UIntImmNode>();
int value = -1;
namespace tvm {
namespace ir {
-inline Expr BroadcastTo(Expr e, int lanes) {
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
if (e.dtype().lanes() == lanes) return e;
if (const BroadcastNode* op = e.as<BroadcastNode>()) {
if (lanes % op->lanes == 0) {
VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
// Load
- Expr VisitExpr_(const LoadNode* op) final {
- Expr expr = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<LoadNode>();
if (op->buffer_var.get() == buf_) {
return LoadNode::make(op->dtype, op->buffer_var,
}
}
- Expr VisitExpr_(const AddNode* op) final {
+ PrimExpr VisitExpr_(const AddNode* op) final {
return AddSubVec(op);
}
- Expr VisitExpr_(const SubNode* op) final {
+ PrimExpr VisitExpr_(const SubNode* op) final {
return AddSubVec(op);
}
- Expr VisitExpr_(const MulNode* op) final {
- Expr a = this->VisitExpr(op->a);
- Expr b = this->VisitExpr(op->b);
+ PrimExpr VisitExpr_(const MulNode* op) final {
+ PrimExpr a = this->VisitExpr(op->a);
+ PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
}
return BinaryVec(op);
}
- Expr VisitExpr_(const DivNode* op) final {
+ PrimExpr VisitExpr_(const DivNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const ModNode* op) final {
+ PrimExpr VisitExpr_(const ModNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const FloorDivNode* op) final {
+ PrimExpr VisitExpr_(const FloorDivNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const FloorModNode* op) final {
+ PrimExpr VisitExpr_(const FloorModNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const MinNode* op) final {
+ PrimExpr VisitExpr_(const MinNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const MaxNode* op) final {
+ PrimExpr VisitExpr_(const MaxNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const EQNode* op) final {
+ PrimExpr VisitExpr_(const EQNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const NENode* op) final {
+ PrimExpr VisitExpr_(const NENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const LTNode* op) final {
+ PrimExpr VisitExpr_(const LTNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const LENode* op) final {
+ PrimExpr VisitExpr_(const LENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const GTNode* op) final {
+ PrimExpr VisitExpr_(const GTNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const GENode* op) final {
+ PrimExpr VisitExpr_(const GENode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const AndNode* op) final {
+ PrimExpr VisitExpr_(const AndNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const OrNode* op) final {
+ PrimExpr VisitExpr_(const OrNode* op) final {
return BinaryVec(op);
}
- Expr VisitExpr_(const RampNode* op) final {
- Expr base = this->VisitExpr(op->base);
- Expr stride = this->VisitExpr(op->stride);
+ PrimExpr VisitExpr_(const RampNode* op) final {
+ PrimExpr base = this->VisitExpr(op->base);
+ PrimExpr stride = this->VisitExpr(op->stride);
if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
const RampNode* base_ramp = base.as<RampNode>();
if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
base = BroadcastTo(base, lanes);
stride = BroadcastTo(stride, lanes);
- Array<Expr> elems;
+ Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(
RampNode::make(ShuffleNode::make_extract_element(base, i),
}
return ShuffleNode::make_concat(elems);
}
- Expr VisitExpr_(const SelectNode *op) final {
- Expr cond = this->VisitExpr(op->condition);
- Expr t = this->VisitExpr(op->true_value);
- Expr f = this->VisitExpr(op->false_value);
+ PrimExpr VisitExpr_(const SelectNode *op) final {
+ PrimExpr cond = this->VisitExpr(op->condition);
+ PrimExpr t = this->VisitExpr(op->true_value);
+ PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(std::max(
cond.dtype().lanes(),
return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
}
- Expr VisitExpr_(const CastNode *op) final {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr VisitExpr_(const CastNode *op) final {
+ PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
// Variable
- Expr VisitExpr_(const VarNode* v) final {
+ PrimExpr VisitExpr_(const VarNode* v) final {
if (v == var_.get()) {
return ramp_;
} else if (lets_.count(v)) {
return lets_[v];
} else {
- return GetRef<Expr>(v);
+ return GetRef<PrimExpr>(v);
}
}
// IfThenElse expr
- Expr MutateIfThenElseExpr_(const CallNode *op) {
- Expr cond = this->VisitExpr(op->args[0]);
+ PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
+ PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_vector()) {
need_scalarize_ = true;
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
- Expr t = this->VisitExpr(op->args[1]);
- Expr f = this->VisitExpr(op->args[2]);
+ PrimExpr t = this->VisitExpr(op->args[1]);
+ PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) &&
t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
t = BroadcastTo(t, lanes);
}
}
// Call
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->name == intrinsic::tvm_if_then_else) {
return MutateIfThenElseExpr_(op);
}
if (!op->is_vectorizable()) {
// Cannot vectorize this op
- Array<Expr> new_args;
+ Array<PrimExpr> new_args;
for (auto arg : op->args) {
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_vector()) {
need_scalarize_ = true;
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CallNode::make(
op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
}
} else {
int lane = 0;
- Array<Expr> new_args = MutateArray(op->args, &lane);
+ Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return CallNode::make(
op->dtype.with_lanes(lane), op->name, new_args,
}
}
// Load
- Expr VisitExpr_(const LoadNode* op) final {
- Expr index = this->VisitExpr(op->index);
- Expr pred = this->VisitExpr(op->predicate);
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ PrimExpr index = this->VisitExpr(op->index);
+ PrimExpr pred = this->VisitExpr(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
return LoadNode::make(
}
}
// Let
- Expr VisitExpr_(const LetNode* op) final {
- Expr value = this->VisitExpr(op->value);
+ PrimExpr VisitExpr_(const LetNode* op) final {
+ PrimExpr value = this->VisitExpr(op->value);
CHECK(!lets_.count(op->var.get())) << "not SSA";
if (value.dtype().lanes() != op->value.dtype().lanes()) {
Var v(op->var->name_hint, value.dtype());
lets_[op->var.get()] = v;
return LetNode::make(v, value, this->VisitExpr(op->body));
} else {
- Expr body = this->VisitExpr(op->body);
+ PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
return LetNode::make(op->var, value, body);
}
}
// Provide
Stmt VisitStmt_(const ProvideNode* op) final {
- Expr new_value = this->VisitExpr(op->value);
+ PrimExpr new_value = this->VisitExpr(op->value);
int lane = new_value.dtype().lanes();
- Array<Expr> new_args = MutateArray(op->args, &lane);
+ Array<PrimExpr> new_args = MutateArray(op->args, &lane);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return GetRef<Stmt>(op);
} else {
}
// Store
Stmt VisitStmt_(const StoreNode* op) final {
- Expr value = this->VisitExpr(op->value);
- Expr index = this->VisitExpr(op->index);
- Expr pred = this->VisitExpr(op->predicate);
+ PrimExpr value = this->VisitExpr(op->value);
+ PrimExpr index = this->VisitExpr(op->index);
+ PrimExpr pred = this->VisitExpr(op->predicate);
if (value.same_as(op->value) && index.same_as(op->index)) {
return GetRef<Stmt>(op);
} else {
}
CHECK(is_zero(op->min));
CHECK(!op->extent.dtype().is_vector());
- Expr extent = this->VisitExpr(op->extent);
+ PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
CHECK(!op->condition.dtype().is_vector());
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(GetRef<Stmt>(op));
}
- Expr condition = this->VisitExpr(op->condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(GetRef<Stmt>(op));
}
- Array<Expr> extents;
+ Array<PrimExpr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
- Expr new_ext = this->VisitExpr(op->extents[i]);
+ PrimExpr new_ext = this->VisitExpr(op->extents[i]);
if (new_ext.dtype().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(GetRef<Stmt>(op));
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
- Map<Var, Expr> values{{var_, idx}};
+ Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
// the lanes.
int var_lanes_;
// ramp representing the var.
- Expr ramp_;
+ PrimExpr ramp_;
// flag to mark requirment of scalarization.
bool need_scalarize_{false};
// The lets
- std::unordered_map<const VarNode*, Expr> lets_;
+ std::unordered_map<const VarNode*, PrimExpr> lets_;
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
- Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) {
+ Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
if (arr.size() == 0) return arr;
int& lanes = *p_lanes;
bool changed = false;
- std::vector<Expr> new_arr(arr.size());
+ std::vector<PrimExpr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
- Expr old_elem = arr[i];
- Expr new_elem = this->VisitExpr(old_elem);
+ PrimExpr old_elem = arr[i];
+ PrimExpr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes());
}
}
if (!changed) return arr;
- return Array<Expr>(new_arr);
+ return Array<PrimExpr>(new_arr);
}
template<typename T>
- Expr BinaryVec(const T* op) {
- Expr a = this->VisitExpr(op->a);
- Expr b = this->VisitExpr(op->b);
+ PrimExpr BinaryVec(const T* op) {
+ PrimExpr a = this->VisitExpr(op->a);
+ PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
- Expr AddSubVec(const T* op) {
- Expr a = this->VisitExpr(op->a);
- Expr b = this->VisitExpr(op->b);
+ PrimExpr AddSubVec(const T* op) {
+ PrimExpr a = this->VisitExpr(op->a);
+ PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
} else {
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
visited_shared_buffers_.insert(op->node.as<tvm::VarNode>());
}
} else if (op->attr_key == attr::thread_extent) {
- VarExpr var = op->node.as<tvm::IterVarNode>()->var;
+ Var var = op->node.as<tvm::IterVarNode>()->var;
const auto *extent = op->value.as<IntImmNode>();
CHECK(extent);
};
bool VerifyGPUCode(Stmt stmt,
- Map<std::string, Expr> constraints) {
+ Map<std::string, PrimExpr> constraints) {
GPUCodeVerifier verifier;
int64_t max_local_memory_per_block = INT64_MAX;
protected:
/// Visitor implementation
//@{
- void VisitExpr(const Expr &n) final {
+ void VisitExpr(const PrimExpr &n) final {
if (Failed()) return;
StmtExprVisitor::VisitExpr(n);
}
}
/// Handle memory access to a Variable
- void HandleLoadStoreToVariable(const VarExpr &var) {
+ void HandleLoadStoreToVariable(const Var &var) {
// We skip the access within thread env.
if (InThreadEnv()) return;
//@}
LoweredFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
- std::unordered_map<const VarNode *, Expr> defs_; ///< Variable definitions
+ std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions
};
} // namespace
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
- auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
+ auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<ir::StringImmNode>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
*
* \return Array<StringImm> names of params
*/
- Array<tvm::Expr> ListParamNames() {
- Array<tvm::Expr> ret;
+ Array<tvm::PrimExpr> ListParamNames() {
+ Array<tvm::PrimExpr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImmNode::make(kv.first));
}
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
- return tvm::Expr();
+ return tvm::PrimExpr();
}
}, "compile_engine_const", topi::kBroadcast);
scalars_.push_back(value->op);
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
- return tvm::Expr();
+ return tvm::PrimExpr();
}
}, "data_const", topi::kBroadcast);
scalars_.push_back(value);
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- Array<tvm::Expr> ret;
+ Array<tvm::PrimExpr> ret;
for (const auto &kv : this->output_.params) {
- tvm::Expr name = ir::StringImmNode::make(kv.first);
+ tvm::PrimExpr name = ir::StringImmNode::make(kv.first);
ret.push_back(name);
}
*rv = ret;
Pass LambdaLift();
Pass InlinePrimitives();
-Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
+Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
Pass ManifestAlloc(Target target_host) {
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
- Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
+ Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
* \return The module with dead functions removed.
*/
Module RemoveUnusedFunctions(const Module& module,
- Array<tvm::Expr> entry_funcs) {
+ Array<tvm::PrimExpr> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
auto* str_name = entry.as<ir::StringImmNode>();
namespace transform {
-Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
+Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
[=](Module m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
TensorType ConstantNode::tensor_type() const {
auto dtype = DataType(data->dtype);
- Array<tvm::Expr> shape;
+ Array<tvm::PrimExpr> shape;
for (int i = 0; i < data->ndim; i++) {
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
using AttrsHashHandler::VisitAttr_;
size_t VisitAttr_(const tvm::VarNode* var) final {
size_t hash = std::hash<std::string>()(VarNode::_type_key);
- auto it = hash_map_.find(GetRef<VarExpr>(var));
+ auto it = hash_map_.find(GetRef<tvm::Var>(var));
if (it != hash_map_.end()) {
return it->second;
}
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed([]() {
- Array<tvm::Expr> ret;
+ Array<tvm::PrimExpr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
- ret.push_back(tvm::Expr(name));
+ ret.push_back(tvm::PrimExpr(name));
}
return ret;
});
CHECK(static_cast<int>(data->shape.size()) != 0);
CHECK(param->units.defined());
- Array<tvm::Expr> oshape = data->shape;
+ Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set((oshape.size() - 1), param->units);
DataType out_dtype = param->out_dtype;
}
reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);
- Array<tvm::Expr> oshape = buffer->shape;
+ Array<tvm::PrimExpr> oshape = buffer->shape;
reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
return true;
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;
- Array<tvm::Expr> oshape = x->shape;
+ Array<tvm::PrimExpr> oshape = x->shape;
oshape.Set(2, y->shape[1]);
// assign output type
CHECK(static_cast<int>(data->shape.size()) != 0);
- Array<tvm::Expr> oshape = data->shape;
+ Array<tvm::PrimExpr> oshape = data->shape;
if (param->units.defined()) {
- Array<tvm::Expr> dshape = data->shape;
+ Array<tvm::PrimExpr> dshape = data->shape;
// validate the weight shape is proper if defined
// Assign weight type
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
oshape.Set((oshape.size() - 1), param->units);
} else {
if (weight == nullptr) return false;
- Array<tvm::Expr> wshape = weight->shape;
+ Array<tvm::PrimExpr> wshape = weight->shape;
oshape.Set((oshape.size() - 1), wshape[0]);
}
// split.
// 1) Create a map from axis to param_width using old layout.
- std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
+ std::map<std::string, tvm::Array<tvm::PrimExpr>> axis_pad_width;
int index_counter = 0;
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
}
// 2) Create new pad width by walking over the new layout and using the map.
- tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
+ tvm::Array<tvm::Array<tvm::PrimExpr>> new_pad_width;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
auto axis_name = new_layout_axis.name();
tvm::DataType dtype,
std::string name = "tensor",
std::string tag = topi::kInjective) {
- tvm::Expr num_elem = tvm::Var("num_elem");
+ tvm::PrimExpr num_elem = tvm::Var("num_elem");
return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
return tvm::cast(dtype, start[0] + step[0] * indices[0]);
}, name, tag);
Tensor start = inputs[0];
Tensor stop = inputs[1];
Tensor step = inputs[2];
- Array<tvm::Expr> empty = {0};
+ Array<tvm::PrimExpr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
}
* \return The adjusted Layout.
*/
inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
- const Array<tvm::Expr>& old_shape) {
+ const Array<tvm::PrimExpr>& old_shape) {
// For each subordinate axis
// 1) Find the corresponding dual axis.
// 2) Find the Index of this dual axis in old_layout.
PassInfo PassInfoNode::make(int opt_level,
std::string name,
- tvm::Array<tvm::Expr> required) {
+ tvm::Array<tvm::PrimExpr> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
}
// linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
+inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
auto* str_name = x.as<ir::StringImmNode>();
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::Expr>& required) {
+ const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return ModulePassNode::make(pass_func, pass_info);
}
const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::Expr>& required) {
+ const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return FunctionPassNode::make(pass_func, pass_info);
}
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
- tvm::Array<tvm::Expr> required = args[3];
+ tvm::Array<tvm::PrimExpr> required = args[3];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
- tvm::Array<tvm::Expr> required = args[2];
- tvm::Array<tvm::Expr> disabled = args[3];
+ tvm::Array<tvm::PrimExpr> required = args[2];
+ tvm::Array<tvm::PrimExpr> disabled = args[3];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
return ulhs;
}
- return tvm::Expr();
+ return tvm::PrimExpr();
}
Type VisitType_(const TensorTypeNode* op, const Type& tn) final {
if (!dim.defined()) {
// NB: We push an arbitrary dimension here so we can continue error propogation.
shape.push_back(tt1->shape[i]);
- tvm::Expr shape1 = tt1->shape[i];
- tvm::Expr shape2 = tt2->shape[i];
+ tvm::PrimExpr shape1 = tt1->shape[i];
+ tvm::PrimExpr shape2 = tt2->shape[i];
std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2);
mismatches.push_back(tuple);
} else {
CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
- const Array<tvm::Expr> oshape = data->shape;
+ const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type, output will always be float 32.
reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
return true;
AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale
AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point
- const Array<tvm::Expr> oshape = data->shape;
+ const Array<tvm::PrimExpr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
- const Array<tvm::Expr> oshape = data->shape;
+ const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type
auto out_dtype = requantize_attrs->out_dtype;
CHECK(out_dtype == DataType::Int(8) ||
attrs.operator->(), input_shape, out_dtype);
}
-static inline int64_t get_const_int(const tvm::Expr& x) {
+static inline int64_t get_const_int(const tvm::PrimExpr& x) {
auto* value_ptr = as_const_int(x);
CHECK(value_ptr) << "Expr is not a constant int";
return value_ptr[0];
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
- void VisitExpr(const Expr& e) final {
+ void VisitExpr(const PrimExpr& e) final {
if (!is_elem_wise_) return;
ExprVisitor::VisitExpr(e);
}
void VisitExpr_(const CallNode* op) final {
- Array<Expr> axis = op->args;
+ Array<PrimExpr> axis = op->args;
if (axis_.size() != axis.size()) {
is_elem_wise_ = false;
return;
return GetSubGraph(scan->update, inputs, false);
}
-Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
const ScanOpNode* scan = scan_op.as<ScanOpNode>();
Array<Operation> body = ScanGetBody(scan_op);
}
}
ReachGraph reach;
- Map<IterVar, Expr> ret;
+ Map<IterVar, PrimExpr> ret;
std::unordered_set<TensorDimKey> place_holder_ref;
for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \param scan The scan node.
* \return Map of spatial_axis -> IntImm
*/
-Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
} // namespace schedule
} // namespace tvm
std::unordered_map<IterVar, Range>* p_state,
arith::Analyzer* actx,
bool allow_missing) {
- auto ceil_div = [actx](Expr a, Expr b) {
+ auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
if (actx->CanProve(indexmod(a, b) == 0)) {
return actx->Simplify(indexdiv(a, b));
}
void PassUpIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Expr>* p_state,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (size_t i = stage->relations.size(); i != 0; --i) {
CHECK(allow_missing);
continue;
}
- Expr outer = state.at(s->outer);
- Expr inner = state.at(s->inner);
- Expr factor = dom_map.at(s->inner)->extent;
- Expr parent_min = dom_map.at(s->parent)->min;
+ PrimExpr outer = state.at(s->outer);
+ PrimExpr inner = state.at(s->inner);
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
CHECK(allow_missing);
continue;
}
- Expr value = state.at(s->fused);
- Expr factor = dom_map.at(s->inner)->extent;
- Expr outer_min = dom_map.at(s->outer)->min;
- Expr inner_min = dom_map.at(s->inner)->min;
+ PrimExpr value = state.at(s->fused);
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = indexdiv(value, factor);
state[s->inner] = indexmod(value, factor);
// add min if they exist
CHECK(allow_missing);
continue;
}
- Expr value = state.at(s->rebased);
- Expr parent_min = dom_map.at(s->parent)->min;
+ PrimExpr value = state.at(s->rebased);
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = value + parent_min;
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Expr>* p_state,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
bool allow_missing) {
auto& state = *p_state;
for (IterVarRelation rel : stage->relations) {
}
Range r = dom_map.at(s->inner);
CHECK(is_zero(r->min));
- Expr parent = state.at(s->parent);
- Expr factor = r->extent;
+ PrimExpr parent = state.at(s->parent);
+ PrimExpr factor = r->extent;
state[s->outer] = indexdiv(parent, factor);
state[s->inner] = indexmod(parent, factor);
} else if (const FuseNode* s = rel.as<FuseNode>()) {
CHECK(allow_missing);
continue;
}
- Expr factor = dom_map.at(s->inner)->extent;
- Expr outer_min = dom_map.at(s->outer)->min;
- Expr inner_min = dom_map.at(s->inner)->min;
- Expr inner = state.at(s->inner);
- Expr outer = state.at(s->outer);
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
+ PrimExpr inner = state.at(s->inner);
+ PrimExpr outer = state.at(s->outer);
CHECK(is_zero(outer_min));
CHECK(is_zero(inner_min));
state[s->fused] = outer * factor + inner;
CHECK(allow_missing);
continue;
}
- Expr value = state.at(s->parent);
- Expr parent_min = dom_map.at(s->parent)->min;
+ PrimExpr value = state.at(s->parent);
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
CHECK(is_zero(parent_min));
state[s->rebased] = value;
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
- Expr factor = dom_map.at(s->inner)->extent;
- Expr parent_min = dom_map.at(s->parent)->min;
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
- Expr outer_min = dom_map.at(s->outer)->min;
- Expr inner_min = dom_map.at(s->inner)->min;
+ PrimExpr outer_min = dom_map.at(s->outer)->min;
+ PrimExpr inner_min = dom_map.at(s->inner)->min;
if (fused.is_single_point()) {
- Expr value = fused.point_value();
- Expr factor = dom_map.at(s->inner)->extent;
- Expr v_outer = indexdiv(value, factor);
- Expr v_inner = indexmod(value, factor);
+ PrimExpr value = fused.point_value();
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr v_outer = indexdiv(value, factor);
+ PrimExpr v_inner = indexmod(value, factor);
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
- Expr fused_extent = (fused.max() - fused.min() + 1);
- Expr inner_extent = dom_map.at(s->inner)->extent;
+ PrimExpr fused_extent = (fused.max() - fused.min() + 1);
+ PrimExpr inner_extent = dom_map.at(s->inner)->extent;
*outer = IntSet::interval(
outer_min + indexdiv(fused.min(), inner_extent),
outer_min + indexdiv(fused.max(), inner_extent));
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
- Expr parent_min = dom_map.at(s->parent)->min;
+ PrimExpr parent_min = dom_map.at(s->parent)->min;
*parent = arith::EvalSet(s->rebased->var + parent_min,
{{s->rebased, rebased}});
}
bool inner = state.at(s->inner);
if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
- Expr factor = dom_map.at(s->inner)->extent;
- Expr step = dom_map.at(s->outer)->extent;
+ PrimExpr factor = dom_map.at(s->inner)->extent;
+ PrimExpr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
}
}
-std::vector<Expr> MakeBoundCheck(
+std::vector<PrimExpr> MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Expr>& value_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter) {
arith::Analyzer analyzer;
}
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
- std::vector<Expr> preds;
+ std::vector<PrimExpr> preds;
std::unordered_map<const VarNode*, IntSet> iset_dmap;
// setup domain map for set analysis
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
- Expr value = value_map.at(iv) - dom->min;
- Expr vmax = EvalSet(value, iset_dmap).max();
+ PrimExpr value = value_map.at(iv) - dom->min;
+ PrimExpr vmax = EvalSet(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
Range dom = dom_map.at(iv);
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
- Expr value = value_map.at(iv) - iv->dom->min;
+ PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap);
- Expr vmin = s.min();
- Expr vmax = s.max();
+ PrimExpr vmin = s.min();
+ PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
preds.emplace_back(value >= 0);
*/
void PassUpIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Expr>* p_state,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
bool allow_missing = false);
/*!
*/
void PassDownIndex(const Stage& stage,
const Map<IterVar, Range>& dom_map,
- std::unordered_map<IterVar, Expr>* p_state,
+ std::unordered_map<IterVar, PrimExpr>* p_state,
bool allow_missing = false);
/*!
* \param skip_iter The set of variables to skip bound condition.
* \return List of predicates that we need to check.
*/
-std::vector<Expr>
+std::vector<PrimExpr>
MakeBoundCheck(
const Stage& stage,
const Map<IterVar, Range>& dom_map,
- const std::unordered_map<IterVar, Expr>& value_map,
+ const std::unordered_map<IterVar, PrimExpr>& value_map,
bool skip_ivar_domain,
const std::unordered_set<IterVar>& skip_iter);
class VarReplacer : public ir::StmtExprMutator {
public:
explicit VarReplacer(
- const std::unordered_map<const VarNode*, Expr>& vsub)
+ const std::unordered_map<const VarNode*, PrimExpr>& vsub)
: vsub_(vsub) {}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = vsub_.find(op);
if (it != vsub_.end()) return it->second;
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
// Replace free variables in combiner
- auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
+ auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
return this->VisitExpr(e);
});
- auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
+ auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
return this->VisitExpr(e);
});
}
}
- Expr VisitExpr_(const ir::ReduceNode* op) final {
- Expr new_e = StmtExprMutator::VisitExpr_(op);
+ PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
+ PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
if (op->combiner.same_as(new_combiner)) {
}
private:
- const std::unordered_map<const VarNode*, Expr>& vsub_;
+ const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
};
-Expr InjectPredicate(const Array<Expr>& predicates,
- Expr body) {
+PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
+ PrimExpr body) {
using ir::ReduceNode;
using ir::SelectNode;
if (predicates.size() == 0) return body;
const ReduceNode* reduce = body.as<ReduceNode>();
if (reduce) {
auto n = make_object<ReduceNode>(*reduce);
- n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, Expr());
- return Expr(n);
+ n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
+ return PrimExpr(n);
}
- return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, Expr()),
+ return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
body,
make_zero(body.dtype()));
}
Stage s = operator[](tensor->op);
Tensor sugar_tensor = s->op.output(tensor->value_index);
Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
- return sugar_tensor(Array<Expr>(i.begin(), i.end()));
+ return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
}, os.str());
vsub[sugar_tensor] = cache;
std::unordered_set<IterVar>* p_red_axis,
Array<IterVar>* p_new_axis,
std::unordered_map<IterVar, Range>* p_dom_map,
- std::unordered_map<const VarNode*, Expr>* p_vsub,
- std::unordered_map<const VarNode*, Expr>* p_vsub2newvar,
- std::vector<Expr>* p_predicates) {
+ std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
+ std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
+ std::vector<PrimExpr>* p_predicates) {
auto& red_axis = *p_red_axis;
auto& new_axis = *p_new_axis;
auto& dom_map = *p_dom_map;
schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
{
// The source->cache
- std::unordered_map<IterVar, Expr> value_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : orig_stage->leaf_iter_vars) {
if (red_axis.count(iv)) continue;
CHECK_EQ(iv->iter_type, kDataPar)
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<const VarNode*, Expr> vsub;
- std::unordered_map<const VarNode*, Expr> vsub2newvar;
- std::vector<Expr> predicates;
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
+ std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+ std::vector<PrimExpr> predicates;
PrepareAxisMapping(orig_stage, compute,
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
- Expr body;
- Array<Expr> body_list;
+ PrimExpr body;
+ Array<PrimExpr> body_list;
const ir::ReduceNode* first_reduce = nullptr;
for (auto cbody : compute->body) {
body = VarReplacer(vsub)(cbody);
body_list.push_back(body);
}
// The reader args
- Array<Expr> args;
+ Array<PrimExpr> args;
{
// cache->compute
- std::unordered_map<IterVar, Expr> value_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute->axis) {
value_map[iv] = iv->var;
}
compute->name + "." + scope, compute->tag, compute->attrs,
new_axis, body_list);
- Array<Expr> cache_expr_list;
+ Array<PrimExpr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
Array<IterVar> new_axis;
std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<const VarNode*, Expr> vsub;
- std::unordered_map<const VarNode*, Expr> vsub2newvar;
- std::vector<Expr> predicates;
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
+ std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+ std::vector<PrimExpr> predicates;
PrepareAxisMapping(orig_stage, tensor_op,
&red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
for (Region old_region : tensor_op->input_regions) {
Region region;
for (Range r : old_region) {
- Expr min = VarReplacer(vsub2newvar)(r->min);
- Expr extent = VarReplacer(vsub2newvar)(r->extent);
+ PrimExpr min = VarReplacer(vsub2newvar)(r->min);
+ PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
region.push_back(Range::make_by_min_extent(min, extent));
}
new_regions.push_back(region);
}
- Array<Expr> new_scalar_inputs;
- for (Expr old_input : tensor_op->scalar_inputs) {
+ Array<PrimExpr> new_scalar_inputs;
+ for (PrimExpr old_input : tensor_op->scalar_inputs) {
new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
}
}
// The reader args
- Array<Expr> args;
+ Array<PrimExpr> args;
{
// cache->compute
- std::unordered_map<IterVar, Expr> value_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute_axis) {
value_map[iv] = iv->var;
}
}
}
- Array<Expr> cache_expr_list;
+ Array<PrimExpr> cache_expr_list;
for (size_t i = 0; i < tensor_size; i++) {
Tensor cache_tensor = cache_op.output(i);
cache_expr_list.push_back(cache_tensor(args));
void InjectInline(ScheduleNode* sch) {
sch->InvalidateCache();
- std::vector<Array<Expr> > new_body(sch->stages.size());
+ std::vector<Array<PrimExpr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
std::vector<Stmt> new_hybrid_body(sch->stages.size());
std::vector<bool> hybrid_changed(sch->stages.size(), false);
if (stage->attach_type == kInline) {
stage->attach_type = kInlinedAlready;
Array<Var> args;
- Expr body;
+ PrimExpr body;
{
// setup args
const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
- Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+ PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
stage->op, args, body).as<ir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
auto n = make_object<ir::ReduceNode>(*r);
n->value_index = static_cast<int>(k);
n->dtype = r->source[k].dtype();
- new_body[j].Set(k, Expr(n));
+ new_body[j].Set(k, PrimExpr(n));
}
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
- Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+ PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
stage->op, args, body).as<ir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
arith::Analyzer analyzer;
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
- std::unordered_map<IterVar, Expr> value_map;
+ std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) {
dom_map[iv] = iv->dom;
}
}
schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
- std::vector<Expr> predicates = schedule::MakeBoundCheck(
+ std::vector<PrimExpr> predicates = schedule::MakeBoundCheck(
reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
- Expr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, Expr()));
+ PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
- std::unordered_map<const VarNode*, Expr> vsub;
+ std::unordered_map<const VarNode*, PrimExpr> vsub;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
} else {
CHECK(value_map.count(iv));
- Expr index = value_map.at(iv);
+ PrimExpr index = value_map.at(iv);
vsub[iv->var.get()] = index;
}
}
}
}
VarReplacer replacer(vsub);
- Array<Expr> new_source = ir::UpdateArray(reduce->source,
- [&replacer] (const Expr& e) { return replacer(e); });
+ Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
+ [&replacer] (const PrimExpr& e) { return replacer(e); });
- Expr new_pred = replacer(predicate);
+ PrimExpr new_pred = replacer(predicate);
- std::vector<Expr> body;
+ std::vector<PrimExpr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(ReduceNode::make(reduce->combiner,
new_source,
new_pred,
idx));
}
- n->body = Array<Expr>(body);
+ n->body = Array<PrimExpr>(body);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
}
Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
[&](const Array<Var>& i) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
const int idx_size = static_cast<int>(i.size());
for (int idx = 0; idx < idx_size; ++idx) {
if (factor_axis_pos == idx) {
if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
}
- Array<Expr> factor_exprs;
+ Array<PrimExpr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
factor_exprs.push_back(factor_tensors[idx](indices));
}
- Array<Expr> reductions;
+ Array<PrimExpr> reductions;
Array<IterVar> axis = {repl_red_axis};
- Expr cond = const_true();
+ PrimExpr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(ReduceNode::make(reduce->combiner,
factor_exprs, axis, cond, idx));
void Split(StageNode* self,
IterVar parent,
- Expr factor,
- Expr nparts,
+ PrimExpr factor,
+ PrimExpr nparts,
IterVar* p_outer,
IterVar* p_inner) {
// Check if split is valid.
return *this;
}
-Stage& Stage::set_store_predicate(Expr predicate) {
+Stage& Stage::set_store_predicate(PrimExpr predicate) {
StageNode* self = operator->();
self->store_predicate = predicate;
return *this;
}
Stage& Stage::split(
- IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
- Split(operator->(), parent, factor, Expr(), p_outer, p_inner);
+ IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+ Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
return *this;
}
Stage& Stage::split_by_nparts(
- IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
- Split(operator->(), parent, Expr(), nparts, p_outer, p_inner);
+ IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+ Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
return *this;
}
}
Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
- Expr x_factor, Expr y_factor,
+ PrimExpr x_factor, PrimExpr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner) {
split(x_parent, x_factor, p_x_outer, p_x_inner);
Stage& Stage::pragma(IterVar var,
const std::string& pragma_type,
- const Expr& pragma_value) { // NOLINT(*)
+ const PrimExpr& pragma_value) { // NOLINT(*)
if (pragma_type == "unroll") {
this->unroll(var);
} else if (pragma_type == "vectorize") {
return *this;
}
-Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
+Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
StageNode *self = operator->();
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
IterVarRelation SplitNode::make(IterVar parent,
IterVar outer,
IterVar inner,
- Expr factor,
- Expr nparts) {
+ PrimExpr factor,
+ PrimExpr nparts) {
auto n = make_object<SplitNode>();
n->parent = parent;
n->outer = outer;
}
}
- Expr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const CallNode* op) final {
if (op->call_type == CallNode::Halide) {
TensorKey key{op->func, op->value_index};
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- Expr ret = CallNode::make(
+ PrimExpr ret = CallNode::make(
op->dtype, dst->op->name, op->args,
op->call_type, dst->op, dst->value_index);
return this->VisitExpr(ret);
return StmtExprMutator::VisitExpr_(op);
}
- Expr VisitExpr_(const VarNode* op) final {
+ PrimExpr VisitExpr_(const VarNode* op) final {
auto it = var_value_.find(op);
if (it != var_value_.end()) {
return it->second;
} else {
- return GetRef<Expr>(op);
+ return GetRef<PrimExpr>(op);
}
}
replace_op_[src->op.get()] = repl_op;
}
// The thread extent scope.
- std::unordered_map<const Object*, Expr> thread_extent_scope_;
+ std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
// The scan value
- std::unordered_map<const VarNode*, Expr> var_value_;
+ std::unordered_map<const VarNode*, PrimExpr> var_value_;
// buffer replacement
std::unordered_map<TensorKey, Tensor> replace_buffer_;
// buffere realization to be replaced
struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
- Expr expr;
+ PrimExpr expr;
double learning_rate;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
LOG(FATAL) << "bad";
} catch (const tvm::AttrError& e) {
std::string what = e.what();
- CHECK(what.find("expr : Expr, default=1") != std::string::npos);
+ CHECK(what.find("expr : PrimExpr, default=1") != std::string::npos);
CHECK(what.find("axisx") != std::string::npos);
}
- n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx");
+ n->InitBySeq("learning_rate", PrimExpr(1), "expr", 128, "name", "xx");
CHECK_EQ(n->learning_rate, 1.0);
n->InitBySeq("name", "xxx", "expr", 128);
std::ostringstream os;
n->PrintDocString(os);
LOG(INFO) << "docstring\n"<< os.str();
- CHECK(os.str().find("expr : Expr, default=1") != std::string::npos);
+ CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos);
}
TEST(BuildModule, Basic) {
using namespace tvm;
auto n = var("n");
- Array<Expr> shape;
+ Array<PrimExpr> shape;
shape.push_back(n);
auto A = placeholder(shape, DataType::Float(32), "A");
auto B = placeholder(shape, DataType::Float(32), "B");
- auto C = compute(A->shape, [&A, &B](Expr i) {
+ auto C = compute(A->shape, [&A, &B](PrimExpr i) {
return A[i] + B[i];
}, "C");
// The shape of input tensors.
const int n = 4;
- Array<Expr> shape{n};
+ Array<PrimExpr> shape{n};
auto A = placeholder(shape, DataType::Float(32), "A");
auto B = placeholder(shape, DataType::Float(32), "B");
auto C = placeholder(shape, DataType::Float(32), "C");
- auto elemwise_add = compute(A->shape, [&A, &B](Expr i) {
+ auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) {
return A[i] + B[i];
}, "elemwise_add");
auto copy = placeholder(shape, DataType::Float(32), "__copy");
- auto elemwise_sub = compute(C->shape, [©, &C](Expr i) {
+ auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) {
return copy[i] - C[i];
}, "elemwise_sub");
ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), "");
}
-TEST(Array, Expr) {
+TEST(Array, PrimExpr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
- Array<Expr> list{x, z, z};
+ Array<PrimExpr> list{x, z, z};
LOG(INFO) << list.size();
LOG(INFO) << list[0];
LOG(INFO) << list[1];
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
- Array<Expr> list{x, z, z};
+ Array<PrimExpr> list{x, z, z};
auto list2 = list;
list.Set(1, x);
CHECK(list[1].same_as(x));
TEST(Array, Iterator) {
using namespace tvm;
- Array<Expr> array{1, 2, 3};
- std::vector<Expr> vector(array.begin(), array.end());
+ Array<PrimExpr> array{1, 2, 3};
+ std::vector<PrimExpr> vector(array.begin(), array.end());
CHECK(vector[1].as<IntImmNode>()->value == 2);
}
Var x("x");
auto z = max(x + 1 + 2, 100);
auto zz = z + 1;
- Map<Expr, Expr> dict{{x, z}, {z, 2}};
+ Map<PrimExpr, PrimExpr> dict{{x, z}, {z, 2}};
CHECK(dict.size() == 2);
CHECK(dict[x].same_as(z));
CHECK(dict.count(z));
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
- Map<std::string, Expr> dict{{"x", z}, {"z", 2}};
+ Map<std::string, PrimExpr> dict{{"x", z}, {"z", 2}};
CHECK(dict.size() == 2);
CHECK(dict["x"].same_as(z));
}
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
- Map<Expr, Expr> dict{{x, z}, {z, 2}};
+ Map<PrimExpr, PrimExpr> dict{{x, z}, {z, 2}};
auto zz = z + 1;
CHECK(dict[x].same_as(z));
dict.Set(x, zz);
TEST(Map, Iterator) {
using namespace tvm;
- Expr a = 1, b = 2;
- Map<Expr, Expr> map1{{a, b}};
- std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>
+ PrimExpr a = 1, b = 2;
+ Map<PrimExpr, PrimExpr> map1{{a, b}};
+ std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual>
map2(map1.begin(), map1.end());
CHECK(map2[a].as<IntImmNode>()->value == 2);
}
Var x("x");
auto z = max(x + 1 + 2, 100);
ObjectRef tmp = z;
- Expr zz = Downcast<Expr>(tmp);
+ PrimExpr zz = Downcast<PrimExpr>(tmp);
std::ostringstream os;
os << z;
CHECK(zz.same_as(z));
TEST(ExprNodeRef, Basic) {
using namespace tvm;
Var x("x");
- Expr z = max(x + 1 + 2, 100);
+ PrimExpr z = max(x + 1 + 2, 100);
const ir::MaxNode* op = z.as<ir::MaxNode>();
CHECK(GetRef<ObjectRef>(op).same_as(z));
}
auto z = x + 1;
class MyExprFunctor
- : public ir::ExprFunctor<int(const Expr&, int)> {
+ : public ir::ExprFunctor<int(const PrimExpr&, int)> {
public:
int VisitExpr_(const VarNode* op, int b) final {
return b;
auto z = x + 1;
class MyVisitor
- : public ir::ExprFunctor<void(const Expr&)>,
+ : public ir::ExprFunctor<void(const PrimExpr&)>,
public ir::StmtFunctor<void(const Stmt&)> {
public:
int count = 0;
protected:
// implementation
- Expr VisitExpr_(const AddNode* op) final {
+ PrimExpr VisitExpr_(const AddNode* op) final {
return op->a;
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
return StmtMutator::VisitSeqStmt_(op, true);
}
- Expr VisitExpr(const Expr& expr) final {
+ PrimExpr VisitExpr(const PrimExpr& expr) final {
return ExprMutator::VisitExpr(expr);
}
};
using namespace tvm;
using namespace tvm::ir;
Var x("x"), y;
- Expr let = LetNode::make(x, 1, x + 1);
+ PrimExpr let = LetNode::make(x, 1, x + 1);
auto z = EvaluateNode::make(let + let);
CHECK(!ir::VerifySSA(z));
using namespace tvm::runtime;
// automatic conversion of int to expr
PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
- Expr x = args[0];
+ PrimExpr x = args[0];
*rv = x.as<tvm::ir::IntImmNode>()->value + 1;
});
int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
// Check convert back
CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x));
- CHECK(!rv.IsObjectRef<Expr>());
+ CHECK(!rv.IsObjectRef<PrimExpr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
CHECK(args[1].operator Array<NDArray>().get() == nullptr);
- CHECK(!args[0].IsObjectRef<Expr>());
+ CHECK(!args[0].IsObjectRef<PrimExpr>());
});
pf1(x, ObjectRef());
pf1(ObjectRef(x), NDArray());
CHECK(args[1].operator ObjectRef().get() == nullptr);
CHECK(args[1].operator NDArray().get() == nullptr);
CHECK(args[1].operator Module().get() == nullptr);
- CHECK(!args[0].IsObjectRef<Expr>());
+ CHECK(!args[0].IsObjectRef<PrimExpr>());
});
pf2(m, ObjectRef());
pf2(ObjectRef(m), Module());
using namespace tvm;
using namespace tvm::arith;
Var x("x"), y("y"), z("z");
- arith::PVar<Expr> px, py, pz;
+ arith::PVar<PrimExpr> px, py, pz;
arith::PVar<DataType> pt;
arith::PVar<int> planes;
CHECK((px + min(py, px)).Match(z + min(y, z)));
CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
- CHECK((px - floormod(py, px * PConst<Expr>(2))).Match(x - floormod(2, x * 2)));
+ CHECK((px - floormod(py, px * PConst<PrimExpr>(2))).Match(x - floormod(2, x * 2)));
// logicals
CHECK((px == pz).Match(x == 1));
}
// ramp pattern
{
- CHECK(ramp(px, PConst<Expr>(1), planes).Match(
+ CHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(
ir::RampNode::make(x, 1, 10)));
CHECK(planes.Eval() == 10);
- CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
+ CHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(
ir::RampNode::make(x, 2, 10)));
}
// broadcast pattern
TEST(SimplePasses, HasSideEffect) {
using namespace tvm;
auto n = var("n");
- Array<Expr> shape;
+ Array<PrimExpr> shape;
shape.push_back(n);
auto A = placeholder(shape, DataType::Float(32), "A");
B1 = B[0]
B2 = B[0,0]
- assert isinstance(k + n, tvm.expr.Expr)
- assert isinstance(n + n, tvm.expr.Expr)
+ assert isinstance(k + n, tvm.expr.PrimExpr)
+ assert isinstance(n + n, tvm.expr.PrimExpr)
assert isinstance(k + A, tvm.tensor.Tensor)
assert isinstance(A + k, tvm.tensor.Tensor)
assert isinstance(n + A, tvm.tensor.Tensor)
assert (B + A).op.tag == topi.tag.BROADCAST
assert (B + B).op.tag == topi.tag.BROADCAST
- assert isinstance(k + B2, tvm.expr.Expr)
- assert isinstance(B2 + k, tvm.expr.Expr)
- assert isinstance(n + B2, tvm.expr.Expr)
- assert isinstance(B2 + n, tvm.expr.Expr)
- assert isinstance(B2 + B2, tvm.expr.Expr)
+ assert isinstance(k + B2, tvm.expr.PrimExpr)
+ assert isinstance(B2 + k, tvm.expr.PrimExpr)
+ assert isinstance(n + B2, tvm.expr.PrimExpr)
+ assert isinstance(B2 + n, tvm.expr.PrimExpr)
+ assert isinstance(B2 + B2, tvm.expr.PrimExpr)
assert isinstance(B2 + A, tvm.tensor.Tensor)
assert isinstance(A + B2, tvm.tensor.Tensor)
assert isinstance(B2 + B, tvm.tensor.Tensor)
def lower_intrin(stmt):
"""wrapper to call transformation in stmt"""
- lower_expr = isinstance(stmt, tvm.expr.Expr)
+ lower_expr = isinstance(stmt, tvm.expr.PrimExpr)
stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm")
* \return A Tensor whose op member is a broadcast operation
*/
inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
- const tvm::Array<tvm::Expr>& output_shape,
+ const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
CHECK_GE(output_shape.size(), t->shape.size())
return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
};
return tvm::compute(
- tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()),
+ tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l,
name,
tag);
}
#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
- inline tvm::Expr Name(const tvm::Expr& a, \
- const tvm::Expr& b) { \
+ inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \
+ const tvm::PrimExpr& b) { \
ComputeRule; \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
const tvm::Tensor& B, \
std::string name = "T_" #Name, \
std::string tag = kBroadcast) { \
- auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
+ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return detail::WithBroadcast(l, A, B, name, tag); \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
- const tvm::Expr& B, \
+ const tvm::PrimExpr& B, \
std::string name = "T_" #Name, \
std::string tag = kElementWise) { \
- auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
+ auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A(i), B); \
}, name, tag); \
} \
- inline tvm::Tensor Name(const tvm::Expr& A, \
+ inline tvm::Tensor Name(const tvm::PrimExpr& A, \
const tvm::Tensor& B, \
std::string name = "T_" #Name, \
std::string tag = kElementWise) { \
- auto l = [&](tvm::Expr a, tvm::Expr b) { ComputeRule; }; \
+ auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
return l(A, B(i)); \
}, name, tag); \
const tvm::Tensor& B) { \
return topi::OpName(A, B); \
} \
- inline tvm::Tensor Name(const tvm::Expr& A, \
+ inline tvm::Tensor Name(const tvm::PrimExpr& A, \
const tvm::Tensor& B) { \
return topi::OpName(A, B); \
} \
inline tvm::Tensor Name(const tvm::Tensor& A, \
- const tvm::Expr& B) { \
+ const tvm::PrimExpr& B) { \
return topi::OpName(A, B); \
}
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- Expr("tvm.contrib.cublas.matmul"),
+ PrimExpr("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
}
/*!
-* \brief Create an op that multiplies batch matrices
+* \brief Create an op that multiplies batch matrices
* lhs and rhs with cuBLAS
*
* \param lhs The left matrix operand
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- Expr("tvm.contrib.cublas.batch_matmul"),
+ PrimExpr("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- Expr("tvm.contrib.rocblas.matmul"),
+ PrimExpr("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
auto thread_x = tvm::thread_axis(Range(), "threadIdx.x");
s[dense].bind(tx, thread_x);
s[dense_f].compute_at(s[dense], tx);
- s[dense].set_store_predicate(static_cast<Expr>(thread_x) == 0);
- s[out].set_store_predicate(static_cast<Expr>(thread_x) == 0);
+ s[dense].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
+ s[out].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
};
std::function<void(Operation)> traverse;
}
}
- stage_real.set_store_predicate(static_cast<Expr>(thread_x) == 0);
+ stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
return sch;
}
namespace detail {
struct BroadcastHelper {
- std::deque<tvm::Expr> common_shape;
+ std::deque<tvm::PrimExpr> common_shape;
std::deque<tvm::Var> all_vars;
std::deque<tvm::Var> vars1;
std::deque<tvm::Var> vars2;
};
-inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
- const tvm::Array<tvm::Expr>& shape2) {
+inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
+ const tvm::Array<tvm::PrimExpr>& shape2) {
BroadcastHelper bh;
int s1_size = shape1.size();
int s2_size = shape2.size();
- tvm::Expr one(1);
+ tvm::PrimExpr one(1);
int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
} else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: "
- << tvm::Array<tvm::Expr>(shape1.begin(), shape1.end())
+ << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end())
<< " and "
- << tvm::Array<tvm::Expr>(shape2.begin(), shape2.end());
+ << tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
}
}
// Remaining dimensions whether on shape1 or shape2 can always be completed
return bh;
}
-inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
+inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
const tvm::Array<tvm::Var>& ovars,
const tvm::Tensor& T,
const std::deque<tvm::Var>& my_vars,
const std::deque<tvm::Var>& all_vars) {
- tvm::Array<tvm::Expr> ivars;
+ tvm::Array<tvm::PrimExpr> ivars;
CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD.
size_t expected_dims = T->shape.size();
B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
};
return tvm::compute(
- tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()),
+ tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
l,
name,
tag);
*
* \return true if the given expr is a constant int or uint, false otherwise.
*/
-inline bool IsConstInt(Expr expr) {
+inline bool IsConstInt(PrimExpr expr) {
return
expr->IsInstance<tvm::ir::IntImmNode>() ||
expr->IsInstance<tvm::ir::UIntImmNode>();
*
* \return The integer value.
*/
-inline int64_t GetConstInt(Expr expr) {
+inline int64_t GetConstInt(PrimExpr expr) {
if (expr->IsInstance<tvm::ir::IntImmNode>()) {
return expr.as<tvm::ir::IntImmNode>()->value;
}
*
* \return A vector of the integer values
*/
-inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
+inline std::vector<int> GetConstIntValues(
+ Array<PrimExpr> exprs, const std::string& var_name) {
std::vector<int> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
- CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
+ CHECK(IsConstInt(expr)) << "All elements of "
+ << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
}
return result;
*
* \return A vector of the int64_t values
*/
-inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
+inline std::vector<int64_t> GetConstInt64Values(
+ Array<PrimExpr> exprs, const std::string& var_name) {
std::vector<int64_t> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
*
* \return result True if both expressions are equal, else false
*/
-inline bool EqualCheck(Expr lhs, Expr rhs) {
+inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
bool result = tvm::ir::Equal(lhs, rhs);
if (!result) {
- Expr zero(0);
+ PrimExpr zero(0);
result = tvm::ir::Equal(tvm::ir::CanonicalSimplify(lhs-rhs), zero);
}
return result;
*
* \return The Buffer object
*/
-inline Buffer DeclExternBuffer(Array<Expr> shape,
+inline Buffer DeclExternBuffer(Array<PrimExpr> shape,
DataType dtype,
std::string name) {
auto data = var(name, DataType::Handle());
- auto elem_offset = Expr();
- return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
+ auto elem_offset = PrimExpr();
+ return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
-1, 0, kDefault);
}
* function. The function expects two arguments: an array of Buffers holding the input
* tensor values, and a pre-allocated array of Buffers to be filled with the outputs.
*/
-using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
+using FExtern = std::function<PrimExpr(Array<Buffer>, Array<Buffer>)>;
/*!
* \brief Create tensors representing the result of invoking an external function.
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
* element of out_types.
*/
-inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
+inline Array<Tensor> make_extern(const Array< Array<PrimExpr> >& out_shapes,
const std::vector<DataType>& out_types,
const Array<Tensor>& inputs,
FExtern fextern,
*
* \return An expression representing the pack operation
*/
-inline Expr pack_buffer(Buffer buf) {
+inline PrimExpr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
- Expr strides;
+ PrimExpr strides;
if (buf->strides.size() > 0) {
strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
} else {
strides = 0;
}
- Array<Expr> pack_args{
+ Array<PrimExpr> pack_args{
buf->data,
shape,
strides,
*
* \return An expression representing the invocation
*/
-inline Expr call_packed(Array<Expr> args) {
+inline PrimExpr call_packed(Array<PrimExpr> args) {
return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
args, tvm::ir::CallNode::CallType::Intrinsic);
}
* \return An array of 4 elements, representing padding sizes for
* each individual side. The array is in the order { top, left, bottom, right }
*/
-inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
+inline Array<PrimExpr> GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) {
pad_h *= 2;
pad_w *= 2;
*
* \return The index after flattening
*/
-inline Expr RavelIndex(Array<Expr> indices, Array<Expr> shape) {
+inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
- Expr idx;
+ PrimExpr idx;
for (size_t i = 0; i < indices.size(); ++i) {
if (i == 0) {
idx = indices[i];
*
* \return The coordinate corresponding to the 1D index
*/
-inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) {
- std::vector<Expr> indices;
+inline Array<PrimExpr> UnravelIndex(PrimExpr idx, Array<PrimExpr> shape) {
+ std::vector<PrimExpr> indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
indices.push_back(indexmod(idx, shape[i]));
*
* \return True if the input shape is empty.
*/
-inline bool is_empty_shape(const Array<Expr>& x) {
+inline bool is_empty_shape(const Array<PrimExpr>& x) {
bool is_empty = false;
for (const auto& dim : x) {
if (auto int_dim = dim.as<IntImmNode>()) {
std::string name = "T_sign",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
- Expr zero = make_zero(x->dtype);
- Expr one = make_const(x->dtype, 1);
- Expr minus_one = make_const(x->dtype, -1);
+ PrimExpr zero = make_zero(x->dtype);
+ PrimExpr one = make_const(x->dtype, 1);
+ PrimExpr minus_one = make_const(x->dtype, -1);
auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero);
auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1);
return s2;
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
- Expr one = make_const(x->dtype, 1);
+ PrimExpr one = make_const(x->dtype, 1);
return one/tvm::sqrt(x(i));
}, name, tag);
}
* \return A Tensor whose op member is the clip operation
*/
inline Tensor clip(const Tensor& x,
- const Expr& a_min,
- const Expr& a_max,
+ const PrimExpr& a_min,
+ const PrimExpr& a_max,
std::string name = "T_clip",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
*
* \return A Tensor whose op member is the full operation
*/
-inline Tensor full(const Array<Expr>& shape,
+inline Tensor full(const Array<PrimExpr>& shape,
DataType dtype,
- const Expr fill_value,
+ const PrimExpr fill_value,
std::string name = "T_full",
std::string tag = kElementWise) {
- Expr ev = cast(dtype, fill_value);
+ PrimExpr ev = cast(dtype, fill_value);
if (!ev.defined()) {
LOG(ERROR) << "Can't cast fill_value to " << dtype;
}
* \return A Tensor whose op memeber is the full_like operation
*/
inline Tensor full_like(const Tensor& x,
- const Expr fill_value,
+ const PrimExpr fill_value,
std::string name = "T_full_like",
std::string tag = kElementWise) {
- Expr ev = cast(x->dtype, fill_value);
+ PrimExpr ev = cast(x->dtype, fill_value);
return compute(x->shape, [&](const Array<Var>& i) {
return ev;
}, name, tag);
*
* \return The interpolated value in the given index.
*/
-inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices,
- const Expr max_y, const Expr max_x) {
+inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
+ const PrimExpr max_y, const PrimExpr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(cast(DataType::Int(32), shape[0]));
out_shape.push_back(cast(DataType::Int(32), shape[1]));
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1] * input->shape[1] / shape[0]);
idx.push_back(indices[2] * input->shape[2] / shape[1]);
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(cast(DataType::Int(32), shape[0]));
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] * input->shape[2] / shape[0]);
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(cast(DataType::Int(32), shape[0]));
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] * input->shape[2] / shape[0]);
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string name = "tensor",
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear_nhwc(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(cast(DataType::Int(32), shape[0]));
out_shape.push_back(cast(DataType::Int(32), shape[1]));
out_shape.push_back(input->shape[3]);
- Expr cone = make_const(DataType::Int(32), 1);
+ PrimExpr cone = make_const(DataType::Int(32), 1);
auto in_height = as_const_int(input->shape[1]);
auto in_width = as_const_int(input->shape[2]);
auto out_height = as_const_int(shape[0]);
auto out_width = as_const_int(shape[1]);
- Expr y_ratio;
- Expr x_ratio;
+ PrimExpr y_ratio;
+ PrimExpr x_ratio;
if (!align_corners) {
y_ratio = make_const(DataType::Float(32), (static_cast<float>(*in_height) /
static_cast<float>(*out_width - 1)));
}
- Expr other_y = tvm::ir::Simplify(input->shape[1] - cone);
- Expr other_x = tvm::ir::Simplify(input->shape[2] - cone);
+ PrimExpr other_y = tvm::ir::Simplify(input->shape[1] - cone);
+ PrimExpr other_x = tvm::ir::Simplify(input->shape[2] - cone);
return compute(
out_shape, [&](const Array<Var>& indices) {
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear_nchw(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(cast(DataType::Int(32), shape[0]));
out_shape.push_back(cast(DataType::Int(32), shape[1]));
- Expr cone = make_const(DataType::Int(32), 1);
+ PrimExpr cone = make_const(DataType::Int(32), 1);
auto in_height = as_const_int(input->shape[2]);
auto in_width = as_const_int(input->shape[3]);
auto out_height = as_const_int(shape[0]);
auto out_width = as_const_int(shape[1]);
- Expr y_ratio;
- Expr x_ratio;
+ PrimExpr y_ratio;
+ PrimExpr x_ratio;
if (!align_corners) {
y_ratio = make_const(DataType::Float(32), (static_cast<float>(*in_height) /
static_cast<float>(*out_width - 1)));
}
- Expr other_y = tvm::ir::Simplify(input->shape[2] - cone);
- Expr other_x = tvm::ir::Simplify(input->shape[3] - cone);
+ PrimExpr other_y = tvm::ir::Simplify(input->shape[2] - cone);
+ PrimExpr other_x = tvm::ir::Simplify(input->shape[3] - cone);
return compute(
out_shape, [&](const Array<Var>& indices) {
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear(const Tensor& input,
- const Array<tvm::Expr>& shape,
+ const Array<tvm::PrimExpr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string name = "tensor",
* \return A Tensor resized to given shape
*/
inline Tensor resize(const Tensor& input,
- const Array<Expr>& shape,
+ const Array<PrimExpr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string mode = "BILINEAR",
namespace detail {
template <typename T>
-tvm::Expr Map(const tvm::Array<tvm::Expr>& exprs, T op) {
+tvm::PrimExpr Map(const tvm::Array<tvm::PrimExpr>& exprs, T op) {
CHECK_GE(exprs.size(), 1);
- tvm::Expr res = exprs[0];
+ tvm::PrimExpr res = exprs[0];
for (size_t i = 1; i < exprs.size(); ++i) {
res = op(res, exprs[i]);
}
*
*/
inline tvm::Tensor pad(const tvm::Tensor& t,
- const tvm::Array<tvm::Expr>& pad_before,
- tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
- Expr pad_value = Expr(),
+ const tvm::Array<tvm::PrimExpr>& pad_before,
+ tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
+ PrimExpr pad_value = PrimExpr(),
std::string name = "T_pad",
std::string tag = kElementWise,
std::string pad_mode = "constant") {
}
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
- tvm::Array<tvm::Expr> output_shape;
- tvm::Array<tvm::Expr> pad_before_int32;
- tvm::Array<tvm::Expr> pad_after_int32;
+ tvm::Array<tvm::PrimExpr> output_shape;
+ tvm::Array<tvm::PrimExpr> pad_before_int32;
+ tvm::Array<tvm::PrimExpr> pad_after_int32;
for (const auto &ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
pad_value = tvm::make_const(t->dtype, 0);
}
auto l = [&](tvm::Array<tvm::Var> ovars) {
- tvm::Array<tvm::Expr> indices;
- tvm::Array<tvm::Expr> sel;
- tvm::Array<tvm::Expr> pad_idx;
+ tvm::Array<tvm::PrimExpr> indices;
+ tvm::Array<tvm::PrimExpr> sel;
+ tvm::Array<tvm::PrimExpr> pad_idx;
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]);
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
- tvm::Array<tvm::Expr> output_shape{
+ tvm::Array<tvm::PrimExpr> output_shape{
I->shape[0], // B
W->shape[0], // O
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
auto T = (pad_h == 0 && pad_w == 0)
? I
- : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+ : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
return tvm::sum(
T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw),
CHECK_EQ(4, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
- tvm::Array<tvm::Expr> output_shape{
+ tvm::Array<tvm::PrimExpr> output_shape{
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
I->shape[2], // B
auto pH = I->shape[2];
auto pW = I->shape[3];
auto pCM = W->shape[1]; // channel_multiplier
- tvm::Array<tvm::Expr> output_shape{
+ tvm::Array<tvm::PrimExpr> output_shape{
I->shape[0], // B
W->shape[1], // O
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
auto T = (pad_h == 0 && pad_w == 0)
? I
- : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+ : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
auto pH = I->shape[1];
auto pW = I->shape[2];
auto pCM = W->shape[1]; // channel_multiplier
- tvm::Array<tvm::Expr> output_shape{
+ tvm::Array<tvm::PrimExpr> output_shape{
I->shape[0], // B
indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
auto T = (pad_h == 0 && pad_w == 0)
? I
- : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
+ : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
CHECK_EQ(5, W->shape.size());
auto pH = I->shape[2];
auto pW = I->shape[3];
- tvm::Array<tvm::Expr> output_shape{
+ tvm::Array<tvm::PrimExpr> output_shape{
I->shape[0], // B
I->shape[1], // G
W->shape[2], // O
auto T = (pad_h == 0 && pad_w == 0)
? I
- : pad(I, {tvm::Expr(0), tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+ : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
auto l = [&](tvm::Array<tvm::Var> args) {
tvm::Var b = args[0];
tvm::Var g = args[1];
<< "binarize_pack: axis size must be a multiple of 32";
auto n = ishape.size();
- Array<Expr> oshape;
+ Array<PrimExpr> oshape;
for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::ir::Simplify(indexdiv(ishape[i], 32)) :
return tvm::compute(
oshape,
[&](const Array<Var>& indices) {
- Array<Expr> start_idx;
+ Array<PrimExpr> start_idx;
for (size_t i = 0; i < n; ++i) {
start_idx.push_back(i == static_cast<size_t>(axis) ?
indices[i] * 32 :
- static_cast<Expr>(indices[i]));
+ static_cast<PrimExpr>(indices[i]));
}
auto packed = make_const(DataType::UInt(32), 0);
for (size_t j = 0; j < 32; ++j) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
for (size_t i = 0; i < n; ++i) {
idx.push_back(i == static_cast<size_t>(axis) ?
start_idx[i] + static_cast<int>(j) :
*
* \return The logical conjunction expression
*/
-Expr all(Array<Expr> args) {
+PrimExpr all(Array<PrimExpr> args) {
CHECK_GT(args.size(), 0) << "all requires at least one argument";
- Expr ret = args[0];
+ PrimExpr ret = args[0];
for (size_t i = 1; i < args.size(); ++i) {
ret = ret && args[i];
}
* \return The output tensor.
*/
inline Tensor dilate(const Tensor& x,
- Array<Expr> strides,
+ Array<PrimExpr> strides,
std::string name = "tensor",
std::string tag = kInjective) {
auto n = x->shape.size();
<< "strides size (" << strides.size()
<< ") must match dimension of x (" << n << ")";
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify(
(x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
return tvm::compute(
out_shape,
[&](const Array<Var>& indices) {
- Array<Expr> not_zero;
- Array<Expr> index_tuple;
+ Array<PrimExpr> not_zero;
+ Array<PrimExpr> index_tuple;
for (size_t i = 0; i < n; ++i) {
if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
index_tuple.push_back(indices[i]);
std::string name = "tensor",
std::string tag = kInjective) {
auto ishape = x->shape;
- Expr dim = 1;
+ PrimExpr dim = 1;
for (size_t i = 1; i < ishape.size(); ++i) {
dim = dim * ishape[i];
}
- Array<Expr> oshape({ ishape[0], dim });
+ Array<PrimExpr> oshape({ ishape[0], dim });
- std::vector<Expr> extra_shape;
+ std::vector<PrimExpr> extra_shape;
for (size_t i = 1; i < ishape.size(); ++i) {
extra_shape.push_back(ishape[i]);
}
return tvm::compute(
oshape, [&](Var i, Var j) {
- Expr idx = j;
- std::vector<Expr> index;
+ PrimExpr idx = j;
+ std::vector<PrimExpr> index;
for (auto s : extra_shape) {
index.push_back(indexmod(idx, s));
idx = indexdiv(idx, s);
CHECK_EQ(size % 2, 1) << "size should be odd number";
CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
auto input_shape = data->shape;
- Array<Expr> pad_before{ 0, 0, 0, 0};
- Array<Expr> pad_after{ 0, 0, 0, 0};
- pad_before.Set(axis, static_cast<Expr>(size/2));
- pad_after.Set(axis, static_cast<Expr>(size/2));
+ Array<PrimExpr> pad_before{ 0, 0, 0, 0};
+ Array<PrimExpr> pad_after{ 0, 0, 0, 0};
+ pad_before.Set(axis, static_cast<PrimExpr>(size/2));
+ pad_after.Set(axis, static_cast<PrimExpr>(size/2));
auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
Tensor sqr_sum;
* \return The output tensor in same layout order
*/
inline Tensor pool_impl(const Tensor& x,
- const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size,
- const Array<Expr>& padding_size,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const size_t height_axis,
pad_right += stride_width - 1;
}
- Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
+ Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
pad_before.Set(height_axis, pad_top);
pad_before.Set(width_axis, pad_left);
- Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
+ Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
- Array<Expr> out_shape = x->shape;
+ Array<PrimExpr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
auto temp = do_pad ? pad(
x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
return tvm::compute(out_shape, [&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, output[height_axis] * stride_height + dheight);
indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
// TVM compute for summing the pooling window.
auto pool_sum = tvm::compute(out_shape,
[&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
indices.Set(height_axis, output[height_axis] * stride_height + dheight);
indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
// TVM compute for dividing the reduced window sum by kernel size.
return tvm::compute(out_shape,
[&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
if (count_include_pad) {
return div(pool_sum(indices), (kernel_height * kernel_width));
} else {
- Expr h_start = output[height_axis] * stride_height - pad_top;
- Expr w_start = output[width_axis] * stride_width - pad_left;
- Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
- Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+ PrimExpr h_start = output[height_axis] * stride_height - pad_top;
+ PrimExpr w_start = output[width_axis] * stride_width - pad_left;
+ PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height);
+ PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width);
h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
- Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
+ PrimExpr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
make_const(DataType::DataType::Int(32), 1));
return div(pool_sum(indices), divide_factor);
}
}
}
-inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
- const Array<Expr>& kernel_size, const Array<Expr>& stride_size,
- const Array<Expr>& padding_size, PoolType pool_type, bool ceil_mode,
+inline Tensor pool_grad_impl(const Tensor& out_grad,
+ const Tensor& x,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
+ PoolType pool_type, bool ceil_mode,
const size_t height_axis, const size_t width_axis,
bool count_include_pad) {
CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
pad_right += stride_width - 1;
}
- Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
+ Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
pad_before.Set(height_axis, pad_top);
pad_before.Set(width_axis, pad_left);
- Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
+ Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
- Array<Expr> out_shape = x->shape;
+ Array<PrimExpr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
if (pool_type == kMaxPool) {
- Array<Expr> ravel_shape{x->shape.begin(), x->shape.end()};
+ Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
auto mp_argmax =
tvm::compute(out_shape,
[&](const Array<Var>& inds) {
- Array<Expr> window_inds{inds.begin(), inds.end()};
+ Array<PrimExpr> window_inds{inds.begin(), inds.end()};
window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
auto idx = detail::RavelIndex(window_inds, ravel_shape);
return tvm::compute(
x->shape,
[&](const Array<Var>& inds) {
- Array<Expr> pad_inds {inds.begin(), inds.end()};
+ Array<PrimExpr> pad_inds {inds.begin(), inds.end()};
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
auto idx = detail::RavelIndex(pad_inds, ravel_shape);
- Array<Expr> out_idx {inds.begin(), inds.end()};
+ Array<PrimExpr> out_idx {inds.begin(), inds.end()};
out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
- Expr out_idx_lower_h = ir::SelectNode::make(
+ PrimExpr out_idx_lower_h = ir::SelectNode::make(
pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
(pad_inds[height_axis] - kernel_height) / stride_height + 1);
- Expr out_idx_lower_w = ir::SelectNode::make(
+ PrimExpr out_idx_lower_w = ir::SelectNode::make(
pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
(pad_inds[width_axis] - kernel_width) / stride_width + 1);
return tvm::compute(
x->shape,
[&](const Array<Var>& inds) {
- Expr pad_h_idx = inds[height_axis] + pad_top;
- Expr pad_w_idx = inds[width_axis] + pad_left;
+ PrimExpr pad_h_idx = inds[height_axis] + pad_top;
+ PrimExpr pad_w_idx = inds[width_axis] + pad_left;
// output indices whose pooling windows cover current input element (can be out-of-bound)
- Array<Expr> out_idx{inds.begin(), inds.end()};
+ Array<PrimExpr> out_idx{inds.begin(), inds.end()};
out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
- Expr out_idx_lower_h = ir::SelectNode::make(
+ PrimExpr out_idx_lower_h = ir::SelectNode::make(
pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
(pad_h_idx - kernel_height) / stride_height + 1);
- Expr out_idx_lower_w = ir::SelectNode::make(
+ PrimExpr out_idx_lower_w = ir::SelectNode::make(
pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
(pad_w_idx - kernel_width) / stride_width + 1);
- Expr divide_factor; // number of pooled elements
+ PrimExpr divide_factor; // number of pooled elements
if (count_include_pad) {
divide_factor = kernel_height * kernel_width;
} else {
- Expr h_start = out_idx[height_axis] * stride_height - pad_top;
- Expr w_start = out_idx[width_axis] * stride_width - pad_left;
- Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
- Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+ PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
+ PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
+ PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height);
+ PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width);
h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0));
w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0));
divide_factor =
* \return The output tensor in the same layout
*/
inline Tensor pool(const Tensor& x,
- const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size,
- const Array<Expr>& padding_size,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCHW",
*
* \return The output tensor in the same layout
*/
-inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size, const Array<Expr>& padding_size,
+inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
bool count_include_pad = true) {
int height_axis = -1, width_axis = -1;
height_axis, width_axis, count_include_pad);
}
-inline Expr start_index(const Var& out_index,
- const Expr& odim,
- const Expr& idim) {
+inline PrimExpr start_index(const Var& out_index,
+ const PrimExpr& odim,
+ const PrimExpr& idim) {
return indexdiv(out_index * idim, odim);
}
-inline Expr end_index(const Var& out_index,
- const Expr& odim,
- const Expr& idim) {
- Expr tmp = indexdiv((out_index + 1) * idim, odim);
+inline PrimExpr end_index(const Var& out_index,
+ const PrimExpr& odim,
+ const PrimExpr& idim) {
+ PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
return tvm::ir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0,
tmp, tmp + 1);
}
* \return The output tensor in same layout order
*/
inline Tensor adaptive_pool_impl(const Tensor& x,
- const Array<Expr>& output_size,
+ const Array<PrimExpr>& output_size,
PoolType pool_type,
const size_t height_axis,
const size_t width_axis) {
auto out_height = cast(DataType::Int(32), output_size[0]);
auto out_width = cast(DataType::Int(32), output_size[1]);
- Array<Expr> out_shape = x->shape;
+ Array<PrimExpr> out_shape = x->shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
if (pool_type == kMaxPool) {
return tvm::compute(out_shape, [&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
auto i_end_h = end_index(output[height_axis], out_height, height);
}, "tensor", "adaptive_pool_max");
} else if (pool_type == kAvgPool) {
auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
auto i_end_h = end_index(output[height_axis], out_height, height);
}, "tensor", "adaptive_pool_sum");
return tvm::compute(out_shape, [&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
auto i_start_h = start_index(output[height_axis], out_height, height);
auto i_end_h = end_index(output[height_axis], out_height, height);
* \return The output tensor in same layout order
*/
inline Tensor adaptive_pool(const Tensor& x,
- const Array<Expr>& output_size,
+ const Array<PrimExpr>& output_size,
PoolType pool_type,
const std::string& layout = "NCHW") {
int height_axis = -1, width_axis = -1;
inline Tensor global_pool(const Tensor& x,
PoolType pool_type,
const std::string& layout = "NCHW") {
- return adaptive_pool(x, Array<Expr>{1, 1}, pool_type, layout);
+ return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
}
/*!
* \return The output tensor in same layout order
*/
inline Tensor pool_impl_nd(const Tensor& x,
- const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size,
- const Array<Expr>& padding_size,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::vector<int>& axis,
CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
Array<IterVar> daxis;
- std::vector<Expr> kernel(k_size);
- std::vector<Expr> stride(k_size);
- std::vector<Expr> pad_head(k_size);
- std::vector<Expr> pad_tail(k_size);
- Array<Expr> pad_before(std::vector<Expr>(x_size, 0));
- Array<Expr> pad_after(std::vector<Expr>(x_size, 0));
- Array<Expr> out_shape = x->shape;
+ std::vector<PrimExpr> kernel(k_size);
+ std::vector<PrimExpr> stride(k_size);
+ std::vector<PrimExpr> pad_head(k_size);
+ std::vector<PrimExpr> pad_tail(k_size);
+ Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
+ Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
+ Array<PrimExpr> out_shape = x->shape;
bool do_pad = false;
for (int i = 0; i < k_size; i++) {
auto temp = do_pad ? pad(
x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
return tvm::compute(out_shape, [&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
for (int i = 0; i < k_size; i++) {
// TVM compute for summing the pooling window.
auto pool_sum = tvm::compute(out_shape,
[&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
for (int i = 0; i < k_size; i++) {
// TVM compute for dividing the reduced window sum by kernel size.
return tvm::compute(out_shape,
[&](const Array<Var>& output) {
- Array<Expr> indices;
+ Array<PrimExpr> indices;
for (const Var& var : output) indices.push_back(var);
if (count_include_pad) {
auto kernel_size = make_const(DataType::Int(32), 1);
}
return div(pool_sum(indices), kernel_size);
} else {
- std::vector<Expr> start(k_size);
- std::vector<Expr> end(k_size);
+ std::vector<PrimExpr> start(k_size);
+ std::vector<PrimExpr> end(k_size);
auto kernel_size = make_const(DataType::Int(32), 1);
for (int i = 0; i < k_size; i++) {
int ii = axis[i];
kernel_size *= (end[i] - start[i]);
}
- Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
+ PrimExpr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
return div(pool_sum(indices), divide_factor);
}
}, "tensor", kElementWise);
* \return The output tensor in the same layout
*/
inline Tensor pool1d(const Tensor& x,
- const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size,
- const Array<Expr>& padding_size,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCW",
* \return The output tensor in the same layout
*/
inline Tensor pool3d(const Tensor& x,
- const Array<Expr>& kernel_size,
- const Array<Expr>& stride_size,
- const Array<Expr>& padding_size,
+ const Array<PrimExpr>& kernel_size,
+ const Array<PrimExpr>& stride_size,
+ const Array<PrimExpr>& padding_size,
PoolType pool_type,
bool ceil_mode,
const std::string& layout = "NCDHW",
auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) {
- Array<Expr> eval_range;
+ Array<PrimExpr> eval_range;
int arg_counter = 0;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) == axis)
};
auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
- Array<Expr> non_reduce_indices;
+ Array<PrimExpr> non_reduce_indices;
for (size_t i = 0; i < ndim; ++i) {
if (static_cast<int>(i) != axis)
non_reduce_indices.push_back(indices[i]);
std::string tag = "log_softmax_output") {
CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
- Expr m = x->shape[0];
- Expr n = x->shape[1];
+ PrimExpr m = x->shape[0];
+ PrimExpr n = x->shape[1];
auto k = tvm::reduce_axis(Range(0, n), "k");
auto max_elem = tvm::compute(
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \return A Tensor upsampled to given shape
*/
inline Tensor upsampling(const Tensor& input,
- const Array<Expr> shape,
+ const Array<PrimExpr> shape,
std::string layout = "NCHW",
std::string mode = "NEAREST_NEIGHBOR",
std::string name = "tensor",
using namespace tvm;
/*! \brief The operation to use for CommReduce */
-using FReduce = std::function<Expr(Expr source, const Array<IterVar>& axis)>;
+using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
/*! \brief The operation to use for CommReduceIdx */
using FCommReduce = std::function<
- Array<Expr>(Array<Expr> exprs, const Array<IterVar>& axis, Expr* condition)>;
+ Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition)>;
/*!
* \brief Convert a reduction axis which could be empty or have negative
}
/*! \brief Calculate the target shape for a reduce op */
-inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
+inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis,
const Tensor& data,
bool keepdims,
bool atleast1d) {
auto ndim = data->shape.size();
- Array<Expr> target_shape;
+ Array<PrimExpr> target_shape;
if (keepdims) {
for (size_t i = 0; i < ndim; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
*/
inline Tensor DoCommReduce(const Tensor& data,
FReduce func,
- const Array<Expr>& target_shape,
+ const Array<PrimExpr>& target_shape,
const std::vector<int>& reduce_axes,
const std::vector<int>& squeeze_axes) {
auto r_axes = MakeReduceAxes(reduce_axes, data);
auto compute = [&](const Array<Var>& indices) {
- Array<Expr> eval_range;
+ Array<PrimExpr> eval_range;
Array<Var> eval_indices;
int arg_counter = 0;
int red_counter = 0;
auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
(const Array<Var>& indices) {
- Array<Expr> eval_range;
- Array<Expr> eval_indices;
+ Array<PrimExpr> eval_range;
+ Array<PrimExpr> eval_indices;
int arg_counter = 0;
int red_counter = 0;
}
}
- Array<Expr> ravel_shape;
+ Array<PrimExpr> ravel_shape;
for (auto i : real_axis) {
ravel_shape.push_back(data->shape[i]);
}
}
/*! \brief A combiner function for a reduction */
-using FCombine = std::function<Array<Expr>(Array<Var> lhs, Array<Var> rhs)>;
+using FCombine = std::function<Array<PrimExpr>(Array<Var> lhs, Array<Var> rhs)>;
/*! \brief An initializer function for a reduction */
-using FIdentity = std::function<Array<Expr>(std::vector<DataType> types)>;
+using FIdentity = std::function<Array<PrimExpr>(std::vector<DataType> types)>;
/*!
* \brief Create a commutative reducer for a reduction
FIdentity fidentity,
std::string name = "reduce") {
return [fcombine, fidentity, name]
- (Array<Expr> exprs, const Array<IterVar>& axis, Expr* condition) {
+ (Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition) {
Array<Var> lhs, rhs;
std::vector<DataType> dtypes;
auto cond = condition != nullptr ? *condition : tvm::const_true();
auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem);
- Array<Expr> outputs;
+ Array<PrimExpr> outputs;
for (size_t i = 0; i < exprs.size(); ++i) {
outputs.push_back(
tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
}
/*! \brief Wrap tvm::min to ensure we get the correct overload */
-inline Expr MinOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) {
return tvm::min(source, axis);
}
/*! \brief Wrap tvm::max to ensure we get the correct overload */
-inline Expr MaxOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
return tvm::max(source, axis); // NOLINT(*)
}
/*! \brief Wrap tvm::prod to ensure we get the correct overload */
-inline Expr ProdOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis) {
return tvm::prod(source, axis); // NOLINT(*)
}
return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
}
-inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
+inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
CHECK_GE(data->shape.size(), target_shape.size());
auto ishape = detail::GetConstIntValues(data->shape, "ishape");
auto oshape = detail::GetConstIntValues(target_shape, "oshape");
bool keepdims = false,
bool atleast1d = false) {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
- Array<Expr> result;
+ Array<PrimExpr> result;
result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
- Array<Expr> result;
+ Array<PrimExpr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(tvm::max_value(types[1])); // val
return result;
inline FCommReduce MakeArgmaxReducer() {
auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
- Array<Expr> result;
+ Array<PrimExpr> result;
result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx
result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val
return result;
};
auto fidentity = [](std::vector<DataType> types) {
- Array<Expr> result;
+ Array<PrimExpr> result;
result.push_back(tvm::make_const(types[0], -1)); // idx
result.push_back(tvm::min_value(types[1])); // val
return result;
// Calculate offset from last dimension
axis = ndim + axis + 1;
}
- Array<Expr> new_shape;
+ Array<PrimExpr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(x->shape[i]);
}
return compute(
new_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
}
}
- Array<Expr> new_shape;
+ Array<PrimExpr> new_shape;
for (size_t i = 0; i < axes.size(); ++i) {
int axis = static_cast<int>(axes[i]->value);
int new_axis = axis;
return compute(
new_shape, [&](const Array<Var>& indices) {
- std::vector<Expr> idx;
+ std::vector<PrimExpr> idx;
for (size_t i = 0; i < axes.size(); ++i) {
idx.push_back(1);
}
// Reverse the Input Tensor in the axis specified
return compute(
x->shape, [&](const Array<Var>& indices) {
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
if (i == static_cast<size_t>(axis)) {
real_indices.push_back(x->shape[i] - indices[i] - 1);
* \return A Tensor whose op member is the reshape operation
*/
inline Tensor reshape(const Tensor& x,
- Array<Expr> newshape,
+ Array<PrimExpr> newshape,
std::string name = "T_reshape",
std::string tag = kInjective) {
auto x_shape = x->shape;
- Array<Expr> target_shape;
+ Array<PrimExpr> target_shape;
for (const auto &ele : newshape) {
if (ele.as<IntImmNode>()) {
return compute(
target_shape, [&](const Array<Var>& indices) {
return x(UnravelIndex(
- RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
+ RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
x_shape));
}, name, tag);
}
std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 0; i < ndim; ++i) {
if (axis_set.count(static_cast<int>(i)) == 0) {
out_shape.push_back(x->shape[i]);
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
int flag = 0;
for (size_t i = 0; i < ndim; ++i) {
if (axis_set.count(static_cast<int>(i)) == 0) {
CHECK_LT(axis, inputs[0]->shape.size()) <<
"axis out of bounds";
- Array<Expr> axis_sizes;
+ Array<PrimExpr> axis_sizes;
for (auto t : inputs) {
axis_sizes.push_back(t->shape[axis]);
}
- Expr join_size = axis_sizes[0];
+ PrimExpr join_size = axis_sizes[0];
for (size_t i = 1; i < axis_sizes.size(); ++i) {
join_size += axis_sizes[i];
}
join_size = tvm::ir::Simplify(join_size);
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
}
for (size_t i = 0; i < inputs.size() - 1; ++i) {
ind -= axis_sizes[i];
- Array<Expr> idx;
+ Array<PrimExpr> idx;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
"axis out of bounds";
const int stack_size = static_cast<int>(inputs.size());
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i)
out_shape.push_back(inputs[0]->shape[i]);
out_shape.push_back(stack_size);
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
for (size_t i = 0; i < indices.size(); ++i)
if (i != static_cast<size_t>(axis))
idx.push_back(indices[i]);
begin_ids.push_back(val);
}
- Array< Array<Expr> > out_shapes;
+ Array< Array<PrimExpr> > out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
int out_axis_size;
if (i == begin_ids.size() - 1) {
out_axis_size = begin_ids[i + 1] - begin_ids[i];
}
- Array<Expr> shape;
+ Array<PrimExpr> shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
shape.push_back(x->shape[i]);
}
compute(
out_shapes[i], [&](const Array<Var>& indices) {
auto begin = begin_ids[i];
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(indices[j]);
}
end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
}
// Compute
- Array<Expr> out_shape;
- Array<Expr> begin_expr;
- Array<Expr> strides_expr;
+ Array<PrimExpr> out_shape;
+ Array<PrimExpr> begin_expr;
+ Array<PrimExpr> strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
return compute(
out_shape, [&](const Array<Var>& indices) {
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
std::string mode = "clip",
std::string name = "T_take",
std::string tag = kInjective) {
- Array<Expr> a_shape = a->shape;
- Array<Expr> out_shape = indices->shape;
- Expr a_size = 1;
+ Array<PrimExpr> a_shape = a->shape;
+ Array<PrimExpr> out_shape = indices->shape;
+ PrimExpr a_size = 1;
for (size_t i = 0; i < a_shape.size(); ++i) {
a_size = a_size * a_shape[i];
}
CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
auto length_dim = data->shape[axis];
auto batch_dim = data->shape[1 - axis];
- Array<Expr> out_shape = data->shape;
+ Array<PrimExpr> out_shape = data->shape;
Tensor out = compute(
out_shape, [&](const Array<Var>& out_index) {
- Array<Expr> len_index;
+ Array<PrimExpr> len_index;
auto tid = out_index[axis];
auto bid = out_index[1 - axis];
len_index.push_back(bid);
- Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
- tvm::make_const(data->dtype, mask_value), data(out_index));
+ PrimExpr ret = tvm::if_then_else(
+ tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
+ tvm::make_const(data->dtype, mask_value), data(out_index));
return ret;
}, name, tag);
return out;
auto axis_dim = a->shape[axis];
int indices_len = static_cast<int>(indices->shape.size());
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 0; i < a->shape.size(); ++i) {
if (axis == static_cast<int>(i)) {
for (size_t j = 0; j < indices->shape.size(); ++j) {
if (mode == "clip") {
return compute(
out_shape, [&](const Array<Var>& out_index) {
- Array<Expr> indices_position;
+ Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
"Make sure input indices are in bound";
return compute(
out_shape, [&](const Array<Var>& out_index) {
- Array<Expr> indices_position;
+ Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
} else { // mode == "wrap"
return compute(
out_shape, [&](const Array<Var>& out_index) {
- Array<Expr> indices_position;
+ Array<PrimExpr> indices_position;
for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
indices_position.push_back(out_index[j]);
}
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
<< x->shape.size() << " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
<< x->dtype << " vs " << y->dtype;
- Array<Expr> oshape = x->shape;
+ Array<PrimExpr> oshape = x->shape;
Tensor out;
if (condition->shape.size() != 1) {
<< condition->shape[0] << " vs " << x->shape[0];
out = compute(
oshape, [&](const Array<Var>& indices) {
- Array<Expr> condition_idx{indices[0]};
+ Array<PrimExpr> condition_idx{indices[0]};
return tvm::ir::SelectNode::make(condition(condition_idx) != 0,
x(indices), y(indices));
}, name, tag);
// Calculate offset from last dimension
axis += ndim;
}
- Array<Expr> new_shape;
+ Array<PrimExpr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(x->shape[i]);
}
return compute(
new_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
size_t ndim = x->shape.size();
size_t rdim = reps.size();
size_t tdim = (ndim > rdim) ? ndim : rdim;
- Array<Expr> data_shape;
- Array<Expr> reps_shape;
- Array<Expr> new_shape;
+ Array<PrimExpr> data_shape;
+ Array<PrimExpr> reps_shape;
+ Array<PrimExpr> new_shape;
if (ndim == rdim) {
for (size_t i = 0; i < ndim; ++i) {
data_shape.push_back(x->shape[i]);
} else {
return compute(
new_shape, [&](const Array<Var>& indices) {
- Array<Expr> idx;
+ Array<PrimExpr> idx;
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i)
idx.push_back(indexmod(indices[i], x->shape[i]));
size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
<< "than dimensions of data tensor";
- Array<Expr> out_shape;
+ Array<PrimExpr> out_shape;
for (size_t i = 1; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}
}
return compute(
out_shape, [&](const Array<Var>& out_index) {
- Array<Expr> indices_position;
+ Array<PrimExpr> indices_position;
indices_position.push_back(0);
for (size_t i = 0; i < ndim_i - 1; ++i) {
indices_position.push_back(out_index[i]);
}
- Array<Expr> real_indices;
+ Array<PrimExpr> real_indices;
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(DataType::Int(32), i));
if (indices->dtype.is_int()) {
bool trans_b = false,
std::string name = "T_matmul",
std::string tag = kMatMul) {
- tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
+ tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0],
B->shape[trans_b ? 0 : 1]};
auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
auto l = [&](tvm::Var i, tvm::Var j) {
CHECK_GE(A->shape.size(), axes);
CHECK_GE(B->shape.size(), axes);
- Array<Expr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
+ Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it)
output_shape.push_back(*it);
auto func =
[&A, &B, &iter_vars, axes]
(const Array<Var>& input_indices) {
- Array<Expr> A_indices(
+ Array<PrimExpr> A_indices(
input_indices.begin(),
input_indices.begin() + (A->shape.size() - axes));
for (auto& v : iter_vars)
A_indices.push_back(v);
- Array<Expr> B_indices;
+ Array<PrimExpr> B_indices;
for (auto& v : iter_vars)
B_indices.push_back(v);
*/
inline Tensor tensordot(const Tensor& A,
const tvm::Tensor& B,
- Array<Expr> A_axes,
- Array<Expr> B_axes,
+ Array<PrimExpr> A_axes,
+ Array<PrimExpr> B_axes,
std::string name = "T_tensordot",
std::string tag = kMatMul) {
CHECK_EQ(A_axes.size(), B_axes.size());
auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
- Array<Expr> output_shape;
+ Array<PrimExpr> output_shape;
for (unsigned i = 0; i < A->shape.size(); ++i)
if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
output_shape.push_back(A->shape[i]);
[&A, &B, &iter_vars, A_axes_val, B_axes_val]
(const Array<Var>& input_indices) {
int idx_input = 0;
- Array<Expr> A_indices;
+ Array<PrimExpr> A_indices;
for (unsigned i = 0; i < A->shape.size(); ++i) {
auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
if (axes_pos == A_axes_val.end())
A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
}
- Array<Expr> B_indices;
+ Array<PrimExpr> B_indices;
for (unsigned i = 0; i < B->shape.size(); ++i) {
auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
if (axes_pos == B_axes_val.end())
return compute(output_shape, func, name, tag);
}
-inline Tensor arange(const Expr& start,
- const Expr& stop,
- const Expr& step,
+inline Tensor arange(const PrimExpr& start,
+ const PrimExpr& stop,
+ const PrimExpr& step,
DataType dtype,
std::string name = "T_arange",
std::string tag = kInjective) {
- Expr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil(
+ PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil(
tvm::cast(tvm::DataType::Float(32), stop - start) / step));
- Array<Expr> shape;
+ Array<PrimExpr> shape;
return compute({num_elem}, [&](const Array<Var>& indices) {
return tvm::cast(dtype, start + step * indices[0]);
}, name, tag);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
- Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
+ Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
- Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
- Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
+ Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
+ Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
return src(src_indices);
}, name, tag);
}
const std::string name = "T_shape",
const std::string tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
- Array<Expr> out_shape{ndim};
+ Array<PrimExpr> out_shape{ndim};
return compute(out_shape, [&](const Array<Var>& indices) {
auto idx = indices[0];
- Expr ret = 0;
+ PrimExpr ret = 0;
for (int i = 0; i < ndim; ++i) {
ret = tvm::if_then_else(idx == i, src->shape[i], ret);
}
const std::string& name = "ndarray_size",
const std::string& tag = kInjective) {
int ndim = static_cast<int>(src->shape.size());
- Array<Expr> out_ndarray_size = {1};
+ Array<PrimExpr> out_ndarray_size = {1};
return compute(out_ndarray_size, [&](const Array<Var>& indices) {
- Expr ret = 1;
+ PrimExpr ret = 1;
for (int i = 0; i < ndim; ++i) {
ret *= src->shape[i];
}
* \return one-hot tensor.
*/
inline Tensor one_hot(const Tensor& indices,
- const Expr on_value,
- const Expr off_value,
+ const PrimExpr on_value,
+ const PrimExpr off_value,
int depth,
int axis,
const DataType& dtype,
const std::string name = "T_one_hot",
const std::string tag = kInjective) {
- Array<Expr> oshape;
+ Array<PrimExpr> oshape;
int ndim = indices->shape.size() + 1;
int indices_index = 0;
int true_axis = (axis == -1) ? indices->shape.size() : axis;
}
}
- Expr on_value_cast = cast(dtype, on_value);
- Expr off_value_cast = cast(dtype, off_value);
+ PrimExpr on_value_cast = cast(dtype, on_value);
+ PrimExpr off_value_cast = cast(dtype, off_value);
return compute(oshape, [&](const Array<Var>& iter_vars) {
Array<Var> indices_indices;
for (size_t i = 0; i < iter_vars.size(); i++) {
int out_h = h_in / stride;
int out_w = w_in / stride;
- Array<Expr> out_shape = {batch, out_c, out_h, out_w};
+ Array<PrimExpr> out_shape = {batch, out_c, out_h, out_w};
return reshape(out, out_shape);
}
} // namespace vision
out_shape = tuple(
tvm.ir_pass.Simplify(
(data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n))
- pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr)
+ pad_value = (pad_value if isinstance(pad_value, tvm.expr.PrimExpr)
else tvm.const(pad_value, data.dtype))
def _pad(*indices):
not_zero = []
out : Expr or int
The simplified output
"""
- return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr
+ return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.PrimExpr) else expr
def ravel_index(indices, shape):
*rv = Op(args[0].operator tvm::Tensor(), \
args[1].operator tvm::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::Expr(), \
+ *rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::Tensor(), \
- args[1].operator tvm::Expr()); \
+ args[1].operator tvm::PrimExpr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
- *rv = Op(args[0].operator tvm::Expr(), \
- args[1].operator tvm::Expr()); \
+ *rv = Op(args[0].operator tvm::PrimExpr(), \
+ args[1].operator tvm::PrimExpr()); \
} \
}); \
} else if (args.size() == 3) {
*rv = tensordot(args[0], args[1], args[2]);
} else {
- Array<Expr> axes = args[3];
+ Array<PrimExpr> axes = args[3];
*rv = tensordot(args[0], args[1], args[2], axes);
}
});
B = (tvm.var("B", dtype=dtype) if rhs_shape is None
else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
C = ftopi(A, B)
- if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
- assert(isinstance(C, tvm.expr.Expr))
+ if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr):
+ assert(isinstance(C, tvm.expr.PrimExpr))
return
def gen_operand(shape, low, high, ctx):
# Build the logic and compile the function
A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
B = func(A)
- if isinstance(A, tvm.expr.Expr):
- assert (isinstance(B, tvm.expr.Expr))
+ if isinstance(A, tvm.expr.PrimExpr):
+ assert (isinstance(B, tvm.expr.PrimExpr))
return
def check_device(device):
A = (tvm.var("A", dtype=dtype))
B = (tvm.var("B", dtype=dtype))
C = func(A, B)
- if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
- assert (isinstance(C, tvm.expr.Expr))
+ if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr):
+ assert (isinstance(C, tvm.expr.PrimExpr))
return
def check_device(device):