CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
- } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
- *ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
- } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
- *ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
}
using ContainerType = VarNode;
};
-class Integer;
-/*! \brief ExprNode: constant integer. */
-class IntImmNode : public PrimExprNode {
- public:
- /*! \brief the Internal value. */
- int64_t value;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &dtype);
- v->Visit("value", &value);
- }
-
- TVM_DLL static Integer make(DataType t, int64_t value);
-
- static constexpr const char* _type_key = "IntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
-};
-
/*!
- * \brief Container of constant integer (IntImm).
+ * \brief Container of constant int that adds more constructors.
*
* This is used to store and automate type check
* attributes that must be constant integer.
+ *
+ * \sa IntImm
*/
-class Integer : public PrimExpr {
+class Integer : public IntImm {
public:
- Integer() : PrimExpr() {}
+ Integer() {}
/*!
* \brief constructor from node.
*/
- explicit Integer(ObjectPtr<Object> node) : PrimExpr(node) {}
+ explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
/*!
* \brief Construct integer from int value.
*/
- Integer(int value) : PrimExpr(value) {} // NOLINT(*)
+ Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*)
+ /*!
+ * \brief Construct integer from int imm.
+ * \param other The other value.
+ */
+ Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
- Integer& operator=(const Integer& other) {
- data_ = other.data_;
+ Integer& operator=(const IntImm& other) {
+ data_ = ObjectRef::GetDataPtr<Object>(other);
return *this;
}
/*!
- * \brief Get pointer to the internal value.
- * \return the content of the integer.
- */
- const IntImmNode* operator->() const {
- return static_cast<const IntImmNode*>(get());
- }
- /*!
* \brief convert to int64_t
*/
operator int64_t() const {
<< " Trying to reference a null Integer";
return (*this)->value;
}
- /*! \brief type indicate the container type */
- using ContainerType = IntImmNode;
};
/*! \brief range over one dimension */
#include <algorithm>
#include <type_traits>
+#include <limits>
#include "expr.h"
#include "ir.h"
}
/*!
- * \brief Get x as constant uint expression.
- * \param x The expression
- * \return the address to the int expression,
- * return nullptr, if x is not UIntImm.
- */
-inline const uint64_t* as_const_uint(const PrimExpr& x) {
- if (!x.defined()) return nullptr;
- if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
- return &(op->value);
- } else {
- return nullptr;
- }
-}
-
-/*!
* \brief Check whether x is a constant integer expression.
* \param x The input argument
* \param value the value to be compared against.
*/
TVM_DLL PrimExpr trunc(PrimExpr x);
+/*!
+ * \brief Construct a large uint constant by its low 32 bits and high 32bits.
+ * \param dtype The final data type.
+ * \param low The lower 32 bits.
+ * \param high The higher 32 bits.
+ * \return The constructed expression.
+ */
+TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
+
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x) { \
// Implementation details after this
inline bool is_const(const PrimExpr& x) {
- if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
+ if (x.as<ir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
- if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
+ if (val.as<ir::IntImmNode>()) {
return true;
}
}
inline bool is_positive_const(const PrimExpr& a) {
if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
return op->value > 0;
- } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
- return op->value > 0;
} else {
return false;
}
inline bool is_const_int(const PrimExpr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return op->value == value;
- } else if (const auto* op = x.as<ir::UIntImmNode>()) {
- return op->value == static_cast<uint64_t>(value);
} else if (const auto* op = x.as<ir::BroadcastNode>()) {
const PrimExpr& val = op->value;
if (const auto* opv = val.as<ir::IntImmNode>()) {
return opv->value == value;
- } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
- return opv->value == static_cast<uint64_t>(value);
}
}
return false;
template<typename ValueType>
inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
- if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
- if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
+ if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
+ if (t.is_uint()) {
+ // Use IntImm if it is a small integer
+ uint64_t uval = static_cast<uint64_t>(value);
+ if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
+ return IntImm(t, static_cast<int64_t>(value));
+ } else {
+ uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
+ uint64_t low = uval & mask;
+ uint64_t high = uval >> 32U;
+ return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
+ }
+ }
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
- if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
+ if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value));
+ }
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
}
using IntImmNode = tvm::IntImmNode;
using VarNode = tvm::VarNode;
-/*! \brief constant unsigned integer. */
-class UIntImmNode : public PrimExprNode {
- public:
- /*! \brief The constant value content. */
- uint64_t value;
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("dtype", &dtype);
- v->Visit("value", &value);
- }
-
- TVM_DLL static PrimExpr make(DataType t, uint64_t value);
-
- static constexpr const char* _type_key = "UIntImm";
- TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
-};
-
/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
/*!
* \brief See pesudo code
*
+ * Construct a big uint that may not be representable by int64
+ *
+ * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) {
+ * return (v1 << 32) | v0;
+ * }
+ */
+constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm";
+/*!
+ * \brief See pesudo code
+ *
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
};
/*!
+ * \brief Constant integer literals in the program.
+ * \sa IntImm
+ */
+class IntImmNode : public PrimExprNode {
+ public:
+ /*! \brief the Internal value. */
+ int64_t value;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("dtype", &dtype);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "IntImm";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference class to IntImmNode.
+ *
+ * \sa IntImmNode
+ */
+class IntImm : public PrimExpr {
+ public:
+ /*!
+ * \brief Constructor
+ */
+ IntImm() {}
+ /*!
+ * \brief constructor from node.
+ */
+ explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
+ /*!
+ * \brief Constructor.
+ * \param dtype The data type of the value.
+ * \param value The internal value.
+ */
+ TVM_DLL IntImm(DataType dtype, int64_t value);
+ /*!
+ * \brief Get pointer to the internal value.
+ * \return the content of the integer.
+ */
+ const IntImmNode* operator->() const {
+ return static_cast<const IntImmNode*>(get());
+ }
+ /*! \brief type indicate the container type */
+ using ContainerType = IntImmNode;
+};
+
+/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
- virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Object* op, Args ...) {
IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
- IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
return vtable;
void VisitExpr_(const BroadcastNode* op) override;
void VisitExpr_(const ShuffleNode* op) override;
void VisitExpr_(const IntImmNode* op) override;
- void VisitExpr_(const UIntImmNode* op) override;
void VisitExpr_(const FloatImmNode* op) override;
void VisitExpr_(const StringImmNode* op) override;
};
PrimExpr VisitExpr_(const BroadcastNode* op) override;
PrimExpr VisitExpr_(const ShuffleNode* op) override;
PrimExpr VisitExpr_(const IntImmNode* op) override;
- PrimExpr VisitExpr_(const UIntImmNode* op) override;
PrimExpr VisitExpr_(const FloatImmNode* op) override;
PrimExpr VisitExpr_(const StringImmNode* op) override;
};
"""
if dtype is None:
dtype = _scalar_type_inference(value)
+ if dtype == "uint64" and value >= (1 << 63):
+ return _api_internal._LargeUIntImm(
+ dtype, value & ((1 << 32) - 1), value >> 32)
return _api_internal._const(value, dtype)
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
workload = x
- elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
+ elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
if len(source) != 1:
raise FlopCalculationError("Found multiple output in the source of reduce op")
return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
- if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
+ if isinstance(exp, (expr.FloatImm, expr.IntImm)):
return 0
if isinstance(exp, expr.Cast):
return _count_flop(exp.value)
"""
if isinstance(exp, int):
return exp
- if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
+ if not isinstance(exp, (expr.IntImm,)):
exp = ir_pass.Simplify(exp)
- if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
+ if not isinstance(exp, (expr.IntImm,)):
raise ValueError("Expect value to be constant int")
return exp.value
for elem in in_tuple:
if isinstance(elem, expr.Var):
ret.append(elem)
- elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
+ elif not isinstance(elem, (expr.IntImm, int)):
elem = ir_pass.Simplify(elem)
- if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
+ if not isinstance(elem, (expr.IntImm)):
ret.append(elem)
else:
ret.append(get_const_int(elem))
@register_object
-class UIntImm(ConstExpr):
- """UInt constant.
-
- Parameters
- ----------
- dtype : str
- The data type
-
- value : int
- The constant value.
- """
- def __init__(self, dtype, value):
- self.__init_handle_by_constructor__(
- _make.UIntImm, dtype, value)
-
-
-@register_object
class StringImm(ConstExpr):
"""String constant.
if args.__len__() == 0:
res = _tgt.current_target().max_num_threads
else:
- _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
+ _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
res = _tgt.current_target(args[0].value).max_num_threads
return _api.convert(res)
if isinstance(i, numbers.Integral):
arr = arr[i]
else:
- _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
+ _internal_assert(isinstance(i, (_expr.IntImm,)), \
"All indices are supposed to be constants")
arr = arr[i.value]
return arr
cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
# Return no IfThenElse if proven
- if isinstance(cond, _expr.UIntImm):
+ if isinstance(cond, _expr.IntImm):
if cond.value:
return visit_list_to_block(self.visit, node.body)
if node.orelse:
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
-halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
+halide_imm_types = (_expr.IntImm, _expr.FloatImm)
def _internal_assert(cond, err):
def _impl(inputs, attr, params):
is_symbolic_shape = False
for axis in attr['_input_shapes'][inputs[0]]:
- if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(axis, (int, tvm.expr.IntImm)):
is_symbolic_shape = True
break
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
-REGISTER_MAKE(IntImm);
-REGISTER_MAKE(UIntImm);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm);
}
});
+TVM_REGISTER_GLOBAL("_LargeUIntImm")
+.set_body_typed(LargeUIntImm);
+
TVM_REGISTER_GLOBAL("_str")
.set_body_typed(ir::StringImmNode::make);
}
bool Analyzer::CanProve(const PrimExpr& expr) {
- if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
+ if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
}
auto res = this->rewrite_simplify(expr);
- if (const auto* ptr = res.as<ir::UIntImmNode>()) {
+ if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
res = this->canonical_simplify(expr);
- if (const auto* ptr = res.as<ir::UIntImmNode>()) {
+ if (const auto* ptr = res.as<IntImmNode>()) {
return ptr->value != 0;
}
return false;
// const folding
PrimExpr const_res = TryConstFold<DivNode>(a, b);
if (const_res.defined()) return const_res;
- PVar<Integer> c1;
+ PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
// const folding
PrimExpr const_res = TryConstFold<FloorDivNode>(a, b);
if (const_res.defined()) return const_res;
- PVar<Integer> c1;
+ PVar<IntImm> c1;
// x / c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
PrimExpr const_res = TryConstFold<ModNode>(a, b);
if (const_res.defined()) return const_res;
- PVar<Integer> c1;
+ PVar<IntImm> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
PrimExpr const_res = TryConstFold<FloorModNode>(a, b);
if (const_res.defined()) return const_res;
- PVar<Integer> c1;
+ PVar<IntImm> c1;
// x % c1
if (c1.Match(b) && c1.Eval()->value > 0) {
int64_t cval = c1.Eval()->value;
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
- using ir::IntImmNode; \
- using ir::UIntImmNode; \
using ir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
const IntImmNode* pb = b.as<IntImmNode>(); \
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
- using ir::IntImmNode; \
- using ir::UIntImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
const IntImmNode* pb = b.as<IntImmNode>(); \
const DataType& ta = a.dtype(); \
inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
+ if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
+ if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a;
inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
+ if (pa && pb) return IntImm(rtype, pa->value * pb->value);
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImmNode::make(rtype, pa->value / pb->value);
+ return IntImm(rtype, pa->value / pb->value);
}
if (pa) {
if (pa->value == 0) return a;
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
- return IntImmNode::make(rtype, pa->value % pb->value);
+ return IntImm(rtype, pa->value % pb->value);
}
if (pa) {
if (pa->value == 0) return a;
const DataType& rtype = a.dtype();
if (pa && pb) {
CHECK_NE(pb->value, 0) << "Divide by zero";
- return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value));
+ return IntImm(rtype, arith::floordiv(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) {
- return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value));
+ return IntImm(rtype, arith::floormod(pa->value, pb->value));
}
if (pa) {
if (pa->value == 0) return a;
inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
+ if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
+ if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
template<>
inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
- if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
+ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
+ if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
});
return PrimExpr();
}
template<>
inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
- using ir::UIntImmNode;
- const UIntImmNode* pa = a.as<UIntImmNode>();
- const UIntImmNode* pb = b.as<UIntImmNode>();
+ const IntImmNode* pa = a.as<IntImmNode>();
+ const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return b;
if (pa && !pa->value) return a;
if (pb && pb->value) return a;
template<>
inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
- using ir::UIntImmNode;
- const UIntImmNode* pa = a.as<UIntImmNode>();
- const UIntImmNode* pb = b.as<UIntImmNode>();
+ const IntImmNode* pa = a.as<IntImmNode>();
+ const IntImmNode* pb = b.as<IntImmNode>();
if (pa && pa->value) return a;
if (pa && !pa->value) return b;
if (pb && pb->value) return b;
template<>
inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) {
- using ir::UIntImmNode;
- const UIntImmNode* pa = a.as<UIntImmNode>();
+ const IntImmNode* pa = a.as<IntImmNode>();
if (pa) {
- return UIntImmNode::make(DataType::UInt(1), !(pa->value));
+ return IntImm(DataType::UInt(1), !(pa->value));
}
return PrimExpr();
}
return MakeBound(op->value, op->value);
}
- Entry VisitExpr_(const UIntImmNode* op) final {
- if (op->value <= static_cast<uint64_t>(kPosInf)) {
- return MakeBound(op->value, op->value);
- } else {
- return Everything(op->dtype);
- }
- }
-
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
*/
static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
PVar<PrimExpr> x, y;
- PVar<Integer> c;
+ PVar<IntImm> c;
// NOTE: canonical form always use <= or <
if ((c <= x).Match(cond)) {
return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
}
- IntervalSet VisitExpr_(const UIntImmNode* op) final {
- return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
- }
-
IntervalSet VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = dom_map_.find(var);
IntervalSet VisitExpr_(const RampNode* op) final {
CHECK(eval_vec_);
IntervalSet base = Eval(op->base);
- PVar<Integer> stride;
+ PVar<IntImm> stride;
if (stride.Match(op->stride)) {
DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
// Detect useful constraints and use them in the analysis scope.
std::function<void()> EnterConstraint(const PrimExpr& constraint) {
PVar<Var> var;
- PVar<Integer> coeff, base;
+ PVar<IntImm> coeff, base;
// pattern match interesting constraints
if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
return Entry(0, op->value);
}
- Entry VisitExpr_(const UIntImmNode* op) final {
- if (op->value < std::numeric_limits<int64_t>::max()) {
- return Entry(0, static_cast<int>(op->value));
- } else {
- return Everything();
- }
- }
-
Entry VisitExpr_(const AddNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
* }
*
* tvm::Var tx, ty;
- * arith::PVar<Integer> c;
+ * arith::PVar<IntImm> c;
* arith::PVar<Var> v;
* // We can match integer and Var, both of which are
* // special case container of Expr
};
template<>
-class PEqualChecker<Integer> {
+class PEqualChecker<IntImm> {
public:
- bool operator()(const Integer& lhs, const Integer& rhs) const {
+ bool operator()(const IntImm& lhs, const IntImm& rhs) const {
return lhs->value == rhs->value;
}
};
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2, c3;
+ PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2, c3;
+ PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
- PVar<Integer> c1, c2, c3;
+ PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
- PVar<Integer> c1, c2, c3;
+ PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
PVar<int> lanes;
// vector rule
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
PVar<int> lanes;
// vector rule
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
- PVar<Integer> c1;
+ PVar<IntImm> c1;
PVar<int> lanes;
// vector rule
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
PVar<int> lanes;
// vector rule
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
- PVar<Integer> c1, c2;
+ PVar<IntImm> c1, c2;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
Array<PrimExpr> attr{std::string("_attr_"),
FloatImmNode::make(DataType::Float(32), trans(fea.length)),
- IntImmNode::make(DataType::Int(32), fea.nest_level),
+ IntImm(DataType::Int(32), fea.nest_level),
FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
};
}
}
-inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
- if (op->dtype == DataType::UInt(32)) {
+
+inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+ if (dtype == DataType::UInt(32)) {
std::ostringstream temp;
- temp << op->value << "U";
+ temp << val << "U";
p->MarkConst(temp.str());
os << temp.str();
} else {
os << "(";
- p->PrintType(op->dtype, os);
- os << ")" << op->value;
+ p->PrintType(dtype, os);
+ os << ")" << val;
}
}
void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
-void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*)
- PrintConst(op, os, this);
-}
+
void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
os << ")";
} else if (op->is_intrinsic(CallNode::bitwise_and)) {
PrintBinaryIntrinsic(op, " & ", os, this);
+ } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+ CHECK_EQ(op->args.size(), 2U);
+ uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+ uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+ uint64_t val = (high << 32U) | low;
+ PrintUIntConst(op->dtype, val, os, this);
} else if (op->is_intrinsic(CallNode::bitwise_xor)) {
PrintBinaryIntrinsic(op, " ^ ", os, this);
} else if (op->is_intrinsic(CallNode::bitwise_or)) {
void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
CodeGenC::VisitExpr_(op, os);
}
-void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
- CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
- CodeGenC::VisitExpr_(op, os);
-}
-
void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
CodeGenC::VisitExpr_(op, os);
// Codegen for immediate values
void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)
- void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*)
llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- op->args[0].as<UIntImmNode>()->value);
+ Downcast<IntImm>(op->args[0])->value);
if (id == ::llvm::Intrinsic::ctpop) {
PrimExpr e = ARMPopcount(op);
return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<PrimExpr> vcnt_args;
- vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
- vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+ vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
+ vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e);
return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
}
const CallNode* c0 = input8.as<CallNode>();
CHECK(c0 != nullptr);
Array<PrimExpr> vcnt8_args;
- vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
- vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+ vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
+ vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
PrimExpr vcnt8 = ir::CallNode::make(
uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
// Accumulation 8->16bit
Array<PrimExpr> vcnt16_args;
- vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
- vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+ vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+ vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
PrimExpr vcnt16 = ir::CallNode::make(
uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
// Accumulation 16->32bit
Array<PrimExpr> vcnt32_args;
- vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
- vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+ vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+ vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
PrimExpr vcnt32 = ir::CallNode::make(
uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
// Accumulation 32->64bit
Array<PrimExpr> vcnt64_args;
- vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
- vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+ vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+ vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return ir::CallNode::make(
call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
- op->args[0].as<UIntImmNode>()->value);
- const uint64_t *num_signature = as_const_uint(op->args[1]);
- CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
- << "but " << op->args[1] << " got!\n";
+ Downcast<IntImm>(op->args[0])->value);
+ int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> sig_type;
for (size_t i = 2; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
- if (i - 2 < *num_signature) {
+ if (i - 2 < static_cast<size_t>(num_signature)) {
sig_type.push_back(arg_value.back()->getType());
}
}
return llvm::Constant::getNullValue(t_void_p_);
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
return builder_->CreateIsNull(MakeValue(op->args[0]));
+ } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+ CHECK_EQ(op->args.size(), 2U);
+ uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+ uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+ uint64_t val = (high << 32U) | low;
+ return llvm::ConstantInt::get(LLVMType(op->dtype), val);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
CHECK_EQ(op->args[0].dtype().lanes(), 1)
<< "if_then_else can only take scalar condition";
return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
}
-llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) {
- return llvm::ConstantInt::get(LLVMType(op->dtype), op->value);
-}
-
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
}
llvm::Value* VisitExpr_(const VarNode* op) override;
llvm::Value* VisitExpr_(const CastNode* op) override;
llvm::Value* VisitExpr_(const IntImmNode* op) override;
- llvm::Value* VisitExpr_(const UIntImmNode* op) override;
llvm::Value* VisitExpr_(const FloatImmNode* op) override;
llvm::Value* VisitExpr_(const StringImmNode* op) override;
llvm::Value* VisitExpr_(const AddNode* op) override;
MakeValue(
ir::BroadcastNode::make(
ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
- /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
- /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
+ /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
+ /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
});
}
CHECK(call != nullptr);
Array<PrimExpr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
- cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
+ cargs.push_back(IntImm(DataType::UInt(32), id));
+ cargs.push_back(IntImm(DataType::UInt(32), num_signature));
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
CHECK(call != nullptr);
Array<PrimExpr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
- cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
+ cargs.push_back(IntImm(DataType::UInt(32), id));
+ cargs.push_back(IntImm(DataType::UInt(32), num_signature));
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
}
-spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) {
- return builder_->UIntImm(builder_->GetSType(op->dtype), op->value);
-}
-
spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
}
spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
if (op->is_intrinsic("spirv_glsl450")) {
CHECK_GE(op->args.size(), 2U);
- uint32_t inst_id = op->args[0].as<UIntImmNode>()->value;
+ uint32_t inst_id = static_cast<uint32_t>(
+ op->args[0].as<IntImmNode>()->value);
std::vector<spirv::Value> values;
for (size_t i = 1; i < op->args.size(); ++i) {
values.push_back(MakeValue(op->args[i]));
} else if (op->is_intrinsic(CallNode::reinterpret)) {
return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
MakeValue(op->args[0]));
+ } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+ CHECK_EQ(op->args.size(), 2U);
+ uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+ uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+ uint64_t val = (high << 32U) | low;
+ return builder_->UIntImm(builder_->GetSType(op->dtype), val);
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return this->CreateStorageSync(op);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
spirv::Value VisitExpr_(const VarNode* op) override;
spirv::Value VisitExpr_(const CastNode* op) override;
spirv::Value VisitExpr_(const IntImmNode* op) override;
- spirv::Value VisitExpr_(const UIntImmNode* op) override;
spirv::Value VisitExpr_(const FloatImmNode* op) override;
spirv::Value VisitExpr_(const StringImmNode* op) override;
spirv::Value VisitExpr_(const AddNode* op) override;
CHECK(call != nullptr);
Array<PrimExpr> cargs;
// intrin id.
- cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+ cargs.push_back(IntImm(DataType::UInt(32), id));
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
if (dtype.type == DataType::UInt(1)) {
// bool types.
if (*pvalue) {
- ib_.Begin(spv::OpConstantTrue).AddSeq(ret);
+ ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
} else {
- ib_.Begin(spv::OpConstantFalse).AddSeq(ret);
+ ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret);
}
} else {
// Integral/floating-point types.
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
}
-void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) {
- CHECK(op->value <= std::numeric_limits<int>::max())
- << "Int constant exceed bound";
- this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
-}
-
void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
LOG(FATAL) << "Float Imm is not supported";
}
void VisitExpr_(const RampNode* op) final;
void VisitExpr_(const BroadcastNode* op) final;
void VisitExpr_(const IntImmNode* op) final;
- void VisitExpr_(const UIntImmNode* op) final;
void VisitExpr_(const FloatImmNode* op) final;
void VisitExpr_(const StringImmNode* op) final;
// statment
void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*)
os << op->value;
}
-void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*)
- PrintType(op->dtype, os);
- os << "(" << op->value << ")";
-}
+
void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << std::setprecision(20) << op->value << ")";
void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*)
- void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*)
// statment
namespace tvm {
+IntImm::IntImm(DataType dtype, int64_t value) {
+ CHECK(dtype.is_scalar())
+ << "ValueError: IntImm can only take scalar.";
+ CHECK(dtype.is_int() || dtype.is_uint())
+ << "ValueError: IntImm can only take scalar.";
+ if (dtype.is_uint()) {
+ CHECK_GE(value, 0U);
+ }
+ ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
+ node->dtype = dtype;
+ node->value = value;
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("make.IntImm")
+.set_body_typed([](DataType dtype, int64_t value) {
+ return IntImm(dtype, value);
+});
+
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
- virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions.
ATTR_FUNCTOR_DISPATCH(StrMapNode);
ATTR_FUNCTOR_DISPATCH(ArrayNode);
ATTR_FUNCTOR_DISPATCH(IntImmNode);
- ATTR_FUNCTOR_DISPATCH(UIntImmNode);
ATTR_FUNCTOR_DISPATCH(FloatImmNode);
ATTR_FUNCTOR_DISPATCH(StringImmNode);
ATTR_FUNCTOR_DISPATCH(VarNode);
bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final;
bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final;
protected:
size_t VisitAttrDefault_(const Object* lhs) final;
size_t VisitAttr_(const ir::IntImmNode* lhs) final;
- size_t VisitAttr_(const ir::UIntImmNode* lhs) final;
size_t VisitAttr_(const ir::FloatImmNode* lhs) final;
size_t VisitAttr_(const ir::StringImmNode* lhs) final;
size_t VisitAttr_(const ArrayNode* lhs) final;
return false;
}
-bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
- if (const auto* rhs = other.as<UIntImmNode>()) {
- return lhs->value == rhs->value;
- }
- return false;
-}
-
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
return std::hash<int64_t>()(op->value);
}
-size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
- return std::hash<uint64_t>()(op->value);
-}
-
size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
return std::hash<double>()(op->value);
}
namespace tvm {
PrimExpr::PrimExpr(int32_t value)
- : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {}
+ : PrimExpr(IntImm(DataType::Int(32), value)) {}
PrimExpr::PrimExpr(float value)
: PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
is_zero(begin) ? end : (end - begin))) {
}
-Integer IntImmNode::make(DataType t, int64_t value) {
- CHECK(t.is_int() && t.is_scalar())
- << "ValueError: IntImm can only take scalar.";
- ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
- node->dtype = t;
- node->value = value;
- return Integer(node);
-}
-
Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}
return ir::CastNode::make(t, value);
}
+PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
+ return ir::CallNode::make(
+ t, ir::intrinsic::tvm_large_uint_imm,
+ {make_const(DataType::UInt(32), low),
+ make_const(DataType::UInt(32), high)},
+ ir::CallNode::PureIntrinsic);
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
}
}
-
// maximum and min limits
PrimExpr max_value(const DataType& dtype) {
using namespace ir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
if (dtype.bits() == 64) {
- return IntImmNode::make(dtype, std::numeric_limits<int64_t>::max());
+ return IntImm(dtype, std::numeric_limits<int64_t>::max());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = (val << (dtype.bits() - 1)) - 1;
- return IntImmNode::make(dtype, val);
+ return IntImm(dtype, val);
}
} else if (dtype.is_uint()) {
if (dtype.bits() == 64) {
- return UIntImmNode::make(dtype, std::numeric_limits<uint64_t>::max());
+ return make_const(dtype, std::numeric_limits<uint64_t>::max());
} else if (dtype.bits() < 64) {
uint64_t val = 1;
val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
- return UIntImmNode::make(dtype, val);
+ return IntImm(dtype, static_cast<int64_t>(val));
}
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
if (dtype.bits() == 64) {
- return IntImmNode::make(dtype, std::numeric_limits<int64_t>::lowest());
+ return IntImm(dtype, std::numeric_limits<int64_t>::lowest());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = -(val << (dtype.bits() - 1));
- return IntImmNode::make(dtype, val);
+ return IntImm(dtype, val);
}
} else if (dtype.is_uint()) {
- return UIntImmNode::make(dtype, 0);
+ return IntImm(dtype, 0);
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
if (const auto* op = x.as<ir::IntImmNode>()) {
return ConstPowerHelper(op->value, shift);
- } else if (const auto* op = x.as<ir::UIntImmNode>()) {
- return ConstPowerHelper(op->value, shift);
} else {
return false;
}
}
PrimExpr cast(const DataType& t, PrimExpr value) {
- using ir::IntImmNode;
- using ir::UIntImmNode;
using ir::FloatImmNode;
if (value.dtype() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
return make_const(t, op->value);
- } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
- return make_const(t, op->value);
} else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
return make_const(t, op->value);
}
if (value.dtype() != vtype) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
value = make_const(vtype, op->value);
- } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
- return make_const(t, op->value);
} else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
value = make_const(vtype, op->value);
} else {
using ir::FloatImmNode;
const IntImmNode* pa = a.as<IntImmNode>();
const FloatImmNode* fa = a.as<FloatImmNode>();
- if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value);
+ if (pa) return IntImm(a.dtype(), -pa->value);
if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
return make_zero(a.dtype()) - a;
}
}
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
- using ir::IntImmNode;
- using ir::UIntImmNode;
CHECK(cond.dtype() == DataType::Bool(1))
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value);
- if (const UIntImmNode* op = cond.as<UIntImmNode>()) {
- if (op->value != 0) {
- return true_value;
- } else {
- return false_value;
- }
- } else if (const IntImmNode* op = cond.as<IntImmNode>()) {
+ if (const IntImmNode* op = cond.as<IntImmNode>()) {
if (op->value != 0) {
return true_value;
} else {
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value));
+ if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
}
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value));
+ if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
}
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value));
+ if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
});
return ir::CallNode::make(
a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value));
+ if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
});
return ir::CallNode::make(
a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
- if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value));
+ if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
});
return ir::CallNode::make(
a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
using ir::IntImmNode;
const IntImmNode* px = x.as<IntImmNode>();
if (px) {
- return ir::IntImmNode::make(x.dtype(), std::abs(px->value));
+ return IntImm(x.dtype(), std::abs(px->value));
}
return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x);
} else if (x.dtype().is_float()) {
namespace ir {
// constructors
-PrimExpr UIntImmNode::make(DataType t, uint64_t value) {
- CHECK(t.is_uint() && t.lanes() == 1)
- << "ValueError: UIntImm can only take scalar";
- ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
- node->dtype = t;
- node->value = value;
- return PrimExpr(node);
-}
PrimExpr FloatImmNode::make(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
int index = 0;
for (const PrimExpr& e : vectors) {
for (int i = 0; i < e.dtype().lanes(); ++i) {
- indices.push_back(IntImmNode::make(DataType::Int(32), index++));
+ indices.push_back(IntImm(DataType::Int(32), index++));
}
}
return make(vectors, indices);
}
// Printers
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<UIntImmNode>([](const ObjectRef& node, NodePrinter* p) {
- auto* op = static_cast<const UIntImmNode*>(node.get());
- p->stream << "(" << op->dtype << ")" << op->value;
- });
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_REGISTER_NODE_TYPE(AttrStmtNode);
TVM_REGISTER_NODE_TYPE(FloatImmNode);
TVM_REGISTER_NODE_TYPE(IntImmNode);
-TVM_REGISTER_NODE_TYPE(UIntImmNode);
TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_REGISTER_NODE_TYPE(CastNode);
TVM_REGISTER_NODE_TYPE(VarNode);
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
- UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
+ IntImm(DataType::UInt(8), dtype.code()) &&
TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
- UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
+ IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
- UIntImmNode::make(DataType::UInt(16), dtype.lanes()));
+ IntImm(DataType::UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmtNode::make(
vptr, ir::attr::storage_alignment,
- IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop));
+ IntImm(DataType::Int(32), buffer->data_alignment), nop));
}
Var v_shape(arg_name + ".shape", DataType::Handle());
Bind_(buffer->shape[k],
cast(buffer->shape[k].dtype(),
LoadNode::make(tvm_shape_type, v_shape,
- IntImmNode::make(DataType::Int(32), k), const_true(1))),
+ IntImm(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
PrimExpr svalue = cast(
stype,
LoadNode::make(tvm_shape_type, v_strides,
- IntImmNode::make(DataType::Int(32), k), const_true(1)));
+ IntImm(DataType::Int(32), k), const_true(1)));
conds.push_back(expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
field_name << v_strides->name_hint << '[' << k << ']';
PrimExpr value = cast(buffer->shape[k].dtype(),
LoadNode::make(tvm_shape_type, v_strides,
- IntImmNode::make(DataType::Int(32), k), const_true(1)));
+ IntImm(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
Bind_(buffer->strides[k],
cast(buffer->shape[k].dtype(),
LoadNode::make(tvm_shape_type, v_strides,
- IntImmNode::make(DataType::Int(32), k), const_true(1))),
+ IntImm(DataType::Int(32), k), const_true(1))),
field_name.str(), true);
}
}
CompareValue(op->value, other.as<IntImmNode>()->value);
}
- void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final {
- CompareValue(op->value, other.as<UIntImmNode>()->value);
- }
-
void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final {
CompareValue(op->value, other.as<FloatImmNode>()->value);
}
DEFINE_BINOP_VISIT_(OrNode);
void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
-void ExprVisitor::VisitExpr_(const UIntImmNode* op) {}
void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
}
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
if (const IntImmNode* op = a.as<IntImmNode>()) {
return op->value == b.as<IntImmNode>()->value;
}
- if (const UIntImmNode* op = a.as<UIntImmNode>()) {
- return op->value == b.as<UIntImmNode>()->value;
- }
return false;
}
PrimExpr VisitExpr_(const MaxNode* op) final {
using namespace arith;
PVar<PrimExpr> x, y;
- PVar<Integer> c;
+ PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) &&
c.Eval()->value >= 0 &&
const CommReducerNode *combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
- const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
+ const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
CHECK(size_of_args) << call->args[0]->GetTypeKey();
CHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
{cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes),
- IntImmNode::make(DataType::Int(32), op->dtype.code()),
- IntImmNode::make(DataType::Int(32), op->dtype.bits())},
+ IntImm(DataType::Int(32), op->dtype.code()),
+ IntImm(DataType::Int(32), op->dtype.bits())},
CallNode::Extern),
body);
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args,
- IntImmNode::make(DataType::Int(32), i),
- IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
+ IntImm(DataType::Int(32), i),
+ IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
PrimExpr res = CallNode::make(
seq_init.emplace_back(LetStmtNode::make(
tcode, LoadNode::make(
DataType::Int(32), v_packed_arg_type_ids,
- IntImmNode::make(DataType::Int(32), i), const_true(1)),
+ IntImm(DataType::Int(32), i), const_true(1)),
nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
return false;
}
bool VisitExpr_(const VarNode* op) final { return false; }
- bool VisitExpr_(const UIntImmNode* op) final { return false; }
bool VisitExpr_(const IntImmNode* op) final { return false; }
bool VisitExpr_(const FloatImmNode* op) final { return false; }
bool VisitExpr_(const StringImmNode* op) final { return false; }
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, bi.shape[j]);
}
strides = bi.strides;
} else {
for (size_t i = 1; i < bi.shape.size(); ++i) {
- PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = bi.shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, bi.shape[j]);
}
op = expr.as<VarNode>();
if (op != nullptr) {
if (op->name_hint == "threadIdx.x") {
- PrimExpr zero = IntImmNode::make(DataType::Int(32), 0);
+ PrimExpr zero = IntImm(DataType::Int(32), 0);
return zero;
}
if (op->name_hint == "threadIdx.y") {
PrimExpr stride = strides[strides.size()-2];
// thread index unification inside a warp
- PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+ PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
PrimExpr mutated_value = thread_idx_mutator(op->value);
PrimExpr src = CallNode::make(value->dtype,
PrimExpr dst = it3->second;
// thread index unification inside a warp
- PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+ PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
ThreadIdxMutator thread_idx_mutator(warp_y);
dst = thread_idx_mutator(dst);
dst = CallNode::make(DataType::Handle(),
Array<PrimExpr> strides;
for (size_t i = 1; i < shape.size(); ++i) {
- PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+ PrimExpr stride = IntImm(DataType::Int(32), 1);
for (size_t j = shape.size() - 1; j >= i; --j) {
stride = MulNode::make(stride, shape[j]);
}
}
strides.push_back(make_const(DataType::Int(32), 1));
- PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0);
+ PrimExpr elem_offset = IntImm(DataType::Int(32), 0);
CHECK_EQ(call->args.size(), min_bound.size());
for (size_t i = 0; i < min_bound.size(); i++) {
elem_offset = AddNode::make(
// constant folding.
PrimExpr extent = ir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
- const UIntImmNode *v2 = extent.as<UIntImmNode>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
- if (v2 != nullptr) {
- value = static_cast<int>(v2->value);
- }
return value;
}
if (pval != nullptr) {
CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
- res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval));
+ res.push_back(IntImm(DataType::Int(32), *pval));
} else if (val->IsInstance<ir::AnyNode>()) {
res.push_back(val.as<ir::AnyNode>()->ToVar());
} else {
// set inputs
for (auto param : prim_func->params) {
int state = param_states_[param];
- cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state));
+ cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
if (state & kNeedInputData) {
for (auto t : param_data_[param]) {
cache_node->inputs.push_back(t);
auto ret_type = call_node->checked_type();
Array<IndexExpr> out_ndims;
if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
- out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
+ out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
} else {
auto rtype = ret_type.as<TupleTypeNode>();
// TODO(@icemelon): Allow recursive tuple
for (size_t i = 0; i < rtype->fields.size(); ++i) {
auto ttype = rtype->fields[i].as<TensorTypeNode>();
CHECK(ttype);
- out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
+ out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
}
}
// Call shape function
CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
shape.push_back(
- tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i]));
+ tvm::IntImm(DataType::Int(32), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
return PrintConstScalar(op->dtype, &(op->value));
}
- Doc VisitAttr_(const ir::UIntImmNode* op) final {
- return PrintConstScalar(op->dtype, &(op->value));
- }
-
Doc VisitAttr_(const ir::FloatImmNode* op) final {
return PrintConstScalar(op->dtype, &(op->value));
}
const auto& input_rank = input_shape.size();
std::vector<IndexExpr> result_shape;
result_shape.push_back(Any::make());
- result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank));
+ result_shape.push_back(IntImm(DataType::Int(32), input_rank));
reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
return true;
}
}
bool Assert(const IndexExpr& cond) final {
- if (const uint64_t* pdiff = as_const_uint(cond)) {
+ if (const int64_t* pdiff = as_const_int(cond)) {
return pdiff[0];
}
return true;
static inline const int32_t GetQmin(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
- if (dtype.is_int()) {
+ if (dtype.is_int() || dtype.is_uint()) {
auto* min_value = as_const_int(tvm::min_value(dtype));
CHECK(min_value != nullptr);
return static_cast<int32_t>(min_value[0]);
- } else if (dtype.is_uint()) {
- auto* min_value = as_const_uint(tvm::min_value(dtype));
- CHECK(min_value != nullptr);
- return static_cast<int32_t>(min_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
static inline const int32_t GetQmax(const DataType& dtype) {
CHECK_LE(dtype.bits(), 32)
<< "QNN ops support int32 or lower precision";
- if (dtype.is_int()) {
+ if (dtype.is_int() || dtype.is_uint()) {
auto* max_value = as_const_int(tvm::max_value(dtype));
CHECK(max_value != nullptr);
return static_cast<int32_t>(max_value[0]);
- } else if (dtype.is_uint()) {
- auto* max_value = as_const_uint(tvm::max_value(dtype));
- CHECK(max_value != nullptr);
- return static_cast<int32_t>(max_value[0]);
} else {
LOG(FATAL) << "Type not supported " << dtype;
return -1; // To hide the warning
}
}
-TEST(Pattern, Integer) {
+TEST(Pattern, IntImm) {
using namespace tvm;
tvm::Var tx, ty;
- arith::PVar<Integer> c;
+ arith::PVar<IntImm> c;
arith::PVar<Var> v;
{
// We can match integer and Var, both of which are
from tvm.contrib import util
import numpy as np
+def test_large_uint_imm():
+ value = (1 << 63) + 123
+ other = tvm.const(3, "uint64")
+ n = 12
+ num_thread = 2
+
+ A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A')
+ s = tvm.create_schedule(A.op)
+ xo, xi = s[A].split(A.op.axis[0], factor=num_thread)
+ s[A].bind(xi, tvm.thread_axis("threadIdx.x"))
+ s[A].bind(xo, tvm.thread_axis("blockIdx.x"))
+
+ def check_target(device):
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ return
+ f = tvm.build(s, [A], device)
+ # launch the kernel.
+ a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx)
+ f(a)
+ assert a.asnumpy()[0] == value + 3
+
+ check_target("cuda")
+ check_target("vulkan")
+
+
def test_add_pipeline():
n = tvm.var('n')
A = tvm.placeholder((n,), name='A')
if __name__ == "__main__":
+ test_large_uint_imm()
test_add_pipeline()
fcode = tvm.build(func, None, "llvm")
+def test_llvm_large_uintimm():
+ value = (1 << 63) + 123
+ other = tvm.const(3, "uint64")
+ A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A')
+ s = tvm.create_schedule(A.op)
+
+ def check_llvm():
+ if not tvm.module.enabled("llvm"):
+ return
+ f = tvm.build(s, [A], "llvm")
+ ctx = tvm.cpu(0)
+ # launch the kernel.
+ a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx)
+ f(a)
+ assert a.asnumpy() == value + 3
+
+ check_llvm()
+
+
def test_llvm_add_pipeline():
nn = 1024
n = tvm.convert(nn)
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
if __name__ == "__main__":
+ test_llvm_large_uintimm()
test_llvm_import()
test_alignment()
test_rank_zero()
def tvm_val_2_py_val(val):
val = tvm.ir_pass.Substitute(val, var_dict)
val = tvm.ir_pass.Simplify(val)
- assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm))
+ assert isinstance(val, (tvm.expr.IntImm,))
return val.value
ctx = tvm.context(target, 0)
assert x.value == 2
assert x.dtype == "int64"
- x = tvm.expr.UIntImm("uint16", 2)
- assert isinstance(x, tvm.expr.UIntImm)
- assert x.value == 2
- assert x.dtype == "uint16"
-
x = tvm.expr.StringImm("xyza")
assert isinstance(x, tvm.expr.StringImm)
assert x.value == "xyza"
- x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1))
+ x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1))
assert isinstance(x, tvm.expr.Cast)
assert x.dtype == "float32"
assert x.value.value == 1
def check(f, *args):
x = f(*[tvm.const(x, "int32") for x in args])
y = f(*args)
- if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
+ if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y):
raise ValueError("check error: %s vs %s " % (x, y))
tmod = tvm.truncmod
*/
inline bool IsConstInt(PrimExpr expr) {
return
- expr->IsInstance<tvm::ir::IntImmNode>() ||
- expr->IsInstance<tvm::ir::UIntImmNode>();
+ expr->IsInstance<tvm::ir::IntImmNode>();
}
/*!
* \return The integer value.
*/
inline int64_t GetConstInt(PrimExpr expr) {
- if (expr->IsInstance<tvm::ir::IntImmNode>()) {
- return expr.as<tvm::ir::IntImmNode>()->value;
- }
- if (expr->IsInstance<tvm::ir::UIntImmNode>()) {
- return expr.as<tvm::ir::UIntImmNode>()->value;
+ if (expr->IsInstance<tvm::IntImmNode>()) {
+ return expr.as<tvm::IntImmNode>()->value;
}
LOG(ERROR) << "expr must be a constant integer";
return -1;
"""
if isinstance(expr, Integral):
return expr
- if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(expr, tvm.expr.IntImm):
expr = tvm.ir_pass.Simplify(expr)
- if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(expr, tvm.expr.IntImm):
raise ValueError("Expect value to be constant int")
return int(expr.value)
"""
if isinstance(expr, Integral):
return expr == value
- if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(expr, tvm.expr.IntImm):
expr = tvm.ir_pass.Simplify(expr)
- if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(expr, tvm.expr.IntImm):
return False
return expr.value == value
for elem in in_tuple:
if isinstance(elem, tvm.expr.Var):
ret.append(elem)
- elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
+ elif not isinstance(elem, (tvm.expr.IntImm, int)):
elem = tvm.ir_pass.Simplify(elem)
- if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+ if not isinstance(elem, tvm.expr.IntImm):
ret.append(elem)
else:
ret.append(get_const_int(elem))