From ce807fe83825bce666ed1834ab24b5d6ddfa6bca Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 14 Jan 2020 19:03:26 -0800 Subject: [PATCH] [REFACTOR][IR] Unify IntImm and UIntImm (#4706) * [REFACTOR][IR] Unify IntImm and UIntImm This PR unifies UIntImm and IntImm to simplify the codebase. Unsigned integer constants will also be stored as IntImm. For uint constant that does not fit into int64(rare case), we introduced an intrinsic tvm_big_uint_imm to construct such intgers by its lower and higher 32bits. * [REFACTOR][IR] Remove UIntImm to use IntImm * rename big->large --- include/tvm/attrs.h | 4 -- include/tvm/expr.h | 48 +++++---------- include/tvm/expr_operator.h | 53 ++++++++-------- include/tvm/ir.h | 27 +++----- include/tvm/ir/expr.h | 50 +++++++++++++++ include/tvm/ir_functor_ext.h | 4 -- python/tvm/api.py | 3 + python/tvm/autotvm/task/task.py | 4 +- python/tvm/autotvm/util.py | 8 +-- python/tvm/expr.py | 17 ------ python/tvm/hybrid/calls.py | 2 +- python/tvm/hybrid/parser.py | 4 +- python/tvm/hybrid/util.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 2 +- src/api/api_ir.cc | 2 - src/api/api_lang.cc | 3 + src/arithmetic/analyzer.cc | 6 +- src/arithmetic/canonical_simplify.cc | 8 +-- src/arithmetic/const_fold.h | 61 ++++++++----------- src/arithmetic/const_int_bound.cc | 10 +-- src/arithmetic/int_set.cc | 6 +- src/arithmetic/modular_set.cc | 10 +-- src/arithmetic/pattern_match.h | 6 +- src/arithmetic/rewrite_simplify.cc | 26 ++++---- src/autotvm/touch_extractor.cc | 2 +- src/codegen/codegen_c.cc | 21 ++++--- src/codegen/codegen_c.h | 1 - src/codegen/codegen_opengl.cc | 5 -- src/codegen/codegen_opengl.h | 1 - src/codegen/llvm/codegen_arm.cc | 22 +++---- src/codegen/llvm/codegen_llvm.cc | 18 +++--- src/codegen/llvm/codegen_llvm.h | 1 - src/codegen/llvm/codegen_x86_64.cc | 4 +- src/codegen/llvm/intrin_rule_llvm.h | 8 +-- src/codegen/spirv/codegen_spirv.cc | 13 ++-- src/codegen/spirv/codegen_spirv.h | 1 - src/codegen/spirv/intrin_rule_spirv.cc | 2 +- src/codegen/spirv/ir_builder.cc | 4 +- src/codegen/spirv/ir_builder.h | 4 +- src/codegen/stackvm/codegen_stackvm.cc | 6 -- src/codegen/stackvm/codegen_stackvm.h | 1 - src/contrib/hybrid/codegen_hybrid.cc | 5 +- src/contrib/hybrid/codegen_hybrid.h | 1 - src/ir/expr.cc | 19 ++++++ src/lang/attr_functor.h | 4 -- src/lang/attrs.cc | 11 ---- src/lang/expr.cc | 11 +--- src/lang/expr_operator.cc | 55 +++++++---------- src/lang/ir.cc | 16 +---- src/pass/arg_binder.cc | 16 ++--- src/pass/ir_deep_compare.cc | 4 -- src/pass/ir_functor.cc | 2 - src/pass/lift_attr_scope.cc | 3 - src/pass/lower_intrin.cc | 2 +- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/lower_tvm_builtin.cc | 4 +- src/pass/make_api.cc | 6 +- src/pass/rewrite_unsafe_select.cc | 1 - src/pass/tensor_core.cc | 14 ++--- src/pass/unroll_loop.cc | 4 -- src/relay/backend/compile_engine.cc | 8 +-- src/relay/ir/expr.cc | 2 +- src/relay/ir/pretty_printer.cc | 4 -- src/relay/op/tensor/transform.cc | 2 +- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/util.h | 12 +--- tests/cpp/pattern_match_test.cc | 4 +- tests/python/unittest/test_codegen_device.py | 27 ++++++++ tests/python/unittest/test_codegen_llvm.py | 20 ++++++ tests/python/unittest/test_hybrid_script.py | 2 +- .../python/unittest/test_lang_constructor.py | 7 +-- tests/python/unittest/test_lang_operator.py | 2 +- topi/include/topi/detail/constant_utils.h | 10 +-- topi/python/topi/util.py | 12 ++-- 74 files changed, 361 insertions(+), 413 deletions(-) diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index ab9a711d2..9d9f98e79 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { CHECK(expr.defined()); if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } @@ -523,8 +521,6 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { *ptr = static_cast(op->value); } else if (const ir::IntImmNode* op = expr.as()) { *ptr = static_cast(op->value); - } else if (const ir::UIntImmNode* op = expr.as()) { - *ptr = static_cast(op->value); } else { LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } diff --git a/include/tvm/expr.h b/include/tvm/expr.h index faae303d9..62806c667 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -115,56 +115,38 @@ class Var : public PrimExpr { using ContainerType = VarNode; }; -class Integer; -/*! \brief ExprNode: constant integer. */ -class IntImmNode : public PrimExprNode { - public: - /*! \brief the Internal value. */ - int64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static Integer make(DataType t, int64_t value); - - static constexpr const char* _type_key = "IntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); -}; - /*! - * \brief Container of constant integer (IntImm). + * \brief Container of constant int that adds more constructors. * * This is used to store and automate type check * attributes that must be constant integer. + * + * \sa IntImm */ -class Integer : public PrimExpr { +class Integer : public IntImm { public: - Integer() : PrimExpr() {} + Integer() {} /*! * \brief constructor from node. */ - explicit Integer(ObjectPtr node) : PrimExpr(node) {} + explicit Integer(ObjectPtr node) : IntImm(node) {} /*! * \brief Construct integer from int value. */ - Integer(int value) : PrimExpr(value) {} // NOLINT(*) + Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + /*! + * \brief Construct integer from int imm. + * \param other The other value. + */ + Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) /*! * \brief Assign an expression to integer. * \param other another expression. */ - Integer& operator=(const Integer& other) { - data_ = other.data_; + Integer& operator=(const IntImm& other) { + data_ = ObjectRef::GetDataPtr(other); return *this; } - /*! - * \brief Get pointer to the internal value. - * \return the content of the integer. - */ - const IntImmNode* operator->() const { - return static_cast(get()); - } /*! * \brief convert to int64_t */ @@ -173,8 +155,6 @@ class Integer : public PrimExpr { << " Trying to reference a null Integer"; return (*this)->value; } - /*! \brief type indicate the container type */ - using ContainerType = IntImmNode; }; /*! \brief range over one dimension */ diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2d8f37855..ff3b340bf 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -30,6 +30,7 @@ #include #include +#include #include "expr.h" #include "ir.h" @@ -82,21 +83,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) { } } -/*! - * \brief Get x as constant uint expression. - * \param x The expression - * \return the address to the int expression, - * return nullptr, if x is not UIntImm. - */ -inline const uint64_t* as_const_uint(const PrimExpr& x) { - if (!x.defined()) return nullptr; - if (const ir::UIntImmNode* op = x.as()) { - return &(op->value); - } else { - return nullptr; - } -} - /*! * \brief Check whether x is a constant integer expression. * \param x The input argument @@ -597,6 +583,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x); */ TVM_DLL PrimExpr trunc(PrimExpr x); +/*! + * \brief Construct a large uint constant by its low 32 bits and high 32bits. + * \param dtype The final data type. + * \param low The lower 32 bits. + * \param high The higher 32 bits. + * \return The constructed expression. + */ +TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline PrimExpr OpName(PrimExpr x) { \ @@ -617,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan); // Implementation details after this inline bool is_const(const PrimExpr& x) { - if (x.as() || x.as()) { + if (x.as()) { return true; } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; - if (val.as() || val.as()) { + if (val.as()) { return true; } } @@ -631,8 +626,6 @@ inline bool is_const(const PrimExpr& x) { inline bool is_positive_const(const PrimExpr& a) { if (const ir::IntImmNode* op = a.as()) { return op->value > 0; - } else if (const ir::UIntImmNode* op = a.as()) { - return op->value > 0; } else { return false; } @@ -649,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) { inline bool is_const_int(const PrimExpr& x, int64_t value) { if (const auto* op = x.as()) { return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); } else if (const auto* op = x.as()) { const PrimExpr& val = op->value; if (const auto* opv = val.as()) { return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); } } return false; @@ -675,15 +664,27 @@ inline bool is_no_op(const Stmt& stmt) { template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { - if (t.is_int()) return ir::IntImmNode::make(t, static_cast(value)); - if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast(value)); + if (t.is_int()) return IntImm(t, static_cast(value)); + if (t.is_uint()) { + // Use IntImm if it is a small integer + uint64_t uval = static_cast(value); + if (uval <= static_cast(std::numeric_limits::max())) { + return IntImm(t, static_cast(value)); + } else { + uint64_t mask = (static_cast(1) << 32U) - 1U; + uint64_t low = uval & mask; + uint64_t high = uval >> 32U; + return LargeUIntImm(t, static_cast(low), static_cast(high)); + } + } if (t.is_float()) return ir::FloatImmNode::make(t, static_cast(value)); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kCustomBegin)) + if (static_cast(t.code()) >= static_cast(kCustomBegin)) { return ir::FloatImmNode::make(t, static_cast(value)); + } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 84039485a..9c14a31be 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -39,23 +39,6 @@ namespace ir { using IntImmNode = tvm::IntImmNode; using VarNode = tvm::VarNode; -/*! \brief constant unsigned integer. */ -class UIntImmNode : public PrimExprNode { - public: - /*! \brief The constant value content. */ - uint64_t value; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - } - - TVM_DLL static PrimExpr make(DataType t, uint64_t value); - - static constexpr const char* _type_key = "UIntImm"; - TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode); -}; - /*! \brief Floating point constants. */ class FloatImmNode : public PrimExprNode { public: @@ -1422,6 +1405,16 @@ inline bool IsPragmaKey(const std::string& attr_key) { /*! \brief namespace of TVM Intrinsic functions */ namespace intrinsic { +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; /*! * \brief See pesudo code * diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 12b34dd26..12e505ed9 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -131,6 +131,56 @@ class PrimExpr : public BaseExpr { using ContainerType = PrimExprNode; }; +/*! + * \brief Constant integer literals in the program. + * \sa IntImm + */ +class IntImmNode : public PrimExprNode { + public: + /*! \brief the Internal value. */ + int64_t value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "IntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference class to IntImmNode. + * + * \sa IntImmNode + */ +class IntImm : public PrimExpr { + public: + /*! + * \brief Constructor + */ + IntImm() {} + /*! + * \brief constructor from node. + */ + explicit IntImm(ObjectPtr node) : PrimExpr(node) {} + /*! + * \brief Constructor. + * \param dtype The data type of the value. + * \param value The internal value. + */ + TVM_DLL IntImm(DataType dtype, int64_t value); + /*! + * \brief Get pointer to the internal value. + * \return the content of the integer. + */ + const IntImmNode* operator->() const { + return static_cast(get()); + } + /*! \brief type indicate the container type */ + using ContainerType = IntImmNode; +}; + /*! * \brief Base node of all non-primitive expressions. * diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 7d57564fd..37a1fe4bf 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -161,7 +161,6 @@ class ExprFunctor { virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args ...) { @@ -203,7 +202,6 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); - IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); return vtable; @@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const BroadcastNode* op) override; void VisitExpr_(const ShuffleNode* op) override; void VisitExpr_(const IntImmNode* op) override; - void VisitExpr_(const UIntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; }; @@ -372,7 +369,6 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const BroadcastNode* op) override; PrimExpr VisitExpr_(const ShuffleNode* op) override; PrimExpr VisitExpr_(const IntImmNode* op) override; - PrimExpr VisitExpr_(const UIntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; }; diff --git a/python/tvm/api.py b/python/tvm/api.py index 7395d3524..4bfe794c1 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -92,6 +92,9 @@ def const(value, dtype=None): """ if dtype is None: dtype = _scalar_type_inference(value) + if dtype == "uint64" and value >= (1 << 63): + return _api_internal._LargeUIntImm( + dtype, value & ((1 << 32) - 1), value >> 32) return _api_internal._const(value, dtype) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7f36914eb..5067277d3 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None): workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): workload = x - elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): + elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): workload = x.value elif x is None: workload = 0 @@ -344,7 +344,7 @@ def compute_flop(sch): if len(source) != 1: raise FlopCalculationError("Found multiple output in the source of reduce op") return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) - if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): + if isinstance(exp, (expr.FloatImm, expr.IntImm)): return 0 if isinstance(exp, expr.Cast): return _count_flop(exp.value) diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 3026914ae..54001d333 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -155,9 +155,9 @@ def get_const_int(exp): """ if isinstance(exp, int): return exp - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): exp = ir_pass.Simplify(exp) - if not isinstance(exp, (expr.IntImm, expr.UIntImm)): + if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -179,9 +179,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, expr.Var): ret.append(elem) - elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)): + elif not isinstance(elem, (expr.IntImm, int)): elem = ir_pass.Simplify(elem) - if not isinstance(elem, (expr.IntImm, expr.UIntImm)): + if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 71c0aecd1..2fd7b78d9 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -341,23 +341,6 @@ class IntImm(ConstExpr): return self.value -@register_object -class UIntImm(ConstExpr): - """UInt constant. - - Parameters - ---------- - dtype : str - The data type - - value : int - The constant value. - """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__( - _make.UIntImm, dtype, value) - - @register_object class StringImm(ConstExpr): """String constant. diff --git a/python/tvm/hybrid/calls.py b/python/tvm/hybrid/calls.py index 1d5612e67..7038f6144 100644 --- a/python/tvm/hybrid/calls.py +++ b/python/tvm/hybrid/calls.py @@ -156,6 +156,6 @@ def max_num_threads(func_id, args): if args.__len__() == 0: res = _tgt.current_target().max_num_threads else: - _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint") + _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint") res = _tgt.current_target(args[0].value).max_num_threads return _api.convert(res) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 06bcbcabe..57d636328 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -386,7 +386,7 @@ class HybridParser(ast.NodeVisitor): if isinstance(i, numbers.Integral): arr = arr[i] else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + _internal_assert(isinstance(i, (_expr.IntImm,)), \ "All indices are supposed to be constants") arr = arr[i.value] return arr @@ -413,7 +413,7 @@ class HybridParser(ast.NodeVisitor): cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven - if isinstance(cond, _expr.UIntImm): + if isinstance(cond, _expr.IntImm): if cond.value: return visit_list_to_block(self.visit, node.body) if node.orelse: diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 0dd1fa141..a08a380dd 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -33,7 +33,7 @@ from ..container import Array #pylint: disable=invalid-name np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) +halide_imm_types = (_expr.IntImm, _expr.FloatImm) def _internal_assert(cond, err): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7e22d7213..e7f4682e7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -931,7 +931,7 @@ def _shape(): def _impl(inputs, attr, params): is_symbolic_shape = False for axis in attr['_input_shapes'][inputs[0]]: - if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(axis, (int, tvm.expr.IntImm)): is_symbolic_shape = True break diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index ca4823bc6..30ca51592 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -130,8 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer") REGISTER_MAKE(Reduce); REGISTER_MAKE(AttrStmt); -REGISTER_MAKE(IntImm); -REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 6a8bc58ad..fa7b59d36 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -53,6 +53,9 @@ TVM_REGISTER_GLOBAL("_const") } }); +TVM_REGISTER_GLOBAL("_LargeUIntImm") +.set_body_typed(LargeUIntImm); + TVM_REGISTER_GLOBAL("_str") .set_body_typed(ir::StringImmNode::make); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 7a3baa678..e03e5e238 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProve(const PrimExpr& expr) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value != 0; } auto res = this->rewrite_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } res = this->canonical_simplify(expr); - if (const auto* ptr = res.as()) { + if (const auto* ptr = res.as()) { return ptr->value != 0; } return false; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 5f721d7a1..90c6e48de 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -737,7 +737,7 @@ VisitExpr_(const DivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -797,7 +797,7 @@ VisitExpr_(const FloorDivNode* op) { // const folding PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x / c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -905,7 +905,7 @@ VisitExpr_(const ModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; @@ -975,7 +975,7 @@ VisitExpr_(const FloorModNode* op) { PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; - PVar c1; + PVar c1; // x % c1 if (c1.Match(b) && c1.Eval()->value > 0) { int64_t cval = c1.Eval()->value; diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 55c156d89..3b803ecd8 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ using ir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ @@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) { #define TVM_INDEX_CONST_PROPAGATION(BODY) \ - using ir::IntImmNode; \ - using ir::UIntImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const DataType& ta = a.dtype(); \ @@ -103,7 +99,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value); @@ -117,7 +113,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); if (pb && pb->value == 0) return a; if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value); if (fb && fb->value == 0) return a; @@ -129,7 +125,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -159,7 +155,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, pa->value / pb->value); + return IntImm(rtype, pa->value / pb->value); } if (pa) { if (pa->value == 0) return a; @@ -185,7 +181,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, pa->value % pb->value); + return IntImm(rtype, pa->value % pb->value); } if (pa) { if (pa->value == 0) return a; @@ -204,7 +200,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value)); + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -230,7 +226,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value)); + return IntImm(rtype, arith::floormod(pa->value, pb->value)); } if (pa) { if (pa->value == 0) return a; @@ -247,7 +243,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -258,7 +254,7 @@ template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value)); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; @@ -268,8 +264,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); return PrimExpr(); } @@ -277,8 +273,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); return PrimExpr(); } @@ -286,8 +282,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); return PrimExpr(); } @@ -295,8 +291,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); return PrimExpr(); } @@ -304,8 +300,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); return PrimExpr(); } @@ -313,17 +309,16 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); return PrimExpr(); } template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return b; if (pa && !pa->value) return a; if (pb && pb->value) return a; @@ -333,9 +328,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); - const UIntImmNode* pb = b.as(); + const IntImmNode* pa = a.as(); + const IntImmNode* pb = b.as(); if (pa && pa->value) return a; if (pa && !pa->value) return b; if (pb && pb->value) return b; @@ -345,10 +339,9 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { template<> inline PrimExpr TryConstFold(PrimExpr a) { - using ir::UIntImmNode; - const UIntImmNode* pa = a.as(); + const IntImmNode* pa = a.as(); if (pa) { - return UIntImmNode::make(DataType::UInt(1), !(pa->value)); + return IntImm(DataType::UInt(1), !(pa->value)); } return PrimExpr(); } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index a041e40ab..25d88d342 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl : return MakeBound(op->value, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value <= static_cast(kPosInf)) { - return MakeBound(op->value, op->value); - } else { - return Everything(op->dtype); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); @@ -496,7 +488,7 @@ class ConstIntBoundAnalyzer::Impl : */ static std::vector DetectBoundInfo(const PrimExpr& cond) { PVar x, y; - PVar c; + PVar c; // NOTE: canonical form always use <= or < if ((c <= x).Match(cond)) { return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))}; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index ceaa97646..37d5e9eb5 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -384,10 +384,6 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const UIntImmNode* op) final { - return IntervalSet::SinglePoint(GetRef(op)); - } - IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); @@ -476,7 +472,7 @@ class IntervalSetEvaluator : IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); IntervalSet base = Eval(op->base); - PVar stride; + PVar stride; if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 01dd2e8e4..c81842035 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -109,7 +109,7 @@ class ModularSetAnalyzer::Impl : // Detect useful constraints and use them in the analysis scope. std::function EnterConstraint(const PrimExpr& constraint) { PVar var; - PVar coeff, base; + PVar coeff, base; // pattern match interesting constraints if ((truncmod(var, coeff) == base).Match(constraint) || (floormod(var, coeff) == base).Match(constraint)) { @@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl : return Entry(0, op->value); } - Entry VisitExpr_(const UIntImmNode* op) final { - if (op->value < std::numeric_limits::max()) { - return Entry(0, static_cast(op->value)); - } else { - return Everything(); - } - } - Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 733dcf41c..a236e65a8 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -45,7 +45,7 @@ * } * * tvm::Var tx, ty; - * arith::PVar c; + * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr @@ -140,9 +140,9 @@ class PEqualChecker { }; template<> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const Integer& lhs, const Integer& rhs) const { + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 94d951da5..e6e152460 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -124,7 +124,7 @@ VisitExpr_(const AddNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -239,7 +239,7 @@ VisitExpr_(const SubNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -438,7 +438,7 @@ VisitExpr_(const MulNode* op) { // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules @@ -477,7 +477,7 @@ VisitExpr_(const DivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -700,7 +700,7 @@ VisitExpr_(const ModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -789,7 +789,7 @@ VisitExpr_(const FloorDivNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -934,7 +934,7 @@ VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1004,7 +1004,7 @@ VisitExpr_(const MinNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1189,7 +1189,7 @@ VisitExpr_(const MaxNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1362,7 +1362,7 @@ VisitExpr_(const EQNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1; PVar lanes; // vector rule @@ -1416,7 +1416,7 @@ VisitExpr_(const LTNode* op) { // Pattern var to match any expression PVar x, y, z, s1, s2; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; // vector rule @@ -1597,7 +1597,7 @@ VisitExpr_(const AndNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { @@ -1646,7 +1646,7 @@ VisitExpr_(const OrNode* op) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2; PVar lanes; if (op->dtype.lanes() != 1) { diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index cf138edd4..55ed36ca9 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -256,7 +256,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > Array attr{std::string("_attr_"), FloatImmNode::make(DataType::Float(32), trans(fea.length)), - IntImmNode::make(DataType::Int(32), fea.nest_level), + IntImm(DataType::Int(32), fea.nest_level), FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)), FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)), }; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 777ad6203..d9b7f7f08 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -372,16 +372,17 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } -inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - if (op->dtype == DataType::UInt(32)) { + +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (dtype == DataType::UInt(32)) { std::ostringstream temp; - temp << op->value << "U"; + temp << val << "U"; p->MarkConst(temp.str()); os << temp.str(); } else { os << "("; - p->PrintType(op->dtype, os); - os << ")" << op->value; + p->PrintType(dtype, os); + os << ")" << val; } } @@ -408,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintConst(op, os, this); -} + void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } @@ -528,6 +527,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } else if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsic(op, " ^ ", os, this); } else if (op->is_intrinsic(CallNode::bitwise_or)) { diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index cb092c566..7e5dd4269 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -128,7 +128,6 @@ class CodeGenC : void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 7967c1847..cea276d5c 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { CodeGenC::VisitExpr_(op, os); } -void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints."; - CodeGenC::VisitExpr_(op, os); -} - void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; CodeGenC::VisitExpr_(op, os); diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index cd1ec8336..19ca2ee12 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC { // Codegen for immediate values void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_arm.cc b/src/codegen/llvm/codegen_arm.cc index 6879fd5f8..44862cf7a 100644 --- a/src/codegen/llvm/codegen_arm.cc +++ b/src/codegen/llvm/codegen_arm.cc @@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); + Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { if (!call->dtype.is_vector() || call->dtype.bits() == 8 || (total_size != 128 && total_size != 64)) { Array vcnt_args; - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return ir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } @@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { const CallNode* c0 = input8.as(); CHECK(c0 != nullptr); Array vcnt8_args; - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); + vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = ir::CallNode::make( uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = ir::CallNode::make( uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); @@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 16->32bit Array vcnt32_args; - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = ir::CallNode::make( uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); @@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // Accumulation 32->64bit Array vcnt64_args; - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); + vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return ir::CallNode::make( call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index c04a023ae..60d8146fc 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast( - op->args[0].as()->value); - const uint64_t *num_signature = as_const_uint(op->args[1]); - CHECK(num_signature) << "The second argument should be a uint represents number of arguments, " - << "but " << op->args[1] << " got!\n"; + Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector sig_type; for (size_t i = 2; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < *num_signature) { + if (i - 2 < static_cast(num_signature)) { sig_type.push_back(arg_value.back()->getType()); } } @@ -722,6 +720,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return llvm::ConstantInt::get(LLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; @@ -804,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) { - return llvm::ConstantInt::get(LLVMType(op->dtype), op->value); -} - llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(LLVMType(op->dtype), op->value); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 34c3ee723..b269f2423 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -106,7 +106,6 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const VarNode* op) override; llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImmNode* op) override; - llvm::Value* VisitExpr_(const UIntImmNode* op) override; llvm::Value* VisitExpr_(const FloatImmNode* op) override; llvm::Value* VisitExpr_(const StringImmNode* op) override; llvm::Value* VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index 03656cc70..11bda70fb 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -96,8 +96,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { MakeValue( ir::BroadcastNode::make( ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } diff --git a/src/codegen/llvm/intrin_rule_llvm.h b/src/codegen/llvm/intrin_rule_llvm.h index b3ab557ee..1f839f362 100644 --- a/src/codegen/llvm/intrin_rule_llvm.h +++ b/src/codegen/llvm/intrin_rule_llvm.h @@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature)); + cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), num_signature)); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index a74942489..985f6816a 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { return builder_->IntImm(builder_->GetSType(op->dtype), op->value); } -spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) { - return builder_->UIntImm(builder_->GetSType(op->dtype), op->value); -} - spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); } @@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = op->args[0].as()->value; + uint32_t inst_id = static_cast( + op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); @@ -285,6 +282,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + return builder_->UIntImm(builder_->GetSType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3804bda0f..5aa7f9c49 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -65,7 +65,6 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImmNode* op) override; - spirv::Value VisitExpr_(const UIntImmNode* op) override; spirv::Value VisitExpr_(const FloatImmNode* op) override; spirv::Value VisitExpr_(const StringImmNode* op) override; spirv::Value VisitExpr_(const AddNode* op) override; diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index d41d96db5..d96883ed0 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { CHECK(call != nullptr); Array cargs; // intrin id. - cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id)); + cargs.push_back(IntImm(DataType::UInt(32), id)); for (PrimExpr arg : call->args) { cargs.push_back(arg); diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 6f8d96e14..bf43f11cc 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -342,9 +342,9 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type == DataType::UInt(1)) { // bool types. if (*pvalue) { - ib_.Begin(spv::OpConstantTrue).AddSeq(ret); + ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); } else { - ib_.Begin(spv::OpConstantFalse).AddSeq(ret); + ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret); } } else { // Integral/floating-point types. diff --git a/src/codegen/spirv/ir_builder.h b/src/codegen/spirv/ir_builder.h index 3843cbb3c..5d25e8634 100644 --- a/src/codegen/spirv/ir_builder.h +++ b/src/codegen/spirv/ir_builder.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index eccff6c74..01096ae1d 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) { - CHECK(op->value <= std::numeric_limits::max()) - << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); -} - void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { LOG(FATAL) << "Float Imm is not supported"; } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 07989b206..1360cc2d7 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -136,7 +136,6 @@ class CodeGenStackVM void VisitExpr_(const RampNode* op) final; void VisitExpr_(const BroadcastNode* op) final; void VisitExpr_(const IntImmNode* op) final; - void VisitExpr_(const UIntImmNode* op) final; void VisitExpr_(const FloatImmNode* op) final; void VisitExpr_(const StringImmNode* op) final; // statment diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7e3d44f26..346ec3808 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << op->value << ")"; -} + void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 89a1ece57..33bd0efae 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -117,7 +117,6 @@ class CodeGenHybrid : void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const UIntImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f698a5d18..6d8996741 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,25 @@ namespace tvm { +IntImm::IntImm(DataType dtype, int64_t value) { + CHECK(dtype.is_scalar()) + << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm can only take scalar."; + if (dtype.is_uint()) { + CHECK_GE(value, 0U); + } + ObjectPtr node = make_object(); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("make.IntImm") +.set_body_typed([](DataType dtype, int64_t value) { + return IntImm(dtype, value); +}); + GlobalVar::GlobalVar(std::string name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 34ee4b315..4fffc475a 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -77,7 +77,6 @@ class AttrFunctor { virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. @@ -113,7 +112,6 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); - ATTR_FUNCTOR_DISPATCH(UIntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); @@ -157,7 +155,6 @@ class AttrsEqualHandler : bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; @@ -198,7 +195,6 @@ class AttrsHashHandler : protected: size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImmNode* lhs) final; - size_t VisitAttr_(const ir::UIntImmNode* lhs) final; size_t VisitAttr_(const ir::FloatImmNode* lhs) final; size_t VisitAttr_(const ir::StringImmNode* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 1d3e767a5..a590f10e7 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { - return lhs->value == rhs->value; - } - return false; -} - bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; @@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) { return std::hash()(op->value); } -size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) { - return std::hash()(op->value); -} - size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) { return std::hash()(op->value); } diff --git a/src/lang/expr.cc b/src/lang/expr.cc index a7289369b..55dfb8934 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -30,7 +30,7 @@ namespace tvm { PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {} + : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {} @@ -54,15 +54,6 @@ Range::Range(PrimExpr begin, PrimExpr end) is_zero(begin) ? end : (end - begin))) { } -Integer IntImmNode::make(DataType t, int64_t value) { - CHECK(t.is_int() && t.is_scalar()) - << "ValueError: IntImm can only take scalar."; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return Integer(node); -} - Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d3875e28c..bd43d89d8 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -35,6 +35,14 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return ir::CastNode::make(t, value); } +PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { + return ir::CallNode::make( + t, ir::intrinsic::tvm_large_uint_imm, + {make_const(DataType::UInt(32), low), + make_const(DataType::UInt(32), high)}, + ir::CallNode::PureIntrinsic); +} + // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) if (lhs.dtype() == rhs.dtype()) return; @@ -78,26 +86,25 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) } } - // maximum and min limits PrimExpr max_value(const DataType& dtype) { using namespace ir; CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::max()); + return IntImm(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { int64_t val = 1; val = (val << (dtype.bits() - 1)) - 1; - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { if (dtype.bits() == 64) { - return UIntImmNode::make(dtype, std::numeric_limits::max()); + return make_const(dtype, std::numeric_limits::max()); } else if (dtype.bits() < 64) { uint64_t val = 1; val = (val << static_cast(dtype.bits())) - 1; - return UIntImmNode::make(dtype, val); + return IntImm(dtype, static_cast(val)); } } else if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -117,14 +124,14 @@ PrimExpr min_value(const DataType& dtype) { CHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { - return IntImmNode::make(dtype, std::numeric_limits::lowest()); + return IntImm(dtype, std::numeric_limits::lowest()); } else if (dtype.bits() < 64) { int64_t val = 1; val = -(val << (dtype.bits() - 1)); - return IntImmNode::make(dtype, val); + return IntImm(dtype, val); } } else if (dtype.is_uint()) { - return UIntImmNode::make(dtype, 0); + return IntImm(dtype, 0); } else if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImmNode::make(dtype, std::numeric_limits::lowest()); @@ -155,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) { bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); - } else if (const auto* op = x.as()) { - return ConstPowerHelper(op->value, shift); } else { return false; } } PrimExpr cast(const DataType& t, PrimExpr value) { - using ir::IntImmNode; - using ir::UIntImmNode; using ir::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImmNode* op = value.as()) { return make_const(t, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } @@ -184,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) { if (value.dtype() != vtype) { if (const IntImmNode* op = value.as()) { value = make_const(vtype, op->value); - } else if (const UIntImmNode* op = value.as()) { - return make_const(t, op->value); } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { @@ -219,7 +218,7 @@ PrimExpr operator-(PrimExpr a) { using ir::FloatImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); - if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value); + if (pa) return IntImm(a.dtype(), -pa->value); if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value); return make_zero(a.dtype()) - a; } @@ -322,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) { } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - using ir::IntImmNode; - using ir::UIntImmNode; CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value); - if (const UIntImmNode* op = cond.as()) { - if (op->value != 0) { - return true_value; - } else { - return false_value; - } - } else if (const IntImmNode* op = cond.as()) { + if (const IntImmNode* op = cond.as()) { if (op->value != 0) { return true_value; } else { @@ -424,7 +415,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -437,7 +428,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); if (pb) { if (pb->value == 0) return a; } @@ -450,7 +441,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic); @@ -460,7 +451,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic); @@ -470,7 +461,7 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value)); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); return ir::CallNode::make( a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic); @@ -494,7 +485,7 @@ PrimExpr abs(PrimExpr x) { using ir::IntImmNode; const IntImmNode* px = x.as(); if (px) { - return ir::IntImmNode::make(x.dtype(), std::abs(px->value)); + return IntImm(x.dtype(), std::abs(px->value)); } return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { diff --git a/src/lang/ir.cc b/src/lang/ir.cc index ad7f26022..f06a6be5e 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -31,14 +31,6 @@ namespace tvm { namespace ir { // constructors -PrimExpr UIntImmNode::make(DataType t, uint64_t value) { - CHECK(t.is_uint() && t.lanes() == 1) - << "ValueError: UIntImm can only take scalar"; - ObjectPtr node = make_object(); - node->dtype = t; - node->value = value; - return PrimExpr(node); -} PrimExpr FloatImmNode::make(DataType t, double value) { CHECK_EQ(t.lanes(), 1) @@ -248,7 +240,7 @@ PrimExpr ShuffleNode::make_concat(Array vectors) { int index = 0; for (const PrimExpr& e : vectors) { for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImmNode::make(DataType::Int(32), index++)); + indices.push_back(IntImm(DataType::Int(32), index++)); } } return make(vectors, indices); @@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) { } // Printers -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(" << op->dtype << ")" << op->value; - }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) .set_dispatch([](const ObjectRef& node, NodePrinter* p) { @@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode); TVM_REGISTER_NODE_TYPE(AttrStmtNode); TVM_REGISTER_NODE_TYPE(FloatImmNode); TVM_REGISTER_NODE_TYPE(IntImmNode); -TVM_REGISTER_NODE_TYPE(UIntImmNode); TVM_REGISTER_NODE_TYPE(StringImmNode); TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(VarNode); diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2c04de371..0f350d2d7 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - UIntImmNode::make(DataType::UInt(8), dtype.code()) && + IntImm(DataType::UInt(8), dtype.code()) && TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - UIntImmNode::make(DataType::UInt(8), dtype.bits()) && + IntImm(DataType::UInt(8), dtype.bits()) && TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - UIntImmNode::make(DataType::UInt(16), dtype.lanes())); + IntImm(DataType::UInt(16), dtype.lanes())); asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -193,7 +193,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, // mark alignment of external bufs init_nest_.emplace_back(AttrStmtNode::make( vptr, ir::attr::storage_alignment, - IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop)); + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); @@ -206,7 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->shape[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_shape, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field @@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr svalue = cast( stype, LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -251,7 +251,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, field_name << v_strides->name_hint << '[' << k << ']'; PrimExpr value = cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))); + IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -270,7 +270,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), LoadNode::make(tvm_shape_type, v_strides, - IntImmNode::make(DataType::Int(32), k), const_true(1))), + IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6eacb145b..8c441510c 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -252,10 +252,6 @@ class IRDeepCompare : CompareValue(op->value, other.as()->value); } - void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final { - CompareValue(op->value, other.as()->value); - } - void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final { CompareValue(op->value, other.as()->value); } diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 67acec674..857206f8d 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode); DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} -void ExprVisitor::VisitExpr_(const UIntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} @@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index 7b760fa4a..5aba355b7 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator { if (const IntImmNode* op = a.as()) { return op->value == b.as()->value; } - if (const UIntImmNode* op = a.as()) { - return op->value == b.as()->value; - } return false; } diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index ed8be8bb3..5684f4ef7 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -173,7 +173,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; - PVar c; + PVar c; auto e = GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index a0b07c293..d509169df 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const CommReducerNode *combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const UIntImmNode *size_of_args = call->args[0].as(); + const IntImmNode *size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 8e7f1d86d..01a97b787 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -129,8 +129,8 @@ class BuiltinLower : public StmtExprMutator { {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), - IntImmNode::make(DataType::Int(32), op->dtype.code()), - IntImmNode::make(DataType::Int(32), op->dtype.bits())}, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index d5c73a2e8..5df36d0b2 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -69,8 +69,8 @@ LoweredFunc MakeAPI(Stmt body, // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, - IntImmNode::make(DataType::Int(32), i), - IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); PrimExpr res = CallNode::make( @@ -117,7 +117,7 @@ LoweredFunc MakeAPI(Stmt body, seq_init.emplace_back(LetStmtNode::make( tcode, LoadNode::make( DataType::Int(32), v_packed_arg_type_ids, - IntImmNode::make(DataType::Int(32), i), const_true(1)), + IntImm(DataType::Int(32), i), const_true(1)), nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 224a81c12..9fb19cc4b 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor { return false; } bool VisitExpr_(const VarNode* op) final { return false; } - bool VisitExpr_(const UIntImmNode* op) final { return false; } bool VisitExpr_(const IntImmNode* op) final { return false; } bool VisitExpr_(const FloatImmNode* op) final { return false; } bool VisitExpr_(const StringImmNode* op) final { return false; } diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index bb57fe8c3..956f27c93 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -462,7 +462,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -575,7 +575,7 @@ class BufferAnalyser : public StmtExprVisitor { strides = bi.strides; } else { for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, bi.shape[j]); } @@ -765,7 +765,7 @@ class ThreadIdxMutator : public StmtExprMutator { op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { - PrimExpr zero = IntImmNode::make(DataType::Int(32), 0); + PrimExpr zero = IntImm(DataType::Int(32), 0); return zero; } if (op->name_hint == "threadIdx.y") { @@ -934,7 +934,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr stride = strides[strides.size()-2]; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); PrimExpr src = CallNode::make(value->dtype, @@ -984,7 +984,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr dst = it3->second; // thread index unification inside a warp - PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_); + PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); dst = CallNode::make(DataType::Handle(), @@ -1089,7 +1089,7 @@ class TensorCoreIRMutator : public StmtExprMutator { Array strides; for (size_t i = 1; i < shape.size(); ++i) { - PrimExpr stride = IntImmNode::make(DataType::Int(32), 1); + PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { stride = MulNode::make(stride, shape[j]); } @@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } strides.push_back(make_const(DataType::Int(32), 1)); - PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0); + PrimExpr elem_offset = IntImm(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { elem_offset = AddNode::make( diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index b2c50f7a8..26ad59189 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator { // constant folding. PrimExpr extent = ir::Simplify(op->extent); const IntImmNode *v1 = extent.as(); - const UIntImmNode *v2 = extent.as(); int value = -1; if (v1 != nullptr) { value = static_cast(v1->value); } - if (v2 != nullptr) { - value = static_cast(v2->value); - } return value; } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 00c40b256..5ee4ce30c 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -88,7 +88,7 @@ Array GetShape(const Array& shape) { if (pval != nullptr) { CHECK_LE(pval[0], std::numeric_limits::max()); CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval)); + res.push_back(IntImm(DataType::Int(32), *pval)); } else if (val->IsInstance()) { res.push_back(val.as()->ToVar()); } else { @@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { // set inputs for (auto param : prim_func->params) { int state = param_states_[param]; - cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state)); + cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); if (state & kNeedInputData) { for (auto t : param_data_[param]) { cache_node->inputs.push_back(t); @@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { auto ret_type = call_node->checked_type(); Array out_ndims; if (const auto* ttype = ret_type.as()) { - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } else { auto rtype = ret_type.as(); // TODO(@icemelon): Allow recursive tuple @@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { for (size_t i = 0; i < rtype->fields.size(); ++i) { auto ttype = rtype->fields[i].as(); CHECK(ttype); - out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size())); + out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); } } // Call shape function diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f66cce6b7..9966d9cc5 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); shape.push_back( - tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i])); + tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorTypeNode::make(shape, dtype); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 25650c776..400a6bea2 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -857,10 +857,6 @@ class PrettyPrinter : return PrintConstScalar(op->dtype, &(op->value)); } - Doc VisitAttr_(const ir::UIntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); - } - Doc VisitAttr_(const ir::FloatImmNode* op) final { return PrintConstScalar(op->dtype, &(op->value)); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4d3a4b958..b5383cd33 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -852,7 +852,7 @@ bool ArgWhereRel(const Array& types, const auto& input_rank = input_shape.size(); std::vector result_shape; result_shape.push_back(Any::make()); - result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank)); + result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32))); return true; } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 594669343..30a9a5c80 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode { } bool Assert(const IndexExpr& cond) final { - if (const uint64_t* pdiff = as_const_uint(cond)) { + if (const int64_t* pdiff = as_const_int(cond)) { return pdiff[0]; } return true; diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 378a5e372..2e332413c 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -47,14 +47,10 @@ static inline Array get_shape(const Type& type) { static inline const int32_t GetQmin(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* min_value = as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); return static_cast(min_value[0]); - } else if (dtype.is_uint()) { - auto* min_value = as_const_uint(tvm::min_value(dtype)); - CHECK(min_value != nullptr); - return static_cast(min_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning @@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) { static inline const int32_t GetQmax(const DataType& dtype) { CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; - if (dtype.is_int()) { + if (dtype.is_int() || dtype.is_uint()) { auto* max_value = as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); return static_cast(max_value[0]); - } else if (dtype.is_uint()) { - auto* max_value = as_const_uint(tvm::max_value(dtype)); - CHECK(max_value != nullptr); - return static_cast(max_value[0]); } else { LOG(FATAL) << "Type not supported " << dtype; return -1; // To hide the warning diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5392eaeac..193f2f206 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -127,10 +127,10 @@ TEST(Pattern, Basic) { } } -TEST(Pattern, Integer) { +TEST(Pattern, IntImm) { using namespace tvm; tvm::Var tx, ty; - arith::PVar c; + arith::PVar c; arith::PVar v; { // We can match integer and Var, both of which are diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 45ecf9539..5a10618fb 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -18,6 +18,32 @@ import tvm from tvm.contrib import util import numpy as np +def test_large_uint_imm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + n = 12 + num_thread = 2 + + A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + xo, xi = s[A].split(A.op.axis[0], factor=num_thread) + s[A].bind(xi, tvm.thread_axis("threadIdx.x")) + s[A].bind(xo, tvm.thread_axis("blockIdx.x")) + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + f = tvm.build(s, [A], device) + # launch the kernel. + a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy()[0] == value + 3 + + check_target("cuda") + check_target("vulkan") + + def test_add_pipeline(): n = tvm.var('n') A = tvm.placeholder((n,), name='A') @@ -112,4 +138,5 @@ def test_add_pipeline(): if __name__ == "__main__": + test_large_uint_imm() test_add_pipeline() diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index 0e595cd79..4920206ee 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -88,6 +88,25 @@ def test_llvm_lookup_intrin(): fcode = tvm.build(func, None, "llvm") +def test_llvm_large_uintimm(): + value = (1 << 63) + 123 + other = tvm.const(3, "uint64") + A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A') + s = tvm.create_schedule(A.op) + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx) + f(a) + assert a.asnumpy() == value + 3 + + check_llvm() + + def test_llvm_add_pipeline(): nn = 1024 n = tvm.convert(nn) @@ -645,6 +664,7 @@ def test_llvm_shuffle(): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": + test_llvm_large_uintimm() test_llvm_import() test_alignment() test_rank_zero() diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c3c40cf74..5f1facb2b 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) - assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm)) + assert isinstance(val, (tvm.expr.IntImm,)) return val.value ctx = tvm.context(target, 0) diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index fe329494e..c4187858a 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -38,16 +38,11 @@ def test_expr_constructor(): assert x.value == 2 assert x.dtype == "int64" - x = tvm.expr.UIntImm("uint16", 2) - assert isinstance(x, tvm.expr.UIntImm) - assert x.value == 2 - assert x.dtype == "uint16" - x = tvm.expr.StringImm("xyza") assert isinstance(x, tvm.expr.StringImm) assert x.value == "xyza" - x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1)) assert isinstance(x, tvm.expr.Cast) assert x.dtype == "float32" assert x.value.value == 1 diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index c57f4a110..ac2ee6d88 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -29,7 +29,7 @@ def test_const_fold(): def check(f, *args): x = f(*[tvm.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y): raise ValueError("check error: %s vs %s " % (x, y)) tmod = tvm.truncmod diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 43ac3a29c..e6de76f20 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -43,8 +43,7 @@ using namespace tvm; */ inline bool IsConstInt(PrimExpr expr) { return - expr->IsInstance() || - expr->IsInstance(); + expr->IsInstance(); } /*! @@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) { * \return The integer value. */ inline int64_t GetConstInt(PrimExpr expr) { - if (expr->IsInstance()) { - return expr.as()->value; - } - if (expr->IsInstance()) { - return expr.as()->value; + if (expr->IsInstance()) { + return expr.as()->value; } LOG(ERROR) << "expr must be a constant integer"; return -1; diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 8f32a297d..02d082b8b 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -92,9 +92,9 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -136,9 +136,9 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): expr = tvm.ir_pass.Simplify(expr) - if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(expr, tvm.expr.IntImm): return False return expr.value == value @@ -160,9 +160,9 @@ def get_const_tuple(in_tuple): for elem in in_tuple: if isinstance(elem, tvm.expr.Var): ret.append(elem) - elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): + elif not isinstance(elem, (tvm.expr.IntImm, int)): elem = tvm.ir_pass.Simplify(elem) - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): + if not isinstance(elem, tvm.expr.IntImm): ret.append(elem) else: ret.append(get_const_int(elem)) -- 2.34.1