[REFACTOR][IR] Unify IntImm and UIntImm (#4706)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 15 Jan 2020 03:03:26 +0000 (19:03 -0800)
committerGitHub <noreply@github.com>
Wed, 15 Jan 2020 03:03:26 +0000 (19:03 -0800)
* [REFACTOR][IR] Unify IntImm and UIntImm

This PR unifies UIntImm and IntImm to simplify the codebase.
Unsigned integer constants will also be stored as IntImm.

For uint constant that does not fit into int64(rare case), we introduced
an intrinsic tvm_big_uint_imm to construct such intgers by its
lower and higher 32bits.

* [REFACTOR][IR] Remove UIntImm to use IntImm

* rename big->large

74 files changed:
include/tvm/attrs.h
include/tvm/expr.h
include/tvm/expr_operator.h
include/tvm/ir.h
include/tvm/ir/expr.h
include/tvm/ir_functor_ext.h
python/tvm/api.py
python/tvm/autotvm/task/task.py
python/tvm/autotvm/util.py
python/tvm/expr.py
python/tvm/hybrid/calls.py
python/tvm/hybrid/parser.py
python/tvm/hybrid/util.py
python/tvm/relay/frontend/tensorflow.py
src/api/api_ir.cc
src/api/api_lang.cc
src/arithmetic/analyzer.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/const_fold.h
src/arithmetic/const_int_bound.cc
src/arithmetic/int_set.cc
src/arithmetic/modular_set.cc
src/arithmetic/pattern_match.h
src/arithmetic/rewrite_simplify.cc
src/autotvm/touch_extractor.cc
src/codegen/codegen_c.cc
src/codegen/codegen_c.h
src/codegen/codegen_opengl.cc
src/codegen/codegen_opengl.h
src/codegen/llvm/codegen_arm.cc
src/codegen/llvm/codegen_llvm.cc
src/codegen/llvm/codegen_llvm.h
src/codegen/llvm/codegen_x86_64.cc
src/codegen/llvm/intrin_rule_llvm.h
src/codegen/spirv/codegen_spirv.cc
src/codegen/spirv/codegen_spirv.h
src/codegen/spirv/intrin_rule_spirv.cc
src/codegen/spirv/ir_builder.cc
src/codegen/spirv/ir_builder.h
src/codegen/stackvm/codegen_stackvm.cc
src/codegen/stackvm/codegen_stackvm.h
src/contrib/hybrid/codegen_hybrid.cc
src/contrib/hybrid/codegen_hybrid.h
src/ir/expr.cc
src/lang/attr_functor.h
src/lang/attrs.cc
src/lang/expr.cc
src/lang/expr_operator.cc
src/lang/ir.cc
src/pass/arg_binder.cc
src/pass/ir_deep_compare.cc
src/pass/ir_functor.cc
src/pass/lift_attr_scope.cc
src/pass/lower_intrin.cc
src/pass/lower_thread_allreduce.cc
src/pass/lower_tvm_builtin.cc
src/pass/make_api.cc
src/pass/rewrite_unsafe_select.cc
src/pass/tensor_core.cc
src/pass/unroll_loop.cc
src/relay/backend/compile_engine.cc
src/relay/ir/expr.cc
src/relay/ir/pretty_printer.cc
src/relay/op/tensor/transform.cc
src/relay/pass/type_solver.cc
src/relay/qnn/util.h
tests/cpp/pattern_match_test.cc
tests/python/unittest/test_codegen_device.py
tests/python/unittest/test_codegen_llvm.py
tests/python/unittest/test_hybrid_script.py
tests/python/unittest/test_lang_constructor.py
tests/python/unittest/test_lang_operator.py
topi/include/topi/detail/constant_utils.h
topi/python/topi/util.py

index ab9a711..9d9f98e 100644 (file)
@@ -490,8 +490,6 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
     CHECK(expr.defined());
     if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<T>(op->value);
-    } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
-      *ptr = static_cast<T>(op->value);
     } else {
       LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
     }
@@ -523,8 +521,6 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
       *ptr = static_cast<double>(op->value);
     } else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<double>(op->value);
-    } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
-      *ptr = static_cast<double>(op->value);
     } else {
       LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
     }
index faae303..62806c6 100644 (file)
@@ -115,57 +115,39 @@ class Var : public PrimExpr {
   using ContainerType = VarNode;
 };
 
-class Integer;
-/*! \brief ExprNode: constant integer. */
-class IntImmNode : public PrimExprNode {
- public:
-  /*! \brief the Internal value. */
-  int64_t value;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("dtype", &dtype);
-    v->Visit("value", &value);
-  }
-
-  TVM_DLL static Integer make(DataType t, int64_t value);
-
-  static constexpr const char* _type_key = "IntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
-};
-
 /*!
- * \brief Container of constant integer (IntImm).
+ * \brief Container of constant int that adds more constructors.
  *
  * This is used to store and automate type check
  * attributes that must be constant integer.
+ *
+ * \sa IntImm
  */
-class Integer : public PrimExpr {
+class Integer : public IntImm {
  public:
-  Integer() : PrimExpr() {}
+  Integer() {}
   /*!
    * \brief constructor from node.
    */
-  explicit Integer(ObjectPtr<Object> node) : PrimExpr(node) {}
+  explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
   /*!
    * \brief Construct integer from int value.
    */
-  Integer(int value) : PrimExpr(value) {}  // NOLINT(*)
+  Integer(int value) : IntImm(DataType::Int(32), value) {}  // NOLINT(*)
+  /*!
+   * \brief Construct integer from int imm.
+   * \param other The other value.
+   */
+  Integer(IntImm other) : IntImm(std::move(other)) {}  // NOLINT(*)
   /*!
    * \brief Assign an expression to integer.
    * \param other another expression.
    */
-  Integer& operator=(const Integer& other) {
-    data_ = other.data_;
+  Integer& operator=(const IntImm& other) {
+    data_ = ObjectRef::GetDataPtr<Object>(other);
     return *this;
   }
   /*!
-   * \brief Get pointer to the internal value.
-   * \return the content of the integer.
-   */
-  const IntImmNode* operator->() const {
-    return static_cast<const IntImmNode*>(get());
-  }
-  /*!
    * \brief convert to int64_t
    */
   operator int64_t() const {
@@ -173,8 +155,6 @@ class Integer : public PrimExpr {
         << " Trying to reference a null Integer";
     return (*this)->value;
   }
-  /*! \brief type indicate the container type */
-  using ContainerType = IntImmNode;
 };
 
 /*! \brief range over one dimension */
index 2d8f378..ff3b340 100644 (file)
@@ -30,6 +30,7 @@
 
 #include <algorithm>
 #include <type_traits>
+#include <limits>
 #include "expr.h"
 #include "ir.h"
 
@@ -83,21 +84,6 @@ inline const int64_t* as_const_int(const PrimExpr& x) {
 }
 
 /*!
- * \brief Get x as constant uint expression.
- * \param x The expression
- * \return the address to the int expression,
- *         return nullptr, if x is not UIntImm.
- */
-inline const uint64_t* as_const_uint(const PrimExpr& x) {
-  if (!x.defined()) return nullptr;
-  if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
-    return &(op->value);
-  } else {
-    return nullptr;
-  }
-}
-
-/*!
  * \brief Check whether x is a constant integer expression.
  * \param x The input argument
  * \param value the value to be compared against.
@@ -597,6 +583,15 @@ TVM_DLL PrimExpr nearbyint(PrimExpr x);
  */
 TVM_DLL PrimExpr trunc(PrimExpr x);
 
+/*!
+ * \brief Construct a large uint constant by its low 32 bits and high 32bits.
+ * \param dtype The final data type.
+ * \param low The lower 32 bits.
+ * \param high The higher 32 bits.
+ * \return The constructed expression.
+ */
+TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                             \
   inline PrimExpr OpName(PrimExpr x) {                                     \
@@ -617,11 +612,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);
 
 // Implementation details after this
 inline bool is_const(const PrimExpr& x) {
-  if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
+  if (x.as<ir::IntImmNode>()) {
     return true;
   } else if (const auto* op = x.as<ir::BroadcastNode>()) {
     const PrimExpr& val = op->value;
-    if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
+    if (val.as<ir::IntImmNode>()) {
       return true;
     }
   }
@@ -631,8 +626,6 @@ inline bool is_const(const PrimExpr& x) {
 inline bool is_positive_const(const PrimExpr& a) {
   if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
     return op->value > 0;
-  } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
-    return op->value > 0;
   } else {
     return false;
   }
@@ -649,14 +642,10 @@ inline bool is_negative_const(const PrimExpr& a) {
 inline bool is_const_int(const PrimExpr& x, int64_t value) {
   if (const auto* op = x.as<ir::IntImmNode>()) {
     return op->value == value;
-  } else if (const auto* op = x.as<ir::UIntImmNode>()) {
-    return op->value == static_cast<uint64_t>(value);
   } else if (const auto* op = x.as<ir::BroadcastNode>()) {
     const PrimExpr& val = op->value;
     if (const auto* opv = val.as<ir::IntImmNode>()) {
       return opv->value == value;
-    } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
-      return opv->value == static_cast<uint64_t>(value);
     }
   }
   return false;
@@ -675,15 +664,27 @@ inline bool is_no_op(const Stmt& stmt) {
 
 template<typename ValueType>
 inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
-  if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
-  if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
+  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
+  if (t.is_uint()) {
+    // Use IntImm if it is a small integer
+    uint64_t uval = static_cast<uint64_t>(value);
+    if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
+      return IntImm(t, static_cast<int64_t>(value));
+    } else {
+      uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
+      uint64_t low = uval & mask;
+      uint64_t high = uval >> 32U;
+      return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
+    }
+  }
   if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
   // For now, we store const scalar values of custom datatypes within doubles; later, during the
   // datatypes lowering pass, we will lower the value to its true representation in the format
   // specified by the datatype.
   // TODO(gus) when do we need to start worrying about doubles not being precise enough?
-  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
+  if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
     return ir::FloatImmNode::make(t, static_cast<double>(value));
+  }
   LOG(FATAL) << "cannot make const for type " << t;
   return PrimExpr();
 }
index 8403948..9c14a31 100644 (file)
@@ -39,23 +39,6 @@ namespace ir {
 using IntImmNode = tvm::IntImmNode;
 using VarNode = tvm::VarNode;
 
-/*! \brief constant unsigned integer. */
-class UIntImmNode : public PrimExprNode {
- public:
-  /*! \brief The constant value content. */
-  uint64_t value;
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("dtype", &dtype);
-    v->Visit("value", &value);
-  }
-
-  TVM_DLL static PrimExpr make(DataType t, uint64_t value);
-
-  static constexpr const char* _type_key = "UIntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
-};
-
 /*! \brief Floating point constants. */
 class FloatImmNode : public PrimExprNode {
  public:
@@ -1425,6 +1408,16 @@ namespace intrinsic {
 /*!
  * \brief See pesudo code
  *
+ *  Construct a big uint that may not be representable by int64
+ *
+ *  Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) {
+ *    return (v1 << 32) | v0;
+ *  }
+ */
+constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm";
+/*!
+ * \brief See pesudo code
+ *
  *  Handle tvm_address_of(Load *op) {
  *     return &op->buffer_var[index];
  *  }
index 12b34dd..12e505e 100644 (file)
@@ -132,6 +132,56 @@ class PrimExpr : public BaseExpr {
 };
 
 /*!
+ * \brief Constant integer literals in the program.
+ * \sa IntImm
+ */
+class IntImmNode : public PrimExprNode {
+ public:
+  /*! \brief the Internal value. */
+  int64_t value;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("dtype", &dtype);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "IntImm";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference class to IntImmNode.
+ *
+ * \sa IntImmNode
+ */
+class IntImm : public PrimExpr {
+ public:
+  /*!
+   * \brief Constructor
+   */
+  IntImm() {}
+  /*!
+   * \brief constructor from node.
+   */
+  explicit IntImm(ObjectPtr<Object> node) : PrimExpr(node) {}
+  /*!
+   * \brief Constructor.
+   * \param dtype The data type of the value.
+   * \param value The internal value.
+   */
+  TVM_DLL IntImm(DataType dtype, int64_t value);
+  /*!
+   * \brief Get pointer to the internal value.
+   * \return the content of the integer.
+   */
+  const IntImmNode* operator->() const {
+    return static_cast<const IntImmNode*>(get());
+  }
+  /*! \brief type indicate the container type */
+  using ContainerType = IntImmNode;
+};
+
+/*!
  * \brief Base node of all non-primitive expressions.
  *
  * RelayExpr supports tensor types, functions and ADT as
index 7d57564..37a1fe4 100644 (file)
@@ -161,7 +161,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
   virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExprDefault_(const Object* op, Args ...) {
@@ -203,7 +202,6 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
     IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
     IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
-    IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
     return vtable;
@@ -327,7 +325,6 @@ class TVM_DLL ExprVisitor :
   void VisitExpr_(const BroadcastNode* op) override;
   void VisitExpr_(const ShuffleNode* op) override;
   void VisitExpr_(const IntImmNode* op) override;
-  void VisitExpr_(const UIntImmNode* op) override;
   void VisitExpr_(const FloatImmNode* op) override;
   void VisitExpr_(const StringImmNode* op) override;
 };
@@ -372,7 +369,6 @@ class TVM_DLL ExprMutator :
   PrimExpr VisitExpr_(const BroadcastNode* op) override;
   PrimExpr VisitExpr_(const ShuffleNode* op) override;
   PrimExpr VisitExpr_(const IntImmNode* op) override;
-  PrimExpr VisitExpr_(const UIntImmNode* op) override;
   PrimExpr VisitExpr_(const FloatImmNode* op) override;
   PrimExpr VisitExpr_(const StringImmNode* op) override;
 };
index 7395d35..4bfe794 100644 (file)
@@ -92,6 +92,9 @@ def const(value, dtype=None):
     """
     if dtype is None:
         dtype = _scalar_type_inference(value)
+    if dtype == "uint64" and value >= (1 << 63):
+        return _api_internal._LargeUIntImm(
+            dtype, value & ((1 << 32) - 1), value >> 32)
     return _api_internal._const(value, dtype)
 
 
index 7f36914..5067277 100644 (file)
@@ -221,7 +221,7 @@ def args_to_workload(x, topi_compute_func=None):
         workload = tuple([args_to_workload(a) for a in x])
     elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
         workload = x
-    elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
+    elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
         workload = x.value
     elif x is None:
         workload = 0
@@ -344,7 +344,7 @@ def compute_flop(sch):
             if len(source) != 1:
                 raise FlopCalculationError("Found multiple output in the source of reduce op")
             return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
-        if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
+        if isinstance(exp, (expr.FloatImm, expr.IntImm)):
             return 0
         if isinstance(exp, expr.Cast):
             return _count_flop(exp.value)
index 3026914..54001d3 100644 (file)
@@ -155,9 +155,9 @@ def get_const_int(exp):
     """
     if isinstance(exp, int):
         return exp
-    if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
+    if not isinstance(exp, (expr.IntImm,)):
         exp = ir_pass.Simplify(exp)
-    if not isinstance(exp, (expr.IntImm, expr.UIntImm)):
+    if not isinstance(exp, (expr.IntImm,)):
         raise ValueError("Expect value to be constant int")
     return exp.value
 
@@ -179,9 +179,9 @@ def get_const_tuple(in_tuple):
     for elem in in_tuple:
         if isinstance(elem, expr.Var):
             ret.append(elem)
-        elif not isinstance(elem, (expr.IntImm, expr.UIntImm, int)):
+        elif not isinstance(elem, (expr.IntImm, int)):
             elem = ir_pass.Simplify(elem)
-            if not isinstance(elem, (expr.IntImm, expr.UIntImm)):
+            if not isinstance(elem, (expr.IntImm)):
                 ret.append(elem)
         else:
             ret.append(get_const_int(elem))
index 71c0aec..2fd7b78 100644 (file)
@@ -342,23 +342,6 @@ class IntImm(ConstExpr):
 
 
 @register_object
-class UIntImm(ConstExpr):
-    """UInt constant.
-
-    Parameters
-    ----------
-    dtype : str
-        The data type
-
-    value : int
-        The constant value.
-    """
-    def __init__(self, dtype, value):
-        self.__init_handle_by_constructor__(
-            _make.UIntImm, dtype, value)
-
-
-@register_object
 class StringImm(ConstExpr):
     """String constant.
 
index 1d5612e..7038f61 100644 (file)
@@ -156,6 +156,6 @@ def max_num_threads(func_id, args):
     if args.__len__() == 0:
         res = _tgt.current_target().max_num_threads
     else:
-        _internal_assert(isinstance(args[0], _expr.UIntImm), "In tvm bool should be uint")
+        _internal_assert(isinstance(args[0], _expr.IntImm), "In tvm bool should be uint")
         res = _tgt.current_target(args[0].value).max_num_threads
     return _api.convert(res)
index 06bcbca..57d6363 100644 (file)
@@ -386,7 +386,7 @@ class HybridParser(ast.NodeVisitor):
                 if isinstance(i, numbers.Integral):
                     arr = arr[i]
                 else:
-                    _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
+                    _internal_assert(isinstance(i, (_expr.IntImm,)), \
                                      "All indices are supposed to be constants")
                     arr = arr[i.value]
             return arr
@@ -413,7 +413,7 @@ class HybridParser(ast.NodeVisitor):
         cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
 
         # Return no IfThenElse if proven
-        if isinstance(cond, _expr.UIntImm):
+        if isinstance(cond, _expr.IntImm):
             if cond.value:
                 return visit_list_to_block(self.visit, node.body)
             if node.orelse:
index 0dd1fa1..a08a380 100644 (file)
@@ -33,7 +33,7 @@ from ..container import Array
 #pylint: disable=invalid-name
 np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
 tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
-halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
+halide_imm_types = (_expr.IntImm, _expr.FloatImm)
 
 
 def _internal_assert(cond, err):
index 7e22d72..e7f4682 100644 (file)
@@ -931,7 +931,7 @@ def _shape():
     def _impl(inputs, attr, params):
         is_symbolic_shape = False
         for axis in attr['_input_shapes'][inputs[0]]:
-            if not isinstance(axis, (int, tvm.expr.IntImm, tvm.expr.UIntImm)):
+            if not isinstance(axis, (int, tvm.expr.IntImm)):
                 is_symbolic_shape = True
                 break
 
index ca4823b..30ca515 100644 (file)
@@ -130,8 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
 REGISTER_MAKE(Reduce);
 REGISTER_MAKE(AttrStmt);
 
-REGISTER_MAKE(IntImm);
-REGISTER_MAKE(UIntImm);
 REGISTER_MAKE(FloatImm);
 REGISTER_MAKE(StringImm);
 
index 6a8bc58..fa7b59d 100644 (file)
@@ -53,6 +53,9 @@ TVM_REGISTER_GLOBAL("_const")
     }
   });
 
+TVM_REGISTER_GLOBAL("_LargeUIntImm")
+.set_body_typed(LargeUIntImm);
+
 TVM_REGISTER_GLOBAL("_str")
 .set_body_typed(ir::StringImmNode::make);
 
index 7a3baa6..e03e5e2 100644 (file)
@@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
 }
 
 bool Analyzer::CanProve(const PrimExpr& expr) {
-  if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
+  if (const auto* ptr = expr.as<IntImmNode>()) {
     return ptr->value != 0;
   }
   auto res = this->rewrite_simplify(expr);
-  if (const auto* ptr = res.as<ir::UIntImmNode>()) {
+  if (const auto* ptr = res.as<IntImmNode>()) {
     return ptr->value != 0;
   }
   res = this->canonical_simplify(expr);
-  if (const auto* ptr = res.as<ir::UIntImmNode>()) {
+  if (const auto* ptr = res.as<IntImmNode>()) {
     return ptr->value != 0;
   }
   return false;
index 5f721d7..90c6e48 100644 (file)
@@ -737,7 +737,7 @@ VisitExpr_(const DivNode* op) {
   // const folding
   PrimExpr const_res = TryConstFold<DivNode>(a, b);
   if (const_res.defined()) return const_res;
-  PVar<Integer> c1;
+  PVar<IntImm> c1;
   // x / c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
     int64_t cval = c1.Eval()->value;
@@ -797,7 +797,7 @@ VisitExpr_(const FloorDivNode* op) {
   // const folding
   PrimExpr const_res = TryConstFold<FloorDivNode>(a, b);
   if (const_res.defined()) return const_res;
-  PVar<Integer> c1;
+  PVar<IntImm> c1;
   // x / c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
     int64_t cval = c1.Eval()->value;
@@ -905,7 +905,7 @@ VisitExpr_(const ModNode* op) {
   PrimExpr const_res = TryConstFold<ModNode>(a, b);
   if (const_res.defined()) return const_res;
 
-  PVar<Integer> c1;
+  PVar<IntImm> c1;
   // x % c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
     int64_t cval = c1.Eval()->value;
@@ -975,7 +975,7 @@ VisitExpr_(const FloorModNode* op) {
   PrimExpr const_res = TryConstFold<FloorModNode>(a, b);
   if (const_res.defined()) return const_res;
 
-  PVar<Integer> c1;
+  PVar<IntImm> c1;
   // x % c1
   if (c1.Match(b) && c1.Eval()->value > 0) {
     int64_t cval = c1.Eval()->value;
index 55c156d..3b803ec 100644 (file)
@@ -76,8 +76,6 @@ inline bool IsIndexType(const DataType& type) {
 
 
 #define TVM_ARITH_CONST_PROPAGATION(BODY)                               \
-  using ir::IntImmNode;                                                 \
-  using ir::UIntImmNode;                                                \
   using ir::FloatImmNode;                                               \
   const IntImmNode* pa = a.as<IntImmNode>();                            \
   const IntImmNode* pb = b.as<IntImmNode>();                            \
@@ -87,8 +85,6 @@ inline bool IsIndexType(const DataType& type) {
 
 
 #define TVM_INDEX_CONST_PROPAGATION(BODY)                               \
-  using ir::IntImmNode;                                                 \
-  using ir::UIntImmNode;                                                \
   const IntImmNode* pa = a.as<IntImmNode>();                            \
   const IntImmNode* pb = b.as<IntImmNode>();                            \
   const DataType& ta = a.dtype();                                       \
@@ -103,7 +99,7 @@ template<>
 inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
+      if (pa && pb) return IntImm(rtype, pa->value + pb->value);
       if (pa && pa->value == 0) return b;
       if (pb && pb->value == 0) return a;
       if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
@@ -117,7 +113,7 @@ template<>
 inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
+      if (pa && pb) return IntImm(rtype, pa->value - pb->value);
       if (pb && pb->value == 0) return a;
       if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
       if (fb && fb->value == 0) return a;
@@ -129,7 +125,7 @@ template<>
 inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
+      if (pa && pb) return IntImm(rtype, pa->value * pb->value);
       if (pa) {
         if (pa->value == 1) return b;
         if (pa->value == 0) return a;
@@ -159,7 +155,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
         // due to division and mod can have different modes
         // NOTE: this will assumes truc div.
         CHECK_NE(pb->value, 0) << "Divide by zero";
-        return IntImmNode::make(rtype, pa->value / pb->value);
+        return IntImm(rtype, pa->value / pb->value);
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -185,7 +181,7 @@ inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
-        return IntImmNode::make(rtype, pa->value % pb->value);
+        return IntImm(rtype, pa->value % pb->value);
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -204,7 +200,7 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
       const DataType& rtype = a.dtype();
       if (pa && pb) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
-        return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value));
+        return IntImm(rtype, arith::floordiv(pa->value, pb->value));
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -230,7 +226,7 @@ inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
-        return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value));
+        return IntImm(rtype, arith::floormod(pa->value, pb->value));
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -247,7 +243,7 @@ template<>
 inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
+      if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
       if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
@@ -258,7 +254,7 @@ template<>
 inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
+      if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
       if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
@@ -268,8 +264,8 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
     });
   return PrimExpr();
 }
@@ -277,8 +273,8 @@ inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
     });
   return PrimExpr();
 }
@@ -286,8 +282,8 @@ inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
     });
   return PrimExpr();
 }
@@ -295,8 +291,8 @@ inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
     });
   return PrimExpr();
 }
@@ -304,8 +300,8 @@ inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
     });
   return PrimExpr();
 }
@@ -313,17 +309,16 @@ inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
 template<>
 inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
-      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
+      if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
+      if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
     });
   return PrimExpr();
 }
 
 template<>
 inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
-  using ir::UIntImmNode;
-  const UIntImmNode* pa = a.as<UIntImmNode>();
-  const UIntImmNode* pb = b.as<UIntImmNode>();
+  const IntImmNode* pa = a.as<IntImmNode>();
+  const IntImmNode* pb = b.as<IntImmNode>();
   if (pa && pa->value) return b;
   if (pa && !pa->value) return a;
   if (pb && pb->value) return a;
@@ -333,9 +328,8 @@ inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
 
 template<>
 inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
-  using ir::UIntImmNode;
-  const UIntImmNode* pa = a.as<UIntImmNode>();
-  const UIntImmNode* pb = b.as<UIntImmNode>();
+  const IntImmNode* pa = a.as<IntImmNode>();
+  const IntImmNode* pb = b.as<IntImmNode>();
   if (pa && pa->value) return a;
   if (pa && !pa->value) return b;
   if (pb && pb->value) return b;
@@ -345,10 +339,9 @@ inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
 
 template<>
 inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) {
-  using ir::UIntImmNode;
-  const UIntImmNode* pa = a.as<UIntImmNode>();
+  const IntImmNode* pa = a.as<IntImmNode>();
   if (pa) {
-    return UIntImmNode::make(DataType::UInt(1), !(pa->value));
+    return IntImm(DataType::UInt(1), !(pa->value));
   }
   return PrimExpr();
 }
index a041e40..25d88d3 100644 (file)
@@ -150,14 +150,6 @@ class ConstIntBoundAnalyzer::Impl :
     return MakeBound(op->value, op->value);
   }
 
-  Entry VisitExpr_(const UIntImmNode* op) final {
-    if (op->value <= static_cast<uint64_t>(kPosInf)) {
-      return MakeBound(op->value, op->value);
-    } else {
-      return Everything(op->dtype);
-    }
-  }
-
   Entry VisitExpr_(const AddNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
@@ -496,7 +488,7 @@ class ConstIntBoundAnalyzer::Impl :
    */
   static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
     PVar<PrimExpr> x, y;
-    PVar<Integer> c;
+    PVar<IntImm> c;
     // NOTE: canonical form always use <= or <
     if ((c <= x).Match(cond)) {
       return {BoundInfo(x.Eval(), MakeBound(c.Eval()->value, kPosInf))};
index ceaa976..37d5e9e 100644 (file)
@@ -384,10 +384,6 @@ class IntervalSetEvaluator :
     return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
   }
 
-  IntervalSet VisitExpr_(const UIntImmNode* op) final {
-    return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
-  }
-
   IntervalSet VisitExpr_(const VarNode* op) final {
     Var var = GetRef<Var>(op);
     auto it = dom_map_.find(var);
@@ -476,7 +472,7 @@ class IntervalSetEvaluator :
   IntervalSet VisitExpr_(const RampNode* op) final {
     CHECK(eval_vec_);
     IntervalSet base = Eval(op->base);
-    PVar<Integer> stride;
+    PVar<IntImm> stride;
     if (stride.Match(op->stride)) {
       DataType t = op->base.dtype();
       int64_t vstride = stride.Eval()->value;
index 01dd2e8..c818420 100644 (file)
@@ -109,7 +109,7 @@ class ModularSetAnalyzer::Impl :
   // Detect useful constraints and use them in the analysis scope.
   std::function<void()> EnterConstraint(const PrimExpr& constraint) {
     PVar<Var> var;
-    PVar<Integer> coeff, base;
+    PVar<IntImm> coeff, base;
     // pattern match interesting constraints
     if ((truncmod(var, coeff) == base).Match(constraint) ||
         (floormod(var, coeff) == base).Match(constraint)) {
@@ -132,14 +132,6 @@ class ModularSetAnalyzer::Impl :
     return Entry(0, op->value);
   }
 
-  Entry VisitExpr_(const UIntImmNode* op) final {
-    if (op->value < std::numeric_limits<int64_t>::max()) {
-      return Entry(0, static_cast<int>(op->value));
-    } else {
-      return Everything();
-    }
-  }
-
   Entry VisitExpr_(const AddNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
index 733dcf4..a236e65 100644 (file)
@@ -45,7 +45,7 @@
  *  }
  *
  *  tvm::Var tx, ty;
- *  arith::PVar<Integer> c;
+ *  arith::PVar<IntImm> c;
  *  arith::PVar<Var> v;
  *  // We can match integer and Var, both of which are
  *  // special case container of Expr
@@ -140,9 +140,9 @@ class PEqualChecker<PrimExpr> {
 };
 
 template<>
-class PEqualChecker<Integer> {
+class PEqualChecker<IntImm> {
  public:
-  bool operator()(const Integer& lhs, const Integer& rhs) const {
+  bool operator()(const IntImm& lhs, const IntImm& rhs) const {
     return lhs->value == rhs->value;
   }
 };
index 94d951d..e6e1524 100644 (file)
@@ -124,7 +124,7 @@ VisitExpr_(const AddNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2, c3;
+  PVar<IntImm> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
   // Vector rules
@@ -239,7 +239,7 @@ VisitExpr_(const SubNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2, c3;
+  PVar<IntImm> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
   // Vector rules
@@ -438,7 +438,7 @@ VisitExpr_(const MulNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
   // Vector rules
@@ -477,7 +477,7 @@ VisitExpr_(const DivNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2, c3;
+  PVar<IntImm> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
 
@@ -700,7 +700,7 @@ VisitExpr_(const ModNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
 
@@ -789,7 +789,7 @@ VisitExpr_(const FloorDivNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2, c3;
+  PVar<IntImm> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
 
@@ -934,7 +934,7 @@ VisitExpr_(const FloorModNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   // Pattern var for lanes in broadcast and ramp
   PVar<int> lanes;
 
@@ -1004,7 +1004,7 @@ VisitExpr_(const MinNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   PVar<int> lanes;
 
   // vector rule
@@ -1189,7 +1189,7 @@ VisitExpr_(const MaxNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   PVar<int> lanes;
 
   // vector rule
@@ -1362,7 +1362,7 @@ VisitExpr_(const EQNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y;
   // Pattern var match IntImm
-  PVar<Integer> c1;
+  PVar<IntImm> c1;
   PVar<int> lanes;
 
   // vector rule
@@ -1416,7 +1416,7 @@ VisitExpr_(const LTNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   PVar<int> lanes;
 
   // vector rule
@@ -1597,7 +1597,7 @@ VisitExpr_(const AndNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   PVar<int> lanes;
 
   if (op->dtype.lanes() != 1) {
@@ -1646,7 +1646,7 @@ VisitExpr_(const OrNode* op) {
   // Pattern var to match any expression
   PVar<PrimExpr> x, y;
   // Pattern var match IntImm
-  PVar<Integer> c1, c2;
+  PVar<IntImm> c1, c2;
   PVar<int> lanes;
 
   if (op->dtype.lanes() != 1) {
index cf138ed..55ed36c 100644 (file)
@@ -256,7 +256,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
 
     Array<PrimExpr> attr{std::string("_attr_"),
                      FloatImmNode::make(DataType::Float(32), trans(fea.length)),
-                     IntImmNode::make(DataType::Int(32), fea.nest_level),
+                     IntImm(DataType::Int(32), fea.nest_level),
                      FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
                      FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
     };
index 777ad62..d9b7f7f 100644 (file)
@@ -372,16 +372,17 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { //
   }
 }
 
-inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
-  if (op->dtype == DataType::UInt(32)) {
+
+inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+  if (dtype == DataType::UInt(32)) {
     std::ostringstream temp;
-    temp << op->value << "U";
+    temp << val << "U";
     p->MarkConst(temp.str());
     os << temp.str();
   } else {
     os << "(";
-    p->PrintType(op->dtype, os);
-    os << ")" << op->value;
+    p->PrintType(dtype, os);
+    os << ")" << val;
   }
 }
 
@@ -408,9 +409,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) {
 void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
   PrintConst(op, os, this);
 }
-void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) {  // NOLINT(*)
-  PrintConst(op, os, this);
-}
+
 void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
   PrintConst(op, os, this);
 }
@@ -528,6 +527,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
     os << ")";
   } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     PrintBinaryIntrinsic(op, " & ", os, this);
+  } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+    CHECK_EQ(op->args.size(), 2U);
+    uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+    uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+    uint64_t val = (high << 32U) | low;
+    PrintUIntConst(op->dtype, val, os, this);
   } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     PrintBinaryIntrinsic(op, " ^ ", os, this);
   } else if (op->is_intrinsic(CallNode::bitwise_or)) {
index cb092c5..7e5dd42 100644 (file)
@@ -128,7 +128,6 @@ class CodeGenC :
   void VisitExpr_(const ShuffleNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const UIntImmNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
   // statment
index 7967c18..cea276d 100644 (file)
@@ -247,11 +247,6 @@ void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) {
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
-  CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
-  CodeGenC::VisitExpr_(op, os);
-}
-
 void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
   CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
   CodeGenC::VisitExpr_(op, os);
index cd1ec83..19ca2ee 100644 (file)
@@ -50,7 +50,6 @@ class CodeGenOpenGL final : public CodeGenC {
 
   // Codegen for immediate values
   void VisitExpr_(const IntImmNode* op, std::ostream& os) final;  // NOLINT(*)
-  void VisitExpr_(const UIntImmNode* op, std::ostream& os) final;  // NOLINT(*)
   void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;  // NOLINT(*)
   void VisitExpr_(const StringImmNode* op, std::ostream& os) final;  // NOLINT(*)
 
index 6879fd5..44862cf 100644 (file)
@@ -48,7 +48,7 @@ class CodeGenARM final : public CodeGenCPU {
 llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic("llvm_intrin")) {
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
-        op->args[0].as<UIntImmNode>()->value);
+        Downcast<IntImm>(op->args[0])->value);
     if (id == ::llvm::Intrinsic::ctpop) {
       PrimExpr e = ARMPopcount(op);
       return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
@@ -68,8 +68,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
   if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
      (total_size != 128 && total_size != 64)) {
     Array<PrimExpr> vcnt_args;
-    vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
-    vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+    vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
+    vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
     vcnt_args.push_back(e);
     return ir::CallNode::make(call->dtype,  "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
   }
@@ -93,16 +93,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
   const CallNode* c0 = input8.as<CallNode>();
   CHECK(c0 != nullptr);
   Array<PrimExpr> vcnt8_args;
-  vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
-  vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+  vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
+  vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt8_args.push_back(input8);
   PrimExpr vcnt8 = ir::CallNode::make(
     uint8_type,  "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
 
   // Accumulation 8->16bit
   Array<PrimExpr> vcnt16_args;
-  vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
-  vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+  vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+  vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt16_args.push_back(vcnt8);
   PrimExpr vcnt16 = ir::CallNode::make(
     uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
@@ -112,8 +112,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
 
   // Accumulation 16->32bit
   Array<PrimExpr> vcnt32_args;
-  vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
-  vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+  vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+  vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt32_args.push_back(vcnt16);
   PrimExpr vcnt32 = ir::CallNode::make(
     uint32_type,  "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
@@ -123,8 +123,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
 
   // Accumulation 32->64bit
   Array<PrimExpr> vcnt64_args;
-  vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
-  vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
+  vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
+  vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt64_args.push_back(vcnt32);
   return ir::CallNode::make(
     call->dtype,  "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
index c04a023..60d8146 100644 (file)
@@ -662,15 +662,13 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic("llvm_intrin")) {
     CHECK_GE(op->args.size(), 2U);
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
-        op->args[0].as<UIntImmNode>()->value);
-    const uint64_t *num_signature = as_const_uint(op->args[1]);
-    CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
-                         << "but " << op->args[1] << " got!\n";
+        Downcast<IntImm>(op->args[0])->value);
+    int64_t num_signature  = Downcast<IntImm>(op->args[1])->value;
     std::vector<llvm::Value*> arg_value;
     std::vector<llvm::Type*> sig_type;
     for (size_t i = 2; i < op->args.size(); ++i) {
       arg_value.push_back(MakeValue(op->args[i]));
-      if (i - 2 < *num_signature) {
+      if (i - 2 < static_cast<size_t>(num_signature)) {
         sig_type.push_back(arg_value.back()->getType());
       }
     }
@@ -722,6 +720,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
     return llvm::Constant::getNullValue(t_void_p_);
   } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
     return builder_->CreateIsNull(MakeValue(op->args[0]));
+  } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+    CHECK_EQ(op->args.size(), 2U);
+    uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+    uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+    uint64_t val = (high << 32U) | low;
+    return llvm::ConstantInt::get(LLVMType(op->dtype), val);
   } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
     CHECK_EQ(op->args[0].dtype().lanes(), 1)
         << "if_then_else can only take scalar condition";
@@ -804,10 +808,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
   return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) {
-  return llvm::ConstantInt::get(LLVMType(op->dtype), op->value);
-}
-
 llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
   return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
 }
index 34c3ee7..b269f24 100644 (file)
@@ -106,7 +106,6 @@ class CodeGenLLVM :
   llvm::Value* VisitExpr_(const VarNode* op) override;
   llvm::Value* VisitExpr_(const CastNode* op) override;
   llvm::Value* VisitExpr_(const IntImmNode* op) override;
-  llvm::Value* VisitExpr_(const UIntImmNode* op) override;
   llvm::Value* VisitExpr_(const FloatImmNode* op) override;
   llvm::Value* VisitExpr_(const StringImmNode* op) override;
   llvm::Value* VisitExpr_(const AddNode* op) override;
index 03656cc..11bda70 100644 (file)
@@ -96,8 +96,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
                 MakeValue(
                     ir::BroadcastNode::make(
                       ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
-                /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
-                /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
+                /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
+                /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
           });
     }
 
index b3ab557..1f839f3 100644 (file)
@@ -43,8 +43,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   CHECK(call != nullptr);
   Array<PrimExpr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
-  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
+  cargs.push_back(IntImm(DataType::UInt(32), id));
+  cargs.push_back(IntImm(DataType::UInt(32), num_signature));
 
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
@@ -60,8 +60,8 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   CHECK(call != nullptr);
   Array<PrimExpr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
-  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
+  cargs.push_back(IntImm(DataType::UInt(32), id));
+  cargs.push_back(IntImm(DataType::UInt(32), num_signature));
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
index a749424..985f681 100644 (file)
@@ -136,10 +136,6 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
   return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) {
-  return builder_->UIntImm(builder_->GetSType(op->dtype), op->value);
-}
-
 spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
   return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
 }
@@ -242,7 +238,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
 spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
   if (op->is_intrinsic("spirv_glsl450")) {
     CHECK_GE(op->args.size(), 2U);
-    uint32_t inst_id = op->args[0].as<UIntImmNode>()->value;
+    uint32_t inst_id = static_cast<uint32_t>(
+        op->args[0].as<IntImmNode>()->value);
     std::vector<spirv::Value> values;
     for (size_t i = 1; i < op->args.size(); ++i) {
       values.push_back(MakeValue(op->args[i]));
@@ -285,6 +282,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
   } else if (op->is_intrinsic(CallNode::reinterpret)) {
     return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
                                MakeValue(op->args[0]));
+  } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) {
+    CHECK_EQ(op->args.size(), 2U);
+    uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
+    uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
+    uint64_t val = (high << 32U) | low;
+    return builder_->UIntImm(builder_->GetSType(op->dtype), val);
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
     return this->CreateStorageSync(op);
   } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
index 3804bda..5aa7f9c 100644 (file)
@@ -65,7 +65,6 @@ class CodeGenSPIRV:
   spirv::Value VisitExpr_(const VarNode* op) override;
   spirv::Value VisitExpr_(const CastNode* op) override;
   spirv::Value VisitExpr_(const IntImmNode* op) override;
-  spirv::Value VisitExpr_(const UIntImmNode* op) override;
   spirv::Value VisitExpr_(const FloatImmNode* op) override;
   spirv::Value VisitExpr_(const StringImmNode* op) override;
   spirv::Value VisitExpr_(const AddNode* op) override;
index d41d96d..d96883e 100644 (file)
@@ -39,7 +39,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   CHECK(call != nullptr);
   Array<PrimExpr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+  cargs.push_back(IntImm(DataType::UInt(32), id));
 
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
index 6f8d96e..bf43f11 100644 (file)
@@ -342,9 +342,9 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) {
   if (dtype.type == DataType::UInt(1)) {
     // bool types.
     if (*pvalue) {
-      ib_.Begin(spv::OpConstantTrue).AddSeq(ret);
+      ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
     } else {
-      ib_.Begin(spv::OpConstantFalse).AddSeq(ret);
+      ib_.Begin(spv::OpConstantFalse).AddSeq(dtype, ret);
     }
   } else {
     // Integral/floating-point types.
index 3843cbb..5d25e86 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index eccff6c..01096ae 100644 (file)
@@ -280,12 +280,6 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) {
     this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
 }
 
-void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) {
-  CHECK(op->value <= std::numeric_limits<int>::max())
-      << "Int constant exceed bound";
-  this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
-}
-
 void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
   LOG(FATAL) << "Float Imm is not supported";
 }
index 07989b2..1360cc2 100644 (file)
@@ -136,7 +136,6 @@ class CodeGenStackVM
   void VisitExpr_(const RampNode* op) final;
   void VisitExpr_(const BroadcastNode* op) final;
   void VisitExpr_(const IntImmNode* op) final;
-  void VisitExpr_(const UIntImmNode* op) final;
   void VisitExpr_(const FloatImmNode* op) final;
   void VisitExpr_(const StringImmNode* op) final;
   // statment
index 7e3d44f..346ec38 100644 (file)
@@ -79,10 +79,7 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
 void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
   os << op->value;
 }
-void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) {  // NOLINT(*)
-  PrintType(op->dtype, os);
-  os << "(" << op->value << ")";
-}
+
 void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
   PrintType(op->dtype, os);
   os << "(" << std::setprecision(20) << op->value << ")";
index 89a1ece..33bd0ef 100644 (file)
@@ -117,7 +117,6 @@ class CodeGenHybrid :
   void VisitExpr_(const RampNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const UIntImmNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
   // statment
index f698a5d..6d89967 100644 (file)
 
 namespace tvm {
 
+IntImm::IntImm(DataType dtype, int64_t value) {
+  CHECK(dtype.is_scalar())
+      << "ValueError: IntImm can only take scalar.";
+  CHECK(dtype.is_int() || dtype.is_uint())
+      << "ValueError: IntImm can only take scalar.";
+  if (dtype.is_uint()) {
+    CHECK_GE(value, 0U);
+  }
+  ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
+  node->dtype = dtype;
+  node->value = value;
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("make.IntImm")
+.set_body_typed([](DataType dtype, int64_t value) {
+  return IntImm(dtype, value);
+});
+
 GlobalVar::GlobalVar(std::string name_hint) {
   ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
   n->name_hint = std::move(name_hint);
index 34ee4b3..4fffc47 100644 (file)
@@ -77,7 +77,6 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
   virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   // deep comparison of symbolic integer expressions.
@@ -113,7 +112,6 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
     ATTR_FUNCTOR_DISPATCH(StrMapNode);
     ATTR_FUNCTOR_DISPATCH(ArrayNode);
     ATTR_FUNCTOR_DISPATCH(IntImmNode);
-    ATTR_FUNCTOR_DISPATCH(UIntImmNode);
     ATTR_FUNCTOR_DISPATCH(FloatImmNode);
     ATTR_FUNCTOR_DISPATCH(StringImmNode);
     ATTR_FUNCTOR_DISPATCH(VarNode);
@@ -157,7 +155,6 @@ class AttrsEqualHandler :
   bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final;
@@ -198,7 +195,6 @@ class AttrsHashHandler :
  protected:
   size_t VisitAttrDefault_(const Object* lhs) final;
   size_t VisitAttr_(const ir::IntImmNode* lhs) final;
-  size_t VisitAttr_(const ir::UIntImmNode* lhs) final;
   size_t VisitAttr_(const ir::FloatImmNode* lhs) final;
   size_t VisitAttr_(const ir::StringImmNode* lhs) final;
   size_t VisitAttr_(const ArrayNode* lhs) final;
index 1d3e767..a590f10 100644 (file)
@@ -97,13 +97,6 @@ bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other
   return false;
 }
 
-bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<UIntImmNode>()) {
-    return lhs->value == rhs->value;
-  }
-  return false;
-}
-
 bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
   if (const auto* rhs = other.as<FloatImmNode>()) {
     return lhs->value == rhs->value;
@@ -224,10 +217,6 @@ size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
   return std::hash<int64_t>()(op->value);
 }
 
-size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
-  return std::hash<uint64_t>()(op->value);
-}
-
 size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
   return std::hash<double>()(op->value);
 }
index a728936..55dfb89 100644 (file)
@@ -30,7 +30,7 @@
 namespace tvm {
 
 PrimExpr::PrimExpr(int32_t value)
-    : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {}
+    : PrimExpr(IntImm(DataType::Int(32), value)) {}
 
 PrimExpr::PrimExpr(float value)
     : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
@@ -54,15 +54,6 @@ Range::Range(PrimExpr begin, PrimExpr end)
           is_zero(begin) ? end : (end - begin))) {
 }
 
-Integer IntImmNode::make(DataType t, int64_t value) {
-  CHECK(t.is_int() && t.is_scalar())
-      << "ValueError: IntImm can only take scalar.";
-  ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
-  node->dtype = t;
-  node->value = value;
-  return Integer(node);
-}
-
 Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
   return Range(make_object<RangeNode>(min, extent));
 }
index d3875e2..bd43d89 100644 (file)
@@ -35,6 +35,14 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
   return ir::CastNode::make(t, value);
 }
 
+PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
+  return ir::CallNode::make(
+      t, ir::intrinsic::tvm_large_uint_imm,
+      {make_const(DataType::UInt(32), low),
+       make_const(DataType::UInt(32), high)},
+      ir::CallNode::PureIntrinsic);
+}
+
 // The public function with a quick checking path.
 void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) {  // NOLINT(*)
   if (lhs.dtype() == rhs.dtype()) return;
@@ -78,26 +86,25 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) {  // NOLINT(*)
   }
 }
 
-
 // maximum and min limits
 PrimExpr max_value(const DataType& dtype) {
   using namespace ir;
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
     if (dtype.bits() == 64) {
-      return IntImmNode::make(dtype, std::numeric_limits<int64_t>::max());
+      return IntImm(dtype, std::numeric_limits<int64_t>::max());
     } else if (dtype.bits() < 64) {
       int64_t val = 1;
       val = (val << (dtype.bits() - 1)) - 1;
-      return IntImmNode::make(dtype, val);
+      return IntImm(dtype, val);
     }
   } else if (dtype.is_uint()) {
     if (dtype.bits() == 64) {
-      return UIntImmNode::make(dtype, std::numeric_limits<uint64_t>::max());
+      return make_const(dtype, std::numeric_limits<uint64_t>::max());
     } else if (dtype.bits() < 64) {
       uint64_t val = 1;
       val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
-      return UIntImmNode::make(dtype, val);
+      return IntImm(dtype, static_cast<int64_t>(val));
     }
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
@@ -117,14 +124,14 @@ PrimExpr min_value(const DataType& dtype) {
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
     if (dtype.bits() == 64) {
-      return IntImmNode::make(dtype, std::numeric_limits<int64_t>::lowest());
+      return IntImm(dtype, std::numeric_limits<int64_t>::lowest());
     } else if (dtype.bits() < 64) {
       int64_t val = 1;
       val = -(val << (dtype.bits() - 1));
-      return IntImmNode::make(dtype, val);
+      return IntImm(dtype, val);
     }
   } else if (dtype.is_uint()) {
-    return UIntImmNode::make(dtype, 0);
+    return IntImm(dtype, 0);
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
       return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
@@ -155,24 +162,18 @@ inline bool ConstPowerHelper(ValueType val, int *shift) {
 bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
   if (const auto* op = x.as<ir::IntImmNode>()) {
     return ConstPowerHelper(op->value, shift);
-  } else if (const auto* op = x.as<ir::UIntImmNode>()) {
-    return ConstPowerHelper(op->value, shift);
   } else {
     return false;
   }
 }
 
 PrimExpr cast(const DataType& t, PrimExpr value) {
-  using ir::IntImmNode;
-  using ir::UIntImmNode;
   using ir::FloatImmNode;
   if (value.dtype() == t) return value;
   // const fold IntImm as they are used in index computations
   if (t.lanes() == 1) {
     if (const IntImmNode* op = value.as<IntImmNode>()) {
       return make_const(t, op->value);
-    } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
-      return make_const(t, op->value);
     } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
       return make_const(t, op->value);
     }
@@ -184,8 +185,6 @@ PrimExpr cast(const DataType& t, PrimExpr value) {
       if (value.dtype() != vtype) {
         if (const IntImmNode* op = value.as<IntImmNode>()) {
           value = make_const(vtype, op->value);
-        } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
-          return make_const(t, op->value);
         } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
           value = make_const(vtype, op->value);
         } else {
@@ -219,7 +218,7 @@ PrimExpr operator-(PrimExpr a) {
   using ir::FloatImmNode;
   const IntImmNode* pa = a.as<IntImmNode>();
   const FloatImmNode* fa = a.as<FloatImmNode>();
-  if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value);
+  if (pa) return IntImm(a.dtype(), -pa->value);
   if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
   return make_zero(a.dtype()) - a;
 }
@@ -322,18 +321,10 @@ PrimExpr max(PrimExpr a, PrimExpr b) {
 }
 
 PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
-  using ir::IntImmNode;
-  using ir::UIntImmNode;
   CHECK(cond.dtype() == DataType::Bool(1))
       << "if_then_else only accept the condition to be boolean type.";
   BinaryOpMatchTypes(true_value, false_value);
-  if (const UIntImmNode* op = cond.as<UIntImmNode>()) {
-    if (op->value != 0) {
-      return true_value;
-    } else {
-      return false_value;
-    }
-  } else if (const IntImmNode* op = cond.as<IntImmNode>()) {
+  if (const IntImmNode* op = cond.as<IntImmNode>()) {
     if (op->value != 0) {
       return true_value;
     } else {
@@ -424,7 +415,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value));
+      if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
       if (pb) {
         if (pb->value == 0) return a;
       }
@@ -437,7 +428,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value));
+      if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
       if (pb) {
         if (pb->value == 0) return a;
       }
@@ -450,7 +441,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value));
+      if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
     });
   return ir::CallNode::make(
     a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
@@ -460,7 +451,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value));
+      if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
     });
   return ir::CallNode::make(
     a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
@@ -470,7 +461,7 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value));
+      if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
     });
   return ir::CallNode::make(
     a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
@@ -494,7 +485,7 @@ PrimExpr abs(PrimExpr x) {
     using ir::IntImmNode;
     const IntImmNode* px = x.as<IntImmNode>();
     if (px) {
-      return ir::IntImmNode::make(x.dtype(), std::abs(px->value));
+      return IntImm(x.dtype(), std::abs(px->value));
     }
     return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x);
   } else if (x.dtype().is_float()) {
index ad7f260..f06a6be 100644 (file)
@@ -31,14 +31,6 @@ namespace tvm {
 namespace ir {
 
 // constructors
-PrimExpr UIntImmNode::make(DataType t, uint64_t value) {
-  CHECK(t.is_uint() && t.lanes() == 1)
-      << "ValueError: UIntImm can only take scalar";
-  ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
-  node->dtype = t;
-  node->value = value;
-  return PrimExpr(node);
-}
 
 PrimExpr FloatImmNode::make(DataType t, double value) {
   CHECK_EQ(t.lanes(), 1)
@@ -248,7 +240,7 @@ PrimExpr ShuffleNode::make_concat(Array<PrimExpr> vectors) {
   int index = 0;
   for (const PrimExpr& e : vectors) {
     for (int i = 0; i < e.dtype().lanes(); ++i) {
-      indices.push_back(IntImmNode::make(DataType::Int(32), index++));
+      indices.push_back(IntImm(DataType::Int(32), index++));
     }
   }
   return make(vectors, indices);
@@ -531,11 +523,6 @@ Stmt EvaluateNode::make(PrimExpr value) {
 }
 
 // Printers
-TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<UIntImmNode>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const UIntImmNode*>(node.get());
-    p->stream << "(" << op->dtype << ")" << op->value;
-  });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 .set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
@@ -1153,7 +1140,6 @@ TVM_REGISTER_NODE_TYPE(AnyNode);
 TVM_REGISTER_NODE_TYPE(AttrStmtNode);
 TVM_REGISTER_NODE_TYPE(FloatImmNode);
 TVM_REGISTER_NODE_TYPE(IntImmNode);
-TVM_REGISTER_NODE_TYPE(UIntImmNode);
 TVM_REGISTER_NODE_TYPE(StringImmNode);
 TVM_REGISTER_NODE_TYPE(CastNode);
 TVM_REGISTER_NODE_TYPE(VarNode);
index 2c04de3..0f350d2 100644 (file)
@@ -179,11 +179,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   std::ostringstream type_err_msg;
   type_err_msg << arg_name << ".dtype is expected to be " << dtype;
   PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
-               UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
+               IntImm(DataType::UInt(8), dtype.code()) &&
                TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
-               UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
+               IntImm(DataType::UInt(8), dtype.bits()) &&
                TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
-               UIntImmNode::make(DataType::UInt(16), dtype.lanes()));
+               IntImm(DataType::UInt(16), dtype.lanes()));
   asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
   // data field
   if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
@@ -193,7 +193,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
     // mark alignment of external bufs
     init_nest_.emplace_back(AttrStmtNode::make(
         vptr, ir::attr::storage_alignment,
-        IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop));
+        IntImm(DataType::Int(32), buffer->data_alignment), nop));
   }
 
   Var v_shape(arg_name + ".shape", DataType::Handle());
@@ -206,7 +206,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
     Bind_(buffer->shape[k],
           cast(buffer->shape[k].dtype(),
                LoadNode::make(tvm_shape_type, v_shape,
-                          IntImmNode::make(DataType::Int(32), k), const_true(1))),
+                          IntImm(DataType::Int(32), k), const_true(1))),
           field_name.str(), true);
   }
   // strides field
@@ -228,7 +228,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       PrimExpr svalue = cast(
           stype,
           LoadNode::make(tvm_shape_type, v_strides,
-                     IntImmNode::make(DataType::Int(32), k), const_true(1)));
+                     IntImm(DataType::Int(32), k), const_true(1)));
       conds.push_back(expect_stride == svalue);
       expect_stride = expect_stride * buffer->shape[k];
     }
@@ -251,7 +251,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       field_name << v_strides->name_hint << '[' << k << ']';
       PrimExpr value = cast(buffer->shape[k].dtype(),
                         LoadNode::make(tvm_shape_type, v_strides,
-                                   IntImmNode::make(DataType::Int(32), k), const_true(1)));
+                                   IntImm(DataType::Int(32), k), const_true(1)));
       value = tvm::if_then_else(is_null, stride, value);
       value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
       Bind_(buffer->strides[k], value, field_name.str(), true);
@@ -270,7 +270,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       Bind_(buffer->strides[k],
             cast(buffer->shape[k].dtype(),
                  LoadNode::make(tvm_shape_type, v_strides,
-                            IntImmNode::make(DataType::Int(32), k), const_true(1))),
+                            IntImm(DataType::Int(32), k), const_true(1))),
             field_name.str(), true);
     }
   }
index 6eacb14..8c44151 100644 (file)
@@ -252,10 +252,6 @@ class IRDeepCompare :
     CompareValue(op->value, other.as<IntImmNode>()->value);
   }
 
-  void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final {
-    CompareValue(op->value, other.as<UIntImmNode>()->value);
-  }
-
   void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final {
     CompareValue(op->value, other.as<FloatImmNode>()->value);
   }
index 67acec6..857206f 100644 (file)
@@ -260,7 +260,6 @@ DEFINE_BINOP_VISIT_(AndNode);
 DEFINE_BINOP_VISIT_(OrNode);
 
 void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
-void ExprVisitor::VisitExpr_(const UIntImmNode* op) {}
 void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
 void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
 
@@ -640,7 +639,6 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
   }
 
 DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode)
 DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
 DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
 
index 7b760fa..5aba355 100644 (file)
@@ -180,9 +180,6 @@ class AttrScopeLifter : public StmtMutator {
     if (const IntImmNode* op = a.as<IntImmNode>()) {
       return op->value == b.as<IntImmNode>()->value;
     }
-    if (const UIntImmNode* op = a.as<UIntImmNode>()) {
-      return op->value == b.as<UIntImmNode>()->value;
-    }
     return false;
   }
 
index ed8be8b..5684f4e 100644 (file)
@@ -173,7 +173,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   PrimExpr VisitExpr_(const MaxNode* op) final {
     using namespace arith;
     PVar<PrimExpr> x, y;
-    PVar<Integer> c;
+    PVar<IntImm> c;
     auto e = GetRef<PrimExpr>(op);
     if (max(floordiv(x, y), c).Match(e) &&
         c.Eval()->value >= 0 &&
index a0b07c2..d509169 100644 (file)
@@ -120,7 +120,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     const CommReducerNode *combiner = reduce_combiner_.back();
     size_t size = combiner->result.size();
 
-    const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
+    const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
     CHECK(size_of_args) << call->args[0]->GetTypeKey();
     CHECK_EQ(size, size_of_args->value);
     Array<PrimExpr> inits = combiner->identity_element;
index 8e7f1d8..01a97b7 100644 (file)
@@ -129,8 +129,8 @@ class BuiltinLower : public StmtExprMutator {
                        {cast(DataType::Int(32), device_type_),
                         cast(DataType::Int(32), device_id_),
                         cast(DataType::UInt(64), total_bytes),
-                        IntImmNode::make(DataType::Int(32), op->dtype.code()),
-                        IntImmNode::make(DataType::Int(32), op->dtype.bits())},
+                        IntImm(DataType::Int(32), op->dtype.code()),
+                        IntImm(DataType::Int(32), op->dtype.bits())},
                        CallNode::Extern),
         body);
 
index d5c73a2..5df36d0 100644 (file)
@@ -69,8 +69,8 @@ LoweredFunc MakeAPI(Stmt body,
   // load i-th argument as type t
   auto f_arg_value = [&](DataType t, int i) {
     Array<PrimExpr> call_args{v_packed_args,
-                          IntImmNode::make(DataType::Int(32), i),
-                          IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
+                          IntImm(DataType::Int(32), i),
+                          IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
     // load 64 bit version
     DataType api_type = APIType(t);
     PrimExpr res = CallNode::make(
@@ -117,7 +117,7 @@ LoweredFunc MakeAPI(Stmt body,
       seq_init.emplace_back(LetStmtNode::make(
           tcode, LoadNode::make(
               DataType::Int(32), v_packed_arg_type_ids,
-              IntImmNode::make(DataType::Int(32), i), const_true(1)),
+              IntImm(DataType::Int(32), i), const_true(1)),
           nop));
       DataType t = v_arg.dtype();
       if (t.is_handle()) {
index 224a81c..9fb19cc 100644 (file)
@@ -96,7 +96,6 @@ class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
     return false;
   }
   bool VisitExpr_(const VarNode* op) final { return false; }
-  bool VisitExpr_(const UIntImmNode* op) final { return false; }
   bool VisitExpr_(const IntImmNode* op) final { return false; }
   bool VisitExpr_(const FloatImmNode* op) final { return false; }
   bool VisitExpr_(const StringImmNode* op) final { return false; }
index bb57fe8..956f27c 100644 (file)
@@ -462,7 +462,7 @@ class BufferAnalyser : public StmtExprVisitor {
       strides = bi.strides;
     } else {
       for (size_t i = 1; i < bi.shape.size(); ++i) {
-        PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+        PrimExpr stride = IntImm(DataType::Int(32), 1);
         for (size_t j = bi.shape.size() - 1; j >= i; --j) {
           stride = MulNode::make(stride, bi.shape[j]);
         }
@@ -575,7 +575,7 @@ class BufferAnalyser : public StmtExprVisitor {
         strides = bi.strides;
       } else {
         for (size_t i = 1; i < bi.shape.size(); ++i) {
-          PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+          PrimExpr stride = IntImm(DataType::Int(32), 1);
           for (size_t j = bi.shape.size() - 1; j >= i; --j) {
             stride = MulNode::make(stride, bi.shape[j]);
           }
@@ -765,7 +765,7 @@ class ThreadIdxMutator : public StmtExprMutator {
     op = expr.as<VarNode>();
     if (op != nullptr) {
       if (op->name_hint == "threadIdx.x") {
-        PrimExpr zero = IntImmNode::make(DataType::Int(32), 0);
+        PrimExpr zero = IntImm(DataType::Int(32), 0);
         return zero;
       }
       if (op->name_hint == "threadIdx.y") {
@@ -934,7 +934,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       PrimExpr stride = strides[strides.size()-2];
 
       // thread index unification inside a warp
-      PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+      PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       PrimExpr mutated_value = thread_idx_mutator(op->value);
       PrimExpr src = CallNode::make(value->dtype,
@@ -984,7 +984,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
       PrimExpr dst = it3->second;
       // thread index unification inside a warp
-      PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+      PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       dst = thread_idx_mutator(dst);
       dst = CallNode::make(DataType::Handle(),
@@ -1089,7 +1089,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
     Array<PrimExpr> strides;
     for (size_t i = 1; i < shape.size(); ++i) {
-      PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
+      PrimExpr stride = IntImm(DataType::Int(32), 1);
       for (size_t j = shape.size() - 1; j >= i; --j) {
         stride = MulNode::make(stride, shape[j]);
       }
@@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
     }
     strides.push_back(make_const(DataType::Int(32), 1));
 
-    PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0);
+    PrimExpr elem_offset = IntImm(DataType::Int(32), 0);
     CHECK_EQ(call->args.size(), min_bound.size());
     for (size_t i = 0; i < min_bound.size(); i++) {
       elem_offset = AddNode::make(
index b2c50f7..26ad591 100644 (file)
@@ -159,14 +159,10 @@ class LoopUnroller : public StmtExprMutator {
     // constant folding.
     PrimExpr extent = ir::Simplify(op->extent);
     const IntImmNode  *v1 = extent.as<IntImmNode>();
-    const UIntImmNode *v2 = extent.as<UIntImmNode>();
     int value = -1;
     if (v1 != nullptr) {
       value = static_cast<int>(v1->value);
     }
-    if (v2 != nullptr) {
-      value = static_cast<int>(v2->value);
-    }
     return value;
   }
 
index 00c40b2..5ee4ce3 100644 (file)
@@ -88,7 +88,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
     if (pval != nullptr) {
       CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
       CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
-      res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval));
+      res.push_back(IntImm(DataType::Int(32), *pval));
     } else if (val->IsInstance<ir::AnyNode>()) {
       res.push_back(val.as<ir::AnyNode>()->ToVar());
     } else {
@@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     // set inputs
     for (auto param : prim_func->params) {
       int state = param_states_[param];
-      cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state));
+      cache_node->shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
       if (state & kNeedInputData) {
         for (auto t : param_data_[param]) {
           cache_node->inputs.push_back(t);
@@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     auto ret_type = call_node->checked_type();
     Array<IndexExpr> out_ndims;
     if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
-      out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
+      out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
     } else {
       auto rtype = ret_type.as<TupleTypeNode>();
       // TODO(@icemelon): Allow recursive tuple
@@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
       for (size_t i = 0; i < rtype->fields.size(); ++i) {
         auto ttype = rtype->fields[i].as<TensorTypeNode>();
         CHECK(ttype);
-        out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
+        out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
       }
     }
     // Call shape function
index f66cce6..9966d9c 100644 (file)
@@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const {
     CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
     CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
     shape.push_back(
-        tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i]));
+        tvm::IntImm(DataType::Int(32), data->shape[i]));
   }
 
   return TensorTypeNode::make(shape, dtype);
index 25650c7..400a6be 100644 (file)
@@ -857,10 +857,6 @@ class PrettyPrinter :
     return PrintConstScalar(op->dtype, &(op->value));
   }
 
-  Doc VisitAttr_(const ir::UIntImmNode* op) final {
-    return PrintConstScalar(op->dtype, &(op->value));
-  }
-
   Doc VisitAttr_(const ir::FloatImmNode* op) final {
     return PrintConstScalar(op->dtype, &(op->value));
   }
index 4d3a4b9..b5383cd 100644 (file)
@@ -852,7 +852,7 @@ bool ArgWhereRel(const Array<Type>& types,
   const auto& input_rank = input_shape.size();
   std::vector<IndexExpr> result_shape;
   result_shape.push_back(Any::make());
-  result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank));
+  result_shape.push_back(IntImm(DataType::Int(32), input_rank));
   reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
   return true;
 }
index 5946693..30a9a5c 100644 (file)
@@ -41,7 +41,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
   }
 
   bool Assert(const IndexExpr& cond) final {
-    if (const uint64_t* pdiff = as_const_uint(cond)) {
+    if (const int64_t* pdiff = as_const_int(cond)) {
       return pdiff[0];
     }
     return true;
index 378a5e3..2e33241 100644 (file)
@@ -47,14 +47,10 @@ static inline Array<IndexExpr> get_shape(const Type& type) {
 static inline const int32_t GetQmin(const DataType& dtype) {
   CHECK_LE(dtype.bits(), 32)
       << "QNN ops support int32 or lower precision";
-  if (dtype.is_int()) {
+  if (dtype.is_int() || dtype.is_uint()) {
     auto* min_value = as_const_int(tvm::min_value(dtype));
     CHECK(min_value != nullptr);
     return static_cast<int32_t>(min_value[0]);
-  } else if (dtype.is_uint()) {
-    auto* min_value = as_const_uint(tvm::min_value(dtype));
-    CHECK(min_value != nullptr);
-    return static_cast<int32_t>(min_value[0]);
   } else {
     LOG(FATAL) << "Type not supported " << dtype;
     return -1;  // To hide the warning
@@ -64,14 +60,10 @@ static inline const int32_t GetQmin(const DataType& dtype) {
 static inline const int32_t GetQmax(const DataType& dtype) {
   CHECK_LE(dtype.bits(), 32)
       << "QNN ops support int32 or lower precision";
-  if (dtype.is_int()) {
+  if (dtype.is_int() || dtype.is_uint()) {
     auto* max_value = as_const_int(tvm::max_value(dtype));
     CHECK(max_value != nullptr);
     return static_cast<int32_t>(max_value[0]);
-  } else if (dtype.is_uint()) {
-    auto* max_value = as_const_uint(tvm::max_value(dtype));
-    CHECK(max_value != nullptr);
-    return static_cast<int32_t>(max_value[0]);
   } else {
     LOG(FATAL) << "Type not supported " << dtype;
     return -1;  // To hide the warning
index 5392eae..193f2f2 100644 (file)
@@ -127,10 +127,10 @@ TEST(Pattern, Basic) {
   }
 }
 
-TEST(Pattern, Integer) {
+TEST(Pattern, IntImm) {
   using namespace tvm;
   tvm::Var tx, ty;
-  arith::PVar<Integer> c;
+  arith::PVar<IntImm> c;
   arith::PVar<Var> v;
   {
     // We can match integer and Var, both of which are
index 45ecf95..5a10618 100644 (file)
@@ -18,6 +18,32 @@ import tvm
 from tvm.contrib import util
 import numpy as np
 
+def test_large_uint_imm():
+    value =  (1 << 63) + 123
+    other = tvm.const(3, "uint64")
+    n = 12
+    num_thread = 2
+
+    A = tvm.compute((n,), lambda *i: tvm.const(value, "uint64") + other, name='A')
+    s = tvm.create_schedule(A.op)
+    xo, xi = s[A].split(A.op.axis[0], factor=num_thread)
+    s[A].bind(xi, tvm.thread_axis("threadIdx.x"))
+    s[A].bind(xo, tvm.thread_axis("blockIdx.x"))
+
+    def check_target(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            return
+        f = tvm.build(s, [A], device)
+        # launch the kernel.
+        a = tvm.nd.empty((n, ), dtype=A.dtype, ctx=ctx)
+        f(a)
+        assert a.asnumpy()[0] == value + 3
+
+    check_target("cuda")
+    check_target("vulkan")
+
+
 def test_add_pipeline():
     n = tvm.var('n')
     A = tvm.placeholder((n,), name='A')
@@ -112,4 +138,5 @@ def test_add_pipeline():
 
 
 if __name__ == "__main__":
+    test_large_uint_imm()
     test_add_pipeline()
index 0e595cd..4920206 100644 (file)
@@ -88,6 +88,25 @@ def test_llvm_lookup_intrin():
     fcode = tvm.build(func, None, "llvm")
 
 
+def test_llvm_large_uintimm():
+    value =  (1 << 63) + 123
+    other = tvm.const(3, "uint64")
+    A = tvm.compute((), lambda : tvm.const(value, "uint64") + other, name='A')
+    s = tvm.create_schedule(A.op)
+
+    def check_llvm():
+        if not tvm.module.enabled("llvm"):
+            return
+        f = tvm.build(s, [A], "llvm")
+        ctx = tvm.cpu(0)
+        # launch the kernel.
+        a = tvm.nd.empty((), dtype=A.dtype, ctx=ctx)
+        f(a)
+        assert a.asnumpy() == value + 3
+
+    check_llvm()
+
+
 def test_llvm_add_pipeline():
     nn = 1024
     n = tvm.convert(nn)
@@ -645,6 +664,7 @@ def test_llvm_shuffle():
         tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
 
 if __name__ == "__main__":
+    test_llvm_large_uintimm()
     test_llvm_import()
     test_alignment()
     test_rank_zero()
index c3c40cf..5f1facb 100644 (file)
@@ -24,7 +24,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None):
     def tvm_val_2_py_val(val):
         val = tvm.ir_pass.Substitute(val, var_dict)
         val = tvm.ir_pass.Simplify(val)
-        assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm))
+        assert isinstance(val, (tvm.expr.IntImm,))
         return val.value
 
     ctx = tvm.context(target, 0)
index fe32949..c418785 100644 (file)
@@ -38,16 +38,11 @@ def test_expr_constructor():
     assert x.value == 2
     assert x.dtype == "int64"
 
-    x = tvm.expr.UIntImm("uint16", 2)
-    assert isinstance(x, tvm.expr.UIntImm)
-    assert x.value == 2
-    assert x.dtype == "uint16"
-
     x = tvm.expr.StringImm("xyza")
     assert isinstance(x, tvm.expr.StringImm)
     assert x.value == "xyza"
 
-    x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1))
+    x = tvm.expr.Cast("float32", tvm.expr.IntImm("uint32", 1))
     assert isinstance(x, tvm.expr.Cast)
     assert x.dtype == "float32"
     assert x.value.value == 1
index c57f4a1..ac2ee6d 100644 (file)
@@ -29,7 +29,7 @@ def test_const_fold():
     def check(f, *args):
         x = f(*[tvm.const(x, "int32") for x in args])
         y = f(*args)
-        if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y):
+        if not isinstance(x, (tvm.expr.IntImm,)) or x.value != int(y):
             raise ValueError("check error: %s vs %s " % (x, y))
 
     tmod = tvm.truncmod
index 43ac3a2..e6de76f 100644 (file)
@@ -43,8 +43,7 @@ using namespace tvm;
  */
 inline bool IsConstInt(PrimExpr expr) {
   return
-    expr->IsInstance<tvm::ir::IntImmNode>() ||
-    expr->IsInstance<tvm::ir::UIntImmNode>();
+    expr->IsInstance<tvm::ir::IntImmNode>();
 }
 
 /*!
@@ -56,11 +55,8 @@ inline bool IsConstInt(PrimExpr expr) {
  * \return The integer value.
  */
 inline int64_t GetConstInt(PrimExpr expr) {
-  if (expr->IsInstance<tvm::ir::IntImmNode>()) {
-    return expr.as<tvm::ir::IntImmNode>()->value;
-  }
-  if (expr->IsInstance<tvm::ir::UIntImmNode>()) {
-    return expr.as<tvm::ir::UIntImmNode>()->value;
+  if (expr->IsInstance<tvm::IntImmNode>()) {
+    return expr.as<tvm::IntImmNode>()->value;
   }
   LOG(ERROR) << "expr must be a constant integer";
   return -1;
index 8f32a29..02d082b 100644 (file)
@@ -92,9 +92,9 @@ def get_const_int(expr):
     """
     if isinstance(expr, Integral):
         return expr
-    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+    if not isinstance(expr, tvm.expr.IntImm):
         expr = tvm.ir_pass.Simplify(expr)
-    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+    if not isinstance(expr, tvm.expr.IntImm):
         raise ValueError("Expect value to be constant int")
     return int(expr.value)
 
@@ -136,9 +136,9 @@ def equal_const_int(expr, value):
     """
     if isinstance(expr, Integral):
         return expr == value
-    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+    if not isinstance(expr, tvm.expr.IntImm):
         expr = tvm.ir_pass.Simplify(expr)
-    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+    if not isinstance(expr, tvm.expr.IntImm):
         return False
     return expr.value == value
 
@@ -160,9 +160,9 @@ def get_const_tuple(in_tuple):
     for elem in in_tuple:
         if isinstance(elem, tvm.expr.Var):
             ret.append(elem)
-        elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
+        elif not isinstance(elem, (tvm.expr.IntImm, int)):
             elem = tvm.ir_pass.Simplify(elem)
-            if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+            if not isinstance(elem, tvm.expr.IntImm):
                 ret.append(elem)
         else:
             ret.append(get_const_int(elem))