[REFACTOR][IR] tvm::Expr -> PrimExpr(Primitive Expr) (#4669)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 9 Jan 2020 23:30:23 +0000 (15:30 -0800)
committerGitHub <noreply@github.com>
Thu, 9 Jan 2020 23:30:23 +0000 (15:30 -0800)
* [REFACTOR][IR] tvm::Expr -> PrimExpr(Primitive Expr)

As part of unified IR, we will need to unify relay::Expr
and the current tvm::Expr under the same base type.

From the techinical point of view. tvm::Expr is a "primitive"
expression that only contains POD types and handles and does
not do life-cycle management.

This PR renames Expr->PrimExpr to clarify that.
We will send a subsequent PR to introduce the base expr class.

* Remove legacy VarExpr and ExprHash/Equal

207 files changed:
include/tvm/arithmetic.h
include/tvm/attrs.h
include/tvm/buffer.h
include/tvm/build_module.h
include/tvm/data_layout.h
include/tvm/expr.h
include/tvm/expr_operator.h
include/tvm/ir.h
include/tvm/ir_functor_ext.h
include/tvm/ir_pass.h
include/tvm/lowered_func.h
include/tvm/operation.h
include/tvm/packed_func_ext.h
include/tvm/relay/attrs/transform.h
include/tvm/relay/base.h
include/tvm/relay/transform.h
include/tvm/runtime/packed_func.h
include/tvm/schedule.h
include/tvm/target_info.h
include/tvm/tensor.h
include/tvm/tensor_intrin.h
python/tvm/api.py
python/tvm/contrib/sparse.py
python/tvm/expr.py
python/tvm/hybrid/calls.py
python/tvm/hybrid/parser.py
python/tvm/schedule.py
python/tvm/tensor.py
python/tvm/tensor_intrin.py
src/api/api_arith.cc
src/api/api_ir.cc
src/api/api_lang.cc
src/api/api_pass.cc
src/api/api_test.cc
src/arithmetic/analyzer.cc
src/arithmetic/bound_deducer.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/compute_expr.h
src/arithmetic/const_fold.h
src/arithmetic/const_int_bound.cc
src/arithmetic/detect_linear_equation.cc
src/arithmetic/domain_touched.cc
src/arithmetic/int_set.cc
src/arithmetic/int_set.h
src/arithmetic/ir_mutator_with_analyzer.cc
src/arithmetic/ir_mutator_with_analyzer.h
src/arithmetic/ir_visitor_with_analyzer.h
src/arithmetic/modular_set.cc
src/arithmetic/pattern_match.h
src/arithmetic/rewrite_simplify.cc
src/arithmetic/rewrite_simplify.h
src/arithmetic/stmt_simplify.cc
src/autotvm/feature_visitor.cc
src/autotvm/feature_visitor.h
src/autotvm/touch_extractor.cc
src/autotvm/touch_extractor.h
src/codegen/build_module.cc
src/codegen/codegen_c.cc
src/codegen/codegen_c.h
src/codegen/codegen_cuda.cc
src/codegen/codegen_cuda.h
src/codegen/codegen_opencl.cc
src/codegen/codegen_opencl.h
src/codegen/codegen_opengl.cc
src/codegen/codegen_opengl.h
src/codegen/codegen_vhls.cc
src/codegen/intrin_rule.cc
src/codegen/intrin_rule.h
src/codegen/llvm/codegen_amdgpu.cc
src/codegen/llvm/codegen_arm.cc
src/codegen/llvm/codegen_cpu.cc
src/codegen/llvm/codegen_cpu.h
src/codegen/llvm/codegen_llvm.cc
src/codegen/llvm/codegen_llvm.h
src/codegen/llvm/intrin_rule_llvm.cc
src/codegen/llvm/intrin_rule_llvm.h
src/codegen/llvm/intrin_rule_nvptx.cc
src/codegen/llvm/intrin_rule_rocm.cc
src/codegen/spirv/codegen_spirv.cc
src/codegen/spirv/codegen_spirv.h
src/codegen/spirv/intrin_rule_spirv.cc
src/codegen/stackvm/codegen_stackvm.cc
src/codegen/stackvm/codegen_stackvm.h
src/contrib/hybrid/codegen_hybrid.h
src/lang/attrs.cc
src/lang/buffer.cc
src/lang/data_layout.cc
src/lang/expr.cc
src/lang/expr_operator.cc
src/lang/ir.cc
src/lang/tensor.cc
src/op/compute_op.cc
src/op/compute_op.h
src/op/cross_thread_reduction.cc
src/op/extern_op.cc
src/op/hybrid_op.cc
src/op/op_util.cc
src/op/op_util.h
src/op/placeholder_op.cc
src/op/scan_op.cc
src/op/tensor_compute_op.cc
src/op/tensorize.cc
src/pass/arg_binder.cc
src/pass/arg_binder.h
src/pass/bound_checker.cc
src/pass/combine_context_call.cc
src/pass/coproc_sync.cc
src/pass/hoist_if_then_else.cc
src/pass/infer_fragment.cc
src/pass/inject_copy_intrin.cc
src/pass/inject_double_buffer.cc
src/pass/inject_prefetch.cc
src/pass/inject_virtual_thread.cc
src/pass/inline.cc
src/pass/ir_deep_compare.cc
src/pass/ir_functor.cc
src/pass/ir_util.h
src/pass/lift_attr_scope.cc
src/pass/loop_partition.cc
src/pass/lower_custom_datatypes.cc
src/pass/lower_intrin.cc
src/pass/lower_thread_allreduce.cc
src/pass/lower_tvm_builtin.cc
src/pass/lower_warp_memory.cc
src/pass/make_api.cc
src/pass/remap_thread_axis.cc
src/pass/remove_no_op.cc
src/pass/rewrite_unsafe_select.cc
src/pass/simple_passes.cc
src/pass/split_host_device.cc
src/pass/ssa.cc
src/pass/storage_access.cc
src/pass/storage_flatten.cc
src/pass/storage_rewrite.cc
src/pass/storage_sync.cc
src/pass/tensor_core.cc
src/pass/unroll_loop.cc
src/pass/vectorize_loop.cc
src/pass/verify_gpu_code.cc
src/pass/verify_memory.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/removed_unused_funcs.cc
src/relay/ir/expr.cc
src/relay/ir/hash.cc
src/relay/ir/op.cc
src/relay/op/nn/bitserial.cc
src/relay/op/nn/nn.cc
src/relay/op/nn/nn.h
src/relay/op/nn/pad.cc
src/relay/op/tensor/transform.cc
src/relay/pass/infer_layout_util.h
src/relay/pass/pass_manager.cc
src/relay/pass/type_solver.cc
src/relay/qnn/op/dequantize.cc
src/relay/qnn/op/quantize.cc
src/relay/qnn/op/requantize.cc
src/relay/qnn/util.h
src/schedule/auto_inline_elem_wise.cc
src/schedule/graph.cc
src/schedule/graph.h
src/schedule/message_passing.cc
src/schedule/message_passing.h
src/schedule/schedule_dataflow_rewrite.cc
src/schedule/schedule_lang.cc
src/schedule/schedule_ops.cc
tests/cpp/attrs_test.cc
tests/cpp/build_module_test.cc
tests/cpp/container_test.cc
tests/cpp/expr_test.cc
tests/cpp/ir_functor_test.cc
tests/cpp/ir_ssa_test.cc
tests/cpp/packed_func_test.cc
tests/cpp/pattern_match_test.cc
tests/cpp/simple_passes_test.cc
tests/python/unittest/test_lang_tensor_overload_op.py
tests/python/unittest/test_pass_lower_intrin.py
topi/include/topi/broadcast.h
topi/include/topi/contrib/cublas.h
topi/include/topi/contrib/rocblas.h
topi/include/topi/cuda/dense.h
topi/include/topi/cuda/reduction.h
topi/include/topi/detail/broadcast.h
topi/include/topi/detail/constant_utils.h
topi/include/topi/detail/extern.h
topi/include/topi/detail/pad_utils.h
topi/include/topi/detail/ravel_unravel.h
topi/include/topi/detail/tensor_utils.h
topi/include/topi/elemwise.h
topi/include/topi/image/resize.h
topi/include/topi/nn.h
topi/include/topi/nn/bnn.h
topi/include/topi/nn/dilate.h
topi/include/topi/nn/flatten.h
topi/include/topi/nn/local_response_norm.h
topi/include/topi/nn/pooling.h
topi/include/topi/nn/softmax.h
topi/include/topi/nn/upsampling.h
topi/include/topi/reduction.h
topi/include/topi/transform.h
topi/include/topi/vision/reorg.h
topi/python/topi/nn/pad.py
topi/python/topi/util.py
topi/src/topi.cc
topi/tests/python/test_topi_broadcast.py

index d135d30..12acfc3 100644 (file)
@@ -105,7 +105,7 @@ class ConstIntBoundAnalyzer {
    * \param expr The expression of interest.
    * \return the result of the analysis.
    */
-  ConstIntBound operator()(const Expr& expr);
+  ConstIntBound operator()(const PrimExpr& expr);
 
   /*!
    * \brief Update constant int bound information of var.
@@ -136,7 +136,7 @@ class ConstIntBoundAnalyzer {
    *
    * \return an exit function that must be called to cleanup the constraint can be nullptr.
    */
-  std::function<void()> EnterConstraint(const Expr& constraint);
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
   struct Entry;
   class Impl;
   /*! \brief Internal impl */
@@ -192,7 +192,7 @@ class ModularSetAnalyzer {
    * \param expr The expression of interest.
    * \return the result of the analysis.
    */
-  ModularSet operator()(const Expr& expr);
+  ModularSet operator()(const PrimExpr& expr);
   /*!
    * \brief Update constant int bound information of var.
    *
@@ -215,7 +215,7 @@ class ModularSetAnalyzer {
    *
    * \return an exit function that must be called to cleanup the constraint can be nullptr.
    */
-  std::function<void()> EnterConstraint(const Expr& constraint);
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
   struct Entry;
   class Impl;
   /*! \brief Internal impl */
@@ -232,7 +232,7 @@ class RewriteSimplifier {
    * \param expr The expression of interest.
    * \return the result of the analysis.
    */
-  Expr operator()(const Expr& expr);
+  PrimExpr operator()(const PrimExpr& expr);
 
   /*!
    * \brief Update binding of var to a new expression.
@@ -242,10 +242,10 @@ class RewriteSimplifier {
    * \param override Whether do we allow override of existing information.
    */
   void Update(const Var& var,
-              const Expr& new_expr,
+              const PrimExpr& new_expr,
               bool override = false);
 
-  std::function<void()> EnterConstraint(const Expr& constraint);
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
  private:
   friend class Analyzer;
@@ -268,7 +268,7 @@ class CanonicalSimplifier {
    * \param expr The expression of interest.
    * \return the result of the analysis.
    */
-  Expr operator()(const Expr& expr);
+  PrimExpr operator()(const PrimExpr& expr);
 
   /*!
    * \brief Update binding of var to a new expression.
@@ -278,7 +278,7 @@ class CanonicalSimplifier {
    * \param override Whether do we allow override of existing information.
    */
   void Update(const Var& var,
-              const Expr& new_expr,
+              const PrimExpr& new_expr,
               bool override = false);
 
  private:
@@ -316,7 +316,7 @@ class ConstraintContext {
    * \param analyzer The analyzer.
    * \param constraint The constraint to be applied.
    */
-  ConstraintContext(Analyzer* analyzer, Expr constraint)
+  ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
       : analyzer_(analyzer), constraint_(constraint) {}
   // enter the scope.
   void EnterWithScope();
@@ -325,7 +325,7 @@ class ConstraintContext {
   /*! \brief The analyzer */
   Analyzer* analyzer_;
   /*! \brief The constraint */
-  Expr constraint_;
+  PrimExpr constraint_;
   /*! \brief function to be called in recovery */
   std::function<void()> exit_;
 };
@@ -375,9 +375,9 @@ class IntSet : public ObjectRef {
    */
   Range cover_range(Range max_range) const;
   /*! \return Lower bound of the set */
-  Expr min() const;
+  PrimExpr min() const;
   /*! \return upper bound of the set */
-  Expr max() const;
+  PrimExpr max() const;
   /*! \return Whether the set represent nothing  */
   bool is_nothing() const;
   /*! \return Whether the set represent everything  */
@@ -398,7 +398,7 @@ class IntSet : public ObjectRef {
    * \brief The single point value, call only if is_single_point is true
    * \return The point value.
    */
-  Expr point_value() const;
+  PrimExpr point_value() const;
   /*!
    * \brief Try to match IntSet with range r.
    *
@@ -415,13 +415,13 @@ class IntSet : public ObjectRef {
    * \param point The point in the set.
    * \return construct a single point set
    */
-  static IntSet single_point(Expr point);
+  static IntSet single_point(PrimExpr point);
   /*!
    * \brief construct a integer set from vector expression.
    * \param vec The vector expression, can also be single point.
    * \return The result set containing the indices in the vector.
    */
-  static IntSet vector(Expr vec);
+  static IntSet vector(PrimExpr vec);
   /*!
    * \brief Construct a set representing a range.
    * \param r The range
@@ -434,7 +434,7 @@ class IntSet : public ObjectRef {
    * \param max The maximum value of the interval.
    * \return constructed set.
    */
-  static IntSet interval(Expr min, Expr max);
+  static IntSet interval(PrimExpr min, PrimExpr max);
 };
 
 /*!
@@ -450,7 +450,7 @@ class IntSetAnalyzer {
    * \param dom_map The domain map to indicate which variable to relax.
    * \return the result of the analysis.
    */
-  IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
+  IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
 
  private:
   friend class Analyzer;
@@ -499,7 +499,7 @@ class Analyzer {
    * \param var The variable.
    * \param expr The expression we bind to.
    */
-  void Bind(const VarExpr& var, const Expr& expr);
+  void Bind(const Var& var, const PrimExpr& expr);
   /*!
    * \brief Notify all the sub-analyzers that var
    *        is created and binded to a range.
@@ -509,7 +509,7 @@ class Analyzer {
    * \param var The variable.
    * \param range The range we bind to.
    */
-  void Bind(const VarExpr& var, const Range& range);
+  void Bind(const Var& var, const Range& range);
   /*!
    * \brief Whether can we prove expr >= val.
 
@@ -522,7 +522,7 @@ class Analyzer {
    *
    * \note Analyzer will call into sub-analyzers to get the result.
    */
-  bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
+  bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
   /*!
    * \brief Whether can we prove condition.
    *
@@ -531,7 +531,7 @@ class Analyzer {
    *
    * \note Analyzer will call into sub-analyzers to get the result.
    */
-  bool CanProve(const Expr& cond);
+  bool CanProve(const PrimExpr& cond);
   /*!
    * \brief Simplify expr.
    *
@@ -540,7 +540,7 @@ class Analyzer {
    *
    * \note Analyzer will call into sub-analyzers to get the result.
    */
-  Expr Simplify(const Expr& expr);
+  PrimExpr Simplify(const PrimExpr& expr);
 };
 
 //-----------------------------------------------
@@ -554,7 +554,7 @@ class Analyzer {
  * \param dom_map The domain of each variable.
  * \return An integer set that can cover all the possible values of e.
  */
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
                const Map<IterVar, IntSet>& dom_map);
 /*!
  * \brief Same as EvalSet, but takes unordered_map
@@ -563,7 +563,7 @@ IntSet EvalSet(Expr e,
  * \param dom_map The domain of each variable.
  * \return An integer set that can cover all the possible values of e.
  */
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
                const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*!
@@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
                const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*! \brief Map from Expr to IntSet */
-using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
+using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
 /*!
  * \brief Find the integer set of every sub-expression, given the
  *  domain of each iteration variables.
@@ -608,7 +608,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
  * \return the map from the expression to its possible value.
  */
 ExprIntSetMap EvalSetForEachSubExpr(
-    Expr e,
+    PrimExpr e,
     const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*!
@@ -640,7 +640,7 @@ IntSet Intersect(const Array<IntSet>& sets);
  *        The deduce bound must implies e for all value in relax_map
  * \return An integer set that always satisfies the condition.
  */
-IntSet DeduceBound(Expr v, Expr cond,
+IntSet DeduceBound(PrimExpr v, PrimExpr cond,
                    const Map<Var, IntSet>& hint_map,
                    const Map<Var, IntSet>& relax_map);
 /*!
@@ -653,7 +653,7 @@ IntSet DeduceBound(Expr v, Expr cond,
  *        The deduce bound mush implies e for all value in relax_map
  * \return An integer set that always satisfies the condition.
  */
-IntSet DeduceBound(Expr v, Expr cond,
+IntSet DeduceBound(PrimExpr v, PrimExpr cond,
                    const std::unordered_map<const VarNode*, IntSet>& hint_map,
                    const std::unordered_map<const VarNode*, IntSet>& relax_map);
 
@@ -676,7 +676,7 @@ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool
  * \param vars List of variables to be used in detection.
  * \return [coeff[i]] if it is possible, empty array if it is not.
  */
-Array<Expr> DetectLinearEquation(const Expr& e,
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
                                  const Array<Var>& vars);
 
 /*!
@@ -687,7 +687,7 @@ Array<Expr> DetectLinearEquation(const Expr& e,
  * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
  *          return empty if the e does not match the pattern.
  */
-Array<Expr> DetectClipBound(const Expr& e,
+Array<PrimExpr> DetectClipBound(const PrimExpr& e,
                             const Array<Var>& vars);
 
 // implementation
index 13c8b30..ab9a711 100644 (file)
@@ -486,7 +486,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
   if (val.type_code() == kDLInt) {
     *ptr = static_cast<T>(val.value().v_int64);
   } else {
-    Expr expr = val;
+    PrimExpr expr = val;
     CHECK(expr.defined());
     if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<T>(op->value);
@@ -502,7 +502,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
   if (val.type_code() == kStr) {
     *ptr = val.operator std::string();
   } else {
-    Expr expr = val;
+    PrimExpr expr = val;
     const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
     CHECK(op != nullptr);
     *ptr = op->value;
@@ -517,7 +517,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
   if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
     *ptr = val.operator double();
   } else {
-    Expr expr = val;
+    PrimExpr expr = val;
     CHECK(expr.defined());
     if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<double>(op->value);
index 44c7918..284e370 100644 (file)
@@ -66,7 +66,7 @@ class Buffer : public ObjectRef {
    *  If stride is not needed in the slice, it won't be presented
    * \return the result buffer.
    */
-  TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
+  TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
   /*!
    * \brief Get access ptr to the entire buffer.
    * \param access_mask The access mask
@@ -74,22 +74,22 @@ class Buffer : public ObjectRef {
    * \param content_lanes The number of lanes for the (data) type.
    * \param offset The offset of ptr.
    */
-  TVM_DLL Expr access_ptr(int access_mask,
+  TVM_DLL PrimExpr access_ptr(int access_mask,
                           DataType ptr_type = DataType::Handle(),
                           int content_lanes = 1,
-                          Expr offset = make_const(DataType::Int(32), 0)) const;
+                          PrimExpr offset = make_const(DataType::Int(32), 0)) const;
   /*!
    * \brief Create an Expr that does a vector load at begin index.
    * \param begin The beginning index
    * \param dtype The data type to be loaded.
    */
-  TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
+  TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
   /*!
    * \brief Create a Stmt that does a vector store at begin index.
    * \param begin The beginning index
    * \param value The value to be stored.
    */
-  TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
+  TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
@@ -112,14 +112,14 @@ class BufferNode : public Object {
   /*! \brief data type in the content of the tensor */
   DataType dtype;
   /*! \brief The shape of the buffer */
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   /*!
    * \brief The strides of each dimension
    *  This can be an empty array, indicating array is contiguous
    */
-  Array<Expr> strides;
+  Array<PrimExpr> strides;
   /*! \brief The offset in terms of number of dtype elements (including lanes) */
-  Expr elem_offset;
+  PrimExpr elem_offset;
   // Meta data
   /*! \brief optional name of the buffer */
   std::string name;
@@ -159,9 +159,9 @@ class BufferNode : public Object {
   // A default value will be picked.
   TVM_DLL static Buffer make(Var ptr,
                              DataType dtype,
-                             Array<Expr> shape,
-                             Array<Expr> strides,
-                             Expr elem_offset,
+                             Array<PrimExpr> shape,
+                             Array<PrimExpr> strides,
+                             PrimExpr elem_offset,
                              std::string name,
                              std::string scope,
                              int data_alignment,
@@ -184,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const {
  * \return The created buffer.
  * \sa BufferNode::make for complete constructor.
  */
-TVM_DLL Buffer decl_buffer(Array<Expr> shape,
+TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
                            DataType dtype = DataType::Float(32),
                            std::string name = "buffer");
 }  // namespace tvm
index 5078621..8b49fb7 100644 (file)
@@ -52,11 +52,11 @@ class TargetNode : public Object {
   /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
   int thread_warp_size = 1;
   /*! \brief Keys for this target */
-  Array<Expr> keys_array;
+  Array<PrimExpr> keys_array;
   /*! \brief Options for this target */
-  Array<Expr> options_array;
+  Array<PrimExpr> options_array;
   /*! \brief Collection of imported libs */
-  Array<Expr> libs_array;
+  Array<PrimExpr> libs_array;
 
   /*! \return the full device string to pass to codegen::Build */
   TVM_DLL const std::string& str() const;
index 8c7247f..d49320c 100644 (file)
@@ -316,9 +316,9 @@ class BijectiveLayoutNode : public Object {
   /*! \brief Describes how source axes can be mapped to the destination axes,
    *   e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
    */
-  Array<Expr> forward_rule;
+  Array<PrimExpr> forward_rule;
   /*! \brief Describes how destination axes can be mapped to the source axes */
-  Array<Expr> backward_rule;
+  Array<PrimExpr> backward_rule;
 
   /*! \brief The source layout */
   Layout src_layout;
@@ -350,13 +350,13 @@ class BijectiveLayout : public ObjectRef {
   explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
 
   // Given the source shape, infer the destination shape.
-  TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
+  TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
   // Given the destination shape, recover the source shape.
-  TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
+  TVM_DLL Array<PrimExpr> BackwardShape(const Array<PrimExpr>& dst_shape) const;
   // Given the destination indices, infer the destination indices.
-  TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
+  TVM_DLL Array<PrimExpr> ForwardIndex(const Array<PrimExpr>& index) const;
   // Given the destination indices, recover the source indices.
-  TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
+  TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
 
   /*!
    * \brief access the internal node container
index 64d7547..976af61 100644 (file)
 
 namespace tvm {
 
-/*! \brief Base node of all expressions. */
-class ExprNode : public Object {
+/*!
+ * \brief Base node of all primitive expressions.
+ *
+ *  A primitive expression deals with low-level
+ *  POD data types and handles without
+ *  doing life-cycle management for objects.
+ *
+ *  PrimExpr is used in the low-level code
+ *  optimizations and integer analysis.
+ *
+ * \sa PrimExpr
+ */
+class PrimExprNode : public Object {
  public:
   /*! \brief The data type of the expression. */
   DataType dtype;
 
-  static constexpr const char* _type_key = "Expr";
-  TVM_DECLARE_BASE_OBJECT_INFO(ExprNode, Object);
+  static constexpr const char* _type_key = "PrimExpr";
+  TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, Object);
 };
 
-/*! \brief Container of all expressions. */
-class Expr : public ObjectRef {
+/*!
+ * \brief Container of all primitive expressions.
+ * \sa PrimExprNode
+ */
+class PrimExpr : public ObjectRef {
  public:
-  Expr() {}
-  explicit Expr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+  PrimExpr() {}
+  explicit PrimExpr(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
   /*!
    * \brief construct from integer.
    * \param value The value to be constructed.
    */
-  TVM_DLL Expr(int32_t value);  // NOLINT(*)
+  TVM_DLL PrimExpr(int32_t value);  // NOLINT(*)
   /*!
    * \brief construct from float.
    * \param value The value to be constructed.
    */
-  TVM_DLL Expr(float value);  // NOLINT(*)
+  TVM_DLL PrimExpr(float value);  // NOLINT(*)
   /*!
    * \brief construct from string.
    * \param str The value to be constructed.
    */
-  TVM_DLL Expr(std::string str);  // NOLINT(*)
+  TVM_DLL PrimExpr(std::string str);  // NOLINT(*)
 
   /*! \return the data type of this expression. */
   DataType dtype() const {
-    return static_cast<const ExprNode*>(get())->dtype;
+    return static_cast<const PrimExprNode*>(get())->dtype;
   }
 
-  /*! \brief type indicate the container type */
-  using ContainerType = ExprNode;
+  using ContainerType = PrimExprNode;
 };
 
 /*! \brief Base node of all statements. */
@@ -102,7 +115,7 @@ class Var;
  * - Let
  * - LetStmt
  */
-class VarNode : public ExprNode {
+class VarNode : public PrimExprNode {
  public:
   /*!
    * \brief The hint to the variable name.
@@ -118,13 +131,13 @@ class VarNode : public ExprNode {
   }
 
   static constexpr const char* _type_key = "Variable";
-  TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode);
 };
 
 /*! \brief a named variable in TVM */
-class Var : public Expr {
+class Var : public PrimExpr {
  public:
-  explicit Var(ObjectPtr<Object> n) : Expr(n) {}
+  explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
   TVM_DLL explicit Var(std::string name_hint = "v",
                        DataType t = DataType::Int(32));
   /*!
@@ -153,15 +166,9 @@ class Var : public Expr {
   using ContainerType = VarNode;
 };
 
-// Backward compatibility, will be removed later.
-using VarExpr = Var;
-using BaseExprNode = ExprNode;
-using ExprHash = ObjectHash;
-using ExprEqual = ObjectEqual;
-
 class Integer;
 /*! \brief ExprNode: constant integer. */
-class IntImmNode : public ExprNode {
+class IntImmNode : public PrimExprNode {
  public:
   /*! \brief the Internal value. */
   int64_t value;
@@ -174,7 +181,7 @@ class IntImmNode : public ExprNode {
   TVM_DLL static Integer make(DataType t, int64_t value);
 
   static constexpr const char* _type_key = "IntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
 };
 
 /*!
@@ -183,17 +190,17 @@ class IntImmNode : public ExprNode {
  * This is used to store and automate type check
  * attributes that must be constant integer.
  */
-class Integer : public Expr {
+class Integer : public PrimExpr {
  public:
-  Integer() : Expr() {}
+  Integer() : PrimExpr() {}
   /*!
    * \brief constructor from node.
    */
-  explicit Integer(ObjectPtr<Object> node) : Expr(node) {}
+  explicit Integer(ObjectPtr<Object> node) : PrimExpr(node) {}
   /*!
    * \brief Construct integer from int value.
    */
-  Integer(int value) : Expr(value) {}  // NOLINT(*)
+  Integer(int value) : PrimExpr(value) {}  // NOLINT(*)
   /*!
    * \brief Assign an expression to integer.
    * \param other another expression.
@@ -225,12 +232,12 @@ class Integer : public Expr {
 class RangeNode : public Object {
  public:
   /*! \brief beginning of the node */
-  Expr min;
+  PrimExpr min;
   /*! \brief the extend of range */
-  Expr extent;
+  PrimExpr extent;
   /*! \brief constructor */
   RangeNode() {}
-  RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}
+  RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("min", &min);
@@ -249,7 +256,7 @@ class Range : public ObjectRef {
    * \param begin The begin of the range.
    * \param end The end of the range.
    */
-  TVM_DLL Range(Expr begin, Expr end);
+  TVM_DLL Range(PrimExpr begin, PrimExpr end);
   /*!
    * \brief construct a new range with min and extent
    *  The corresponding constructor is removed,
@@ -259,7 +266,7 @@ class Range : public ObjectRef {
    * \param min The minimum range.
    * \param extent The extent of the range.
    */
-  static Range make_by_min_extent(Expr min, Expr extent);
+  static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
   // declare range.
   TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
 };
@@ -357,7 +364,7 @@ class IterVar : public ObjectRef {
   /*!
    * \return the corresponding var in the IterVar.
    */
-  inline operator Expr() const;
+  inline operator PrimExpr() const;
   /*! \brief specify container node */
   using ContainerType = IterVarNode;
 };
@@ -428,7 +435,7 @@ inline const IterVarNode* IterVar::operator->() const {
   return static_cast<const IterVarNode*>(data_.get());
 }
 
-inline IterVar::operator Expr() const {
+inline IterVar::operator PrimExpr() const {
   return (*this)->var;
 }
 
index bf8b1a3..2d8f378 100644 (file)
@@ -44,19 +44,19 @@ namespace tvm {
  */
 template<typename ValueType,
          typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
-inline Expr make_const(DataType t, ValueType value);
+inline PrimExpr make_const(DataType t, ValueType value);
 /*!
  * \brief Make a const zero expr.
  * \param t The target type.
  * \return the result expression.
  */
-inline Expr make_zero(DataType t);
+inline PrimExpr make_zero(DataType t);
 /*!
  * \brief Make a constant true expression.
  * \param lanes The number of lanes in the bool
  * \return The result expression.
  */
-inline Expr const_true(int lanes = 1) {
+inline PrimExpr const_true(int lanes = 1) {
   return make_const(DataType::UInt(1, lanes), 1);
 }
 /*!
@@ -64,7 +64,7 @@ inline Expr const_true(int lanes = 1) {
  * \param lanes The number of lanes in the bool
  * \return The result expression.
  */
-inline Expr const_false(int lanes = 1) {
+inline PrimExpr const_false(int lanes = 1) {
   return make_const(DataType::UInt(1, lanes), 0);
 }
 /*!
@@ -73,7 +73,7 @@ inline Expr const_false(int lanes = 1) {
  * \return the address to the int expression,
  *         return nullptr, if x is not IntImm.
  */
-inline const int64_t* as_const_int(const Expr& x) {
+inline const int64_t* as_const_int(const PrimExpr& x) {
   if (!x.defined()) return nullptr;
   if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
     return &(op->value);
@@ -88,7 +88,7 @@ inline const int64_t* as_const_int(const Expr& x) {
  * \return the address to the int expression,
  *         return nullptr, if x is not UIntImm.
  */
-inline const uint64_t* as_const_uint(const Expr& x) {
+inline const uint64_t* as_const_uint(const PrimExpr& x) {
   if (!x.defined()) return nullptr;
   if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
     return &(op->value);
@@ -103,7 +103,7 @@ inline const uint64_t* as_const_uint(const Expr& x) {
  * \param value the value to be compared against.
  * \return whether x is constant expression.
  */
-inline bool is_const_int(const Expr& x, int64_t value);
+inline bool is_const_int(const PrimExpr& x, int64_t value);
 
 /*!
  * \brief Check whether stmt is nop.
@@ -118,7 +118,7 @@ inline bool is_no_op(const Stmt& stmt);
  * \note This only return true for integer types.
  * \return whether x is constant 1
  */
-inline bool is_one(const Expr& x) {
+inline bool is_one(const PrimExpr& x) {
   return is_const_int(x, 1);
 }
 
@@ -128,7 +128,7 @@ inline bool is_one(const Expr& x) {
  * \return whether x is constant 0
  * \note This only return true for integer types.
  */
-inline bool is_zero(const Expr& x) {
+inline bool is_zero(const PrimExpr& x) {
   return is_const_int(x, 0);
 }
 
@@ -137,21 +137,21 @@ inline bool is_zero(const Expr& x) {
  * \note This only return true for integer types.
  * \return whether x is constant
  */
-inline bool is_const(const Expr& x);
+inline bool is_const(const PrimExpr& x);
 
 /*!
  * Query the maximum possible value of dtype.
  * \param dtype The data type.
  * \return the maximum possible value in this format.
  */
-TVM_DLL Expr max_value(const DataType& dtype);
+TVM_DLL PrimExpr max_value(const DataType& dtype);
 
 /*!
  * Query the minimum possible value of dtype.
  * \param dtype The data type.
  * \return the minimum possible value in this format.
  */
-TVM_DLL Expr min_value(const DataType& dtype);
+TVM_DLL PrimExpr min_value(const DataType& dtype);
 
 /*!
  * \brief Check whether x is a constant power of two
@@ -161,7 +161,7 @@ TVM_DLL Expr min_value(const DataType& dtype);
  * \param shift The output shift if x is power of two.
  * \return whether x is constant power of two
  */
-TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
+TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
 
 /*!
  * \brief cast value to type.
@@ -171,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
  * \return The result expression.
  * \note This function may return value if the type is the same.
  */
-TVM_DLL Expr cast(const DataType& t, Expr value);
+TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value);
 /*!
  * \brief perform reinterpret cast value to type.
  *
@@ -180,7 +180,7 @@ TVM_DLL Expr cast(const DataType& t, Expr value);
  * \return The result expression.
  * \note This function may return value if the type is the same.
  */
-TVM_DLL Expr reinterpret(const DataType& t, Expr value);
+TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value);
 /*!
  * \brief add operator
  *
@@ -190,7 +190,7 @@ TVM_DLL Expr reinterpret(const DataType& t, Expr value);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator+(Expr a, Expr b);
+TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
 /*!
  * \brief subtraction operator
  *
@@ -200,7 +200,7 @@ TVM_DLL Expr operator+(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator-(Expr a, Expr b);
+TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
 /*!
  * \brief negation.
  *
@@ -209,7 +209,7 @@ TVM_DLL Expr operator-(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator-(Expr a);
+TVM_DLL PrimExpr operator-(PrimExpr a);
 /*!
  * \brief multiplication operator
  *
@@ -219,7 +219,7 @@ TVM_DLL Expr operator-(Expr a);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator*(Expr a, Expr b);
+TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
 /*!
  * \brief division operator
  *
@@ -229,7 +229,7 @@ TVM_DLL Expr operator*(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator/(Expr a, Expr b);
+TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
 /*!
  * \brief left shift operator
  *
@@ -239,7 +239,7 @@ TVM_DLL Expr operator/(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator<<(Expr a, Expr b);
+TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
 /*!
  * \brief right shift operator
  *
@@ -249,7 +249,7 @@ TVM_DLL Expr operator<<(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator>>(Expr a, Expr b);
+TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
 /*!
  * \brief greater
  *
@@ -259,7 +259,7 @@ TVM_DLL Expr operator>>(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator>(Expr a, Expr b);
+TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
 /*!
  * \brief greater_equal
  *
@@ -269,7 +269,7 @@ TVM_DLL Expr operator>(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator>=(Expr a, Expr b);
+TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
 /*!
  * \brief less
  *
@@ -279,7 +279,7 @@ TVM_DLL Expr operator>=(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator<(Expr a, Expr b);
+TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
 /*!
  * \brief less_equal
  *
@@ -289,7 +289,7 @@ TVM_DLL Expr operator<(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator<=(Expr a, Expr b);
+TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
 /*!
  * \brief equal
  *
@@ -299,7 +299,7 @@ TVM_DLL Expr operator<=(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator==(Expr a, Expr b);
+TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
 /*!
  * \brief not_equal
  *
@@ -309,7 +309,7 @@ TVM_DLL Expr operator==(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator!=(Expr a, Expr b);
+TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
 /*!
  * \brief and
  *
@@ -318,7 +318,7 @@ TVM_DLL Expr operator!=(Expr a, Expr b);
  * \return The result expression.
  * \note This operator does eager constant folding.
  */
-TVM_DLL Expr operator&&(Expr a, Expr b);
+TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
 /*!
  * \brief or
  *
@@ -327,7 +327,7 @@ TVM_DLL Expr operator&&(Expr a, Expr b);
  * \return The result expression.
  * \note This operator does eager constant folding.
  */
-TVM_DLL Expr operator||(Expr a, Expr b);
+TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
 /*!
  * \brief not
  *
@@ -335,7 +335,7 @@ TVM_DLL Expr operator||(Expr a, Expr b);
  * \return The result expression.
  * \note This operator does eager constant folding.
  */
-TVM_DLL Expr operator!(Expr a);
+TVM_DLL PrimExpr operator!(PrimExpr a);
 /*!
  * \brief compute division in C semantics.
  *
@@ -349,7 +349,7 @@ TVM_DLL Expr operator!(Expr a);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr div(Expr a, Expr b);
+TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute trunc(a / b)
  *
@@ -361,7 +361,7 @@ TVM_DLL Expr div(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr truncdiv(Expr a, Expr b);
+TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute the remainder of truncdiv
  *
@@ -373,7 +373,7 @@ TVM_DLL Expr truncdiv(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr truncmod(Expr a, Expr b);
+TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute floor(a / b) where a and b are non-negative.
  *
@@ -388,7 +388,7 @@ TVM_DLL Expr truncmod(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr indexdiv(Expr a, Expr b);
+TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute the remainder floor(a / b) where a and b are non-negative.
  *
@@ -402,7 +402,7 @@ TVM_DLL Expr indexdiv(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr indexmod(Expr a, Expr b);
+TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute floor(a / b)
  *
@@ -412,7 +412,7 @@ TVM_DLL Expr indexmod(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr floordiv(Expr a, Expr b);
+TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b);
 /*!
  * \brief compute the remainder of floordiv
  *
@@ -422,7 +422,7 @@ TVM_DLL Expr floordiv(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr floormod(Expr a, Expr b);
+TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b);
 /*!
  * \brief take maximum of two values
  *
@@ -432,7 +432,7 @@ TVM_DLL Expr floormod(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr max(Expr a, Expr b);
+TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b);
 /*!
  * \brief take minimum of two values
  *
@@ -442,7 +442,7 @@ TVM_DLL Expr max(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr min(Expr a, Expr b);
+TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise and of two values
  *
@@ -452,7 +452,7 @@ TVM_DLL Expr min(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator&(Expr a, Expr b);
+TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise or of two values
  *
@@ -462,7 +462,7 @@ TVM_DLL Expr operator&(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator|(Expr a, Expr b);
+TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise xor of two values
  *
@@ -472,7 +472,7 @@ TVM_DLL Expr operator|(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator^(Expr a, Expr b);
+TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise negation of two values
  *
@@ -481,7 +481,7 @@ TVM_DLL Expr operator^(Expr a, Expr b);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr operator~(Expr a);
+TVM_DLL PrimExpr operator~(PrimExpr a);
 /*!
  * \brief Conditional expression.
  *
@@ -492,95 +492,95 @@ TVM_DLL Expr operator~(Expr a);
  * \note this function does eager constant folding for
  *       index types(int32, int64) when possible.
  */
-TVM_DLL Expr if_then_else(Expr cond, Expr true_value, Expr false_value);
+TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value);
 /*!
  * \brief Mark condition as likely.
  * \param cond The condition
  * \return The marked expression.
  */
-TVM_DLL Expr likely(Expr cond);
+TVM_DLL PrimExpr likely(PrimExpr cond);
 /*!
  * \brief Calculate power(x, y)
  * \param x The left operand.
  * \param y The right operand.
  */
-TVM_DLL Expr pow(Expr x, Expr y);
+TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y);
 /*!
  * \brief Calculate absolute value of x.
  * \param x The input data
  *
  * \return The aboslute value of input data x
  */
-TVM_DLL Expr abs(Expr x);
+TVM_DLL PrimExpr abs(PrimExpr x);
 /*!
  * \brief Check if x is NaN.
  * \param x The input data
  * \return The result expression.
  */
-TVM_DLL Expr isnan(Expr x);
+TVM_DLL PrimExpr isnan(PrimExpr x);
 
 /*!
  * \brief sum of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr sum(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr sum(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief logical And of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr all(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr all(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief logical Or of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr any(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr any(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr max(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr max(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr min(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr min(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief product of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
  */
-TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
+TVM_DLL PrimExpr prod(PrimExpr source, Array<IterVar> axis);
 
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
  * \return The result expression.
  */
-TVM_DLL Expr floor(Expr x);
+TVM_DLL PrimExpr floor(PrimExpr x);
 
 /*!
  * \brief Calculate ceil(x)
  * \param x The input expression.
  * \return The result expression.
  */
-TVM_DLL Expr ceil(Expr x);
+TVM_DLL PrimExpr ceil(PrimExpr x);
 
 /*!
  * \brief Calculate round(x)
  * \param x The input expression.
  * \return The result expression.
  */
-TVM_DLL Expr round(Expr x);
+TVM_DLL PrimExpr round(PrimExpr x);
 
 /*!
  * \brief Calculates std::nearbyint(x)
@@ -588,20 +588,20 @@ TVM_DLL Expr round(Expr x);
  * \return The result expression.
  * This is a faster alternate to round.
  */
-TVM_DLL Expr nearbyint(Expr x);
+TVM_DLL PrimExpr nearbyint(PrimExpr x);
 
 /*!
  * \brief Calculate trunc(x)
  * \param x The input expression.
  * \return The result expression.
  */
-TVM_DLL Expr trunc(Expr x);
+TVM_DLL PrimExpr trunc(PrimExpr x);
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
-  inline Expr OpName(Expr x) {                                          \
+#define TVM_DECLARE_INTRIN_UNARY(OpName)                                             \
+  inline PrimExpr OpName(PrimExpr x) {                                     \
     return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
-  }                                                                     \
+  }                                                                                  \
 
 TVM_DECLARE_INTRIN_UNARY(exp);
 TVM_DECLARE_INTRIN_UNARY(erf);
@@ -616,11 +616,11 @@ TVM_DECLARE_INTRIN_UNARY(sin);
 TVM_DECLARE_INTRIN_UNARY(atan);
 
 // Implementation details after this
-inline bool is_const(const Expr& x) {
+inline bool is_const(const PrimExpr& x) {
   if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
     return true;
   } else if (const auto* op = x.as<ir::BroadcastNode>()) {
-    const Expr& val = op->value;
+    const PrimExpr& val = op->value;
     if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
       return true;
     }
@@ -628,7 +628,7 @@ inline bool is_const(const Expr& x) {
   return false;
 }
 
-inline bool is_positive_const(const Expr& a) {
+inline bool is_positive_const(const PrimExpr& a) {
   if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
     return op->value > 0;
   } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
@@ -638,7 +638,7 @@ inline bool is_positive_const(const Expr& a) {
   }
 }
 
-inline bool is_negative_const(const Expr& a) {
+inline bool is_negative_const(const PrimExpr& a) {
   if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
     return op->value < 0;
   } else {
@@ -646,13 +646,13 @@ inline bool is_negative_const(const Expr& a) {
   }
 }
 
-inline bool is_const_int(const Expr& x, int64_t value) {
+inline bool is_const_int(const PrimExpr& x, int64_t value) {
   if (const auto* op = x.as<ir::IntImmNode>()) {
     return op->value == value;
   } else if (const auto* op = x.as<ir::UIntImmNode>()) {
     return op->value == static_cast<uint64_t>(value);
   } else if (const auto* op = x.as<ir::BroadcastNode>()) {
-    const Expr& val = op->value;
+    const PrimExpr& val = op->value;
     if (const auto* opv = val.as<ir::IntImmNode>()) {
       return opv->value == value;
     } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
@@ -674,7 +674,7 @@ inline bool is_no_op(const Stmt& stmt) {
 }
 
 template<typename ValueType>
-inline Expr MakeConstScalar(DataType t, ValueType value) {
+inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
   if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
   if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
   if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
@@ -685,11 +685,11 @@ inline Expr MakeConstScalar(DataType t, ValueType value) {
   if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
     return ir::FloatImmNode::make(t, static_cast<double>(value));
   LOG(FATAL) << "cannot make const for type " << t;
-  return Expr();
+  return PrimExpr();
 }
 
 template<typename ValueType, typename>
-inline Expr make_const(DataType t, ValueType value) {
+inline PrimExpr make_const(DataType t, ValueType value) {
   if (t.lanes() == 1) {
     return MakeConstScalar(t, value);
   } else {
@@ -698,7 +698,7 @@ inline Expr make_const(DataType t, ValueType value) {
   }
 }
 
-inline Expr make_zero(DataType t) {
+inline PrimExpr make_zero(DataType t) {
   if (t.is_handle()) {
     return reinterpret(t, make_const(DataType::UInt(64), 0));
   }
@@ -706,43 +706,43 @@ inline Expr make_zero(DataType t) {
 }
 
 // additional const expression overloading
-#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)            \
-  inline Expr Name(Expr& a, Expr b) {                          \
-    a = OpFunc(a, b);                                          \
-    return a;                                                  \
+#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc)             \
+  inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\
+    a = OpFunc(a, b);                                           \
+    return a;                                                   \
   }
 
 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)              \
-  inline Expr Name(const Expr& a, float b) {                   \
-    return Name(a, Expr(b));                                   \
+  inline PrimExpr Name(const PrimExpr& a, float b) { \
+    return Name(a, PrimExpr(b));                          \
   }                                                            \
-  inline Expr Name(float a, const Expr& b) {                   \
-    return Name(Expr(a), b);                                   \
+  inline PrimExpr Name(float a, const PrimExpr& b) { \
+    return Name(PrimExpr(a), b);                          \
   }                                                            \
-  inline Expr Name(int a, const Expr& b) {                     \
+  inline PrimExpr Name(int a, const PrimExpr& b) {   \
     return Name(make_const(b.dtype(), a), b);                  \
   }                                                            \
-  inline Expr Name(const Expr& a, int b) {                     \
+  inline PrimExpr Name(const PrimExpr& a, int b) {   \
     return Name(a, make_const(a.dtype(), b));                  \
   }                                                            \
-  inline Expr Name(const Expr& a, double b) {                  \
-    return Name(a, make_const(DataType::Float(64), b));                  \
+  inline PrimExpr Name(const PrimExpr& a, double b) {\
+    return Name(a, make_const(DataType::Float(64), b));        \
   }
 
-#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)                  \
-  inline Expr Name(const Expr& a, bool b) {                             \
-    return Name(a, Expr(b));                                            \
-  }                                                                     \
-  inline Expr Name(bool a, const Expr& b) {                             \
-    return Name(Expr(a), b);                                            \
+#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)        \
+  inline PrimExpr Name(const PrimExpr& a, bool b) { \
+    return Name(a, PrimExpr(b));                         \
+  }                                                           \
+  inline PrimExpr Name(bool a, const PrimExpr& b) { \
+    return Name(PrimExpr(a), b);                         \
   }
 
-#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)                      \
-  inline Expr Name(const Expr& a, int b) {                              \
-    return Name(a, make_const(a.dtype(), b));                           \
-  }                                                                     \
-  inline Expr Name(int a, const Expr& b) {                              \
-    return Name(make_const(b.dtype(), a), b);                           \
+#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name)            \
+  inline PrimExpr Name(const PrimExpr& a, int b) {  \
+    return Name(a, make_const(a.dtype(), b));                 \
+  }                                                           \
+  inline PrimExpr Name(int a, const PrimExpr& b) {  \
+    return Name(make_const(b.dtype(), a), b);                 \
   }
 
 
@@ -798,19 +798,19 @@ inline void DivAmbiguityError(const TA& a) {
 // The second template argument is necessary to make sure the
 // code compiles lazily by the compiler during invocation.
 template<typename TB>
-inline Expr operator/(const Expr& a, const TB& b) {
+inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
   DivAmbiguityError(a);
   return a;
 }
 
 template<typename TB>
-inline Expr operator/=(const Expr& a, const TB& b) {
+inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
   DivAmbiguityError(a);
   return a;
 }
 
 template<typename TB>
-inline Expr operator%(const Expr& a, const TB& b) {
+inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
   DivAmbiguityError(a);
   return a;
 }
index 11ce09d..8403948 100644 (file)
@@ -40,7 +40,7 @@ using IntImmNode = tvm::IntImmNode;
 using VarNode = tvm::VarNode;
 
 /*! \brief constant unsigned integer. */
-class UIntImmNode : public ExprNode {
+class UIntImmNode : public PrimExprNode {
  public:
   /*! \brief The constant value content. */
   uint64_t value;
@@ -50,14 +50,14 @@ class UIntImmNode : public ExprNode {
     v->Visit("value", &value);
   }
 
-  TVM_DLL static Expr make(DataType t, uint64_t value);
+  TVM_DLL static PrimExpr make(DataType t, uint64_t value);
 
   static constexpr const char* _type_key = "UIntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, PrimExprNode);
 };
 
 /*! \brief Floating point constants. */
-class FloatImmNode : public ExprNode {
+class FloatImmNode : public PrimExprNode {
  public:
   /*! \brief The constant value content. */
   double value;
@@ -67,14 +67,14 @@ class FloatImmNode : public ExprNode {
     v->Visit("value", &value);
   }
 
-  TVM_DLL static Expr make(DataType t, double value);
+  TVM_DLL static PrimExpr make(DataType t, double value);
 
   static constexpr const char* _type_key = "FloatImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
 };
 
 /*! \brief String constants, only used in asserts. */
-class StringImmNode : public ExprNode {
+class StringImmNode : public PrimExprNode {
  public:
   /*! \brief The constant value content. */
   std::string value;
@@ -84,30 +84,30 @@ class StringImmNode : public ExprNode {
     v->Visit("value", &value);
   }
 
-  TVM_DLL Expr static make(std::string value);
+  TVM_DLL PrimExpr static make(std::string value);
 
   static constexpr const char* _type_key = "StringImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
 };
 
 /*!
  * \brief Cast value from one data type to another.
  * \note The lanes of value should keep fixed.
  */
-class CastNode : public ExprNode {
+class CastNode : public PrimExprNode {
  public:
   /*! \brief Original data type. */
-  Expr value;
+  PrimExpr value;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
     v->Visit("value", &value);
   }
 
-  TVM_DLL static Expr make(DataType t, Expr v);
+  TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
 
   static constexpr const char* _type_key = "Cast";
-  TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
 };
 
 /*!
@@ -115,12 +115,12 @@ class CastNode : public ExprNode {
  * \tparam T The type of the child class.
  */
 template<typename T>
-class BinaryOpNode : public ExprNode {
+class BinaryOpNode : public PrimExprNode {
  public:
   /*! \brief The left operand. */
-  Expr a;
+  PrimExpr a;
   /*! \brief The right operand. */
-  Expr b;
+  PrimExpr b;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->dtype));
@@ -128,7 +128,7 @@ class BinaryOpNode : public ExprNode {
     v->Visit("b", &b);
   }
 
-  static Expr make(Expr a, Expr b) {
+  static PrimExpr make(PrimExpr a, PrimExpr b) {
     CHECK(a.defined()) << "ValueError: a is undefined\n";
     CHECK(b.defined()) << "ValueError: b is undefined\n";
     CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
@@ -136,10 +136,10 @@ class BinaryOpNode : public ExprNode {
     node->dtype = a.dtype();
     node->a = std::move(a);
     node->b = std::move(b);
-    return Expr(node);
+    return PrimExpr(node);
   }
 
-  TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
 };
 
 /*! \brief a + b */
@@ -207,12 +207,12 @@ class MaxNode : public BinaryOpNode<MaxNode> {
  * \tparam T The type of the child class.
  */
 template<typename T>
-class CmpOpNode : public ExprNode {
+class CmpOpNode : public PrimExprNode {
  public:
   /*! \brief The left operand. */
-  Expr a;
+  PrimExpr a;
   /*! \brief The right operand. */
-  Expr b;
+  PrimExpr b;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->dtype));
@@ -220,7 +220,7 @@ class CmpOpNode : public ExprNode {
     v->Visit("b", &b);
   }
 
-  static Expr make(Expr a, Expr b) {
+  static PrimExpr make(PrimExpr a, PrimExpr b) {
     CHECK(a.defined()) << "ValueError: a is undefined\n";
     CHECK(b.defined()) << "ValueError: b is undefined\n";
     CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n";
@@ -228,10 +228,10 @@ class CmpOpNode : public ExprNode {
     node->dtype = DataType::Bool(a.dtype().lanes());
     node->a = std::move(a);
     node->b = std::move(b);
-    return Expr(node);
+    return PrimExpr(node);
   }
 
-  TVM_DECLARE_FINAL_OBJECT_INFO(T, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
 };
 
 /*! \brief a == b */
@@ -271,12 +271,12 @@ class GENode : public CmpOpNode<GENode> {
 };
 
 /*! \brief a && b */
-class AndNode : public ExprNode {
+class AndNode : public PrimExprNode {
  public:
   /*! \brief The left operand. */
-  Expr a;
+  PrimExpr a;
   /*! \brief The right operand. */
-  Expr b;
+  PrimExpr b;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &(this->dtype));
@@ -284,19 +284,19 @@ class AndNode : public ExprNode {
     v->Visit("b", &b);
   }
 
-  TVM_DLL static Expr make(Expr a, Expr b);
+  TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
 
   static constexpr const char* _type_key = "And";
-  TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
 };
 
 /*! \brief a || b */
-class OrNode : public ExprNode {
+class OrNode : public PrimExprNode {
  public:
   /*! \brief The left operand. */
-  Expr a;
+  PrimExpr a;
   /*! \brief The right operand. */
-  Expr b;
+  PrimExpr b;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
@@ -304,27 +304,27 @@ class OrNode : public ExprNode {
     v->Visit("b", &b);
   }
 
-  TVM_DLL static Expr make(Expr a, Expr b);
+  TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
 
   static constexpr const char* _type_key = "Or";
-  TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
 };
 
 /*! \brief !a */
-class NotNode : public ExprNode {
+class NotNode : public PrimExprNode {
  public:
   /*! \brief The input operand. */
-  Expr a;
+  PrimExpr a;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
     v->Visit("a", &a);
   }
 
-  TVM_DLL static Expr make(Expr a);
+  TVM_DLL static PrimExpr make(PrimExpr a);
 
   static constexpr const char* _type_key = "Not";
-  TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
 };
 
 /*!
@@ -334,14 +334,14 @@ class NotNode : public ExprNode {
  *       Do not use it to guard against out of bound access,
  *       please use if_then_else instead.
  */
-class SelectNode : public ExprNode {
+class SelectNode : public PrimExprNode {
  public:
   /*! \brief The condition */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief value to be returned when condition is true. */
-  Expr true_value;
+  PrimExpr true_value;
   /*! \brief value to be returned when condition is false. */
-  Expr false_value;
+  PrimExpr false_value;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
@@ -350,10 +350,10 @@ class SelectNode : public ExprNode {
     v->Visit("false_value", &false_value);
   }
 
-  TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
+  TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
 
   static constexpr const char* _type_key = "Select";
-  TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
 };
 
 /*!
@@ -371,14 +371,14 @@ class SelectNode : public ExprNode {
  *
  * \endcode
  */
-class LoadNode : public ExprNode {
+class LoadNode : public PrimExprNode {
  public:
   /*! \brief The buffer variable. */
   Var buffer_var;
   /*! \brief The index locations to be loaded. */
-  Expr index;
+  PrimExpr index;
   /*! \brief The predicate to mask which lanes would be loaded. */
-  Expr predicate;
+  PrimExpr predicate;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
@@ -387,10 +387,10 @@ class LoadNode : public ExprNode {
     v->Visit("predicate", &predicate);
   }
 
-  TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate);
+  TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
 
   static constexpr const char* _type_key = "Load";
-  TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
 };
 
 /*!
@@ -402,12 +402,12 @@ class LoadNode : public ExprNode {
  *  - ramp(0, 1, 3) = [0, 1, 2]
  *  - ramp(1, 2, 4) = [1, 3, 5, 7]
  */
-class RampNode : public ExprNode {
+class RampNode : public PrimExprNode {
  public:
   /*! \brief The base value. */
-  Expr base;
+  PrimExpr base;
   /*! \brief The stride of each step. */
-  Expr stride;
+  PrimExpr stride;
   /*! \brief Total number of lanes. */
   int lanes;
 
@@ -418,17 +418,17 @@ class RampNode : public ExprNode {
     v->Visit("lanes", &lanes);
   }
 
-  TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
+  TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
 
   static constexpr const char* _type_key = "Ramp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
 };
 
 /*! \brief Create a vector where all the elements are value. */
-class BroadcastNode : public ExprNode {
+class BroadcastNode : public PrimExprNode {
  public:
   /*! \brief The base value. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The number of lanes. */
   int lanes;
 
@@ -438,23 +438,23 @@ class BroadcastNode : public ExprNode {
     v->Visit("lanes", &lanes);
   }
 
-  TVM_DLL static Expr make(Expr value, int lanes);
+  TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
 
   static constexpr const char* _type_key = "Broadcast";
-  TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
 };
 
 /*!
  * \brief Let binding. Bind var to value then evaluate body.
  */
-class LetNode : public ExprNode {
+class LetNode : public PrimExprNode {
  public:
   /*! \brief The variable. */
   Var var;
   /*! \brief The value to be binded. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The result expression. */
-  Expr body;
+  PrimExpr body;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
@@ -463,10 +463,10 @@ class LetNode : public ExprNode {
     v->Visit("body", &body);
   }
 
-  TVM_DLL static Expr make(Var var, Expr value, Expr body);
+  TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
 
   static constexpr const char* _type_key = "Let";
-  TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
 };
 
 // Call node, represent a function call or a multi-dimensional array load.
@@ -494,7 +494,7 @@ class FunctionRef : public ObjectRef {
 /*!
  * \brief Call node.
  */
-class CallNode : public ExprNode {
+class CallNode : public PrimExprNode {
  public:
   /*! \brief Possible types of calls. */
   enum CallType : int {
@@ -514,7 +514,7 @@ class CallNode : public ExprNode {
   /*! \brief The name of the function/intrinsic. */
   std::string name;
   /*! \brief The arguments. */
-  Array<Expr> args;
+  Array<PrimExpr> args;
   /*! \brief Type of calls. */
   CallType call_type;
   /*! \brief The function to be called. */
@@ -531,9 +531,9 @@ class CallNode : public ExprNode {
     v->Visit("value_index", &value_index);
   }
 
-  TVM_DLL static Expr make(DataType dtype,
+  TVM_DLL static PrimExpr make(DataType dtype,
                            std::string name,
-                           Array<Expr> args,
+                           Array<PrimExpr> args,
                            CallType call_type,
                            FunctionRef func = FunctionRef(),
                            int value_index = 0);
@@ -560,7 +560,7 @@ class CallNode : public ExprNode {
   bool is_vectorizable() const;
 
   static constexpr const char* _type_key = "Call";
-  TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
 
   // Build-in intrinsics
   static constexpr const char* reinterpret = "reinterpret";
@@ -585,24 +585,24 @@ class CallNode : public ExprNode {
  *  vec = concat(vectors)
  *  result = (vec[indices[0]], vec[indices[1]] ...)
  */
-class ShuffleNode : public ExprNode {
+class ShuffleNode : public PrimExprNode {
  public:
   /*! \brief the input vectors. */
-  Array<Expr> vectors;
+  Array<PrimExpr> vectors;
   /*! \brief The indices of each element. */
-  Array<Expr> indices;
+  Array<PrimExpr> indices;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("vectors", &vectors);
     v->Visit("indices", &indices);
   }
 
-  TVM_DLL static Expr make(Array<Expr> vectors, Array<Expr> indices);
-  TVM_DLL static Expr make_concat(Array<Expr> vectors);
-  TVM_DLL static Expr make_extract_element(Expr vector, int index);
+  TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
+  TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
+  TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
 
   static constexpr const char* _type_key = "Shuffle";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
 };
 
 // Reduce operator
@@ -637,20 +637,20 @@ class CommReducerNode : public Object {
   /*! \brief The right argument of reducer */
   Array<Var> rhs;
   /*! \brief The result of reducer */
-  Array<Expr> result;
+  Array<PrimExpr> result;
   /*!
    * \brief The identity element of reducer, which leaves other
    *  elements unchanged when combined with it, with respect to
    *  the binary operation of this reducer uses.
    */
-  Array<Expr> identity_element;
+  Array<PrimExpr> identity_element;
   /*! \brief Function call operator to combine a and b */
-  Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const;
+  Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
   /*! \brief construct CommReducer from args, result and identity_element */
   TVM_DLL static CommReducer make(Array<Var> lhs,
                                   Array<Var> rhs,
-                                  Array<Expr> result,
-                                  Array<Expr> identity_element);
+                                  Array<PrimExpr> result,
+                                  Array<PrimExpr> identity_element);
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("lhs", &lhs);
@@ -671,27 +671,27 @@ inline const CommReducerNode* CommReducer::operator->() const {
 }
 
 /*! \brief Reduction operator operator */
-class ReduceNode : public ExprNode {
+class ReduceNode : public PrimExprNode {
  public:
   /*! \brief The commutative combiner */
   CommReducer combiner;
   /*! \brief The source operand */
-  Array<Expr> source;
+  Array<PrimExpr> source;
   /*! \brief The reduction axis */
   Array<IterVar> axis;
   /*!
    * \brief Predicate on the reduction
    *  Only add the body to reduction if condition is true.
    */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief the index of this reduce node */
   int value_index;
 
   /*! \brief construct expr from op and rdom */
-  TVM_DLL static Expr make(CommReducer combiner,
-                           Array<Expr> src,
+  TVM_DLL static PrimExpr make(CommReducer combiner,
+                           Array<PrimExpr> src,
                            Array<IterVar> rdom,
-                           Expr condition,
+                           PrimExpr condition,
                            int value_index);
 
   void VisitAttrs(AttrVisitor* v) {
@@ -704,11 +704,11 @@ class ReduceNode : public ExprNode {
   }
 
   static constexpr const char* _type_key = "Reduce";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
 };
 
 /*! \brief Any shape. */
-class AnyNode : public ExprNode {
+class AnyNode : public PrimExprNode {
  public:
   void VisitAttrs(AttrVisitor* v) {}
   /*! \brief Convert to var. */
@@ -716,10 +716,10 @@ class AnyNode : public ExprNode {
     return VarNode::make(DataType::Int(32), "any_dim");
   }
 
-  TVM_DLL static Expr make();
+  TVM_DLL static PrimExpr make();
 
   static constexpr const char* _type_key = "Any";
-  TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
 };
 
 // Statements
@@ -731,7 +731,7 @@ class LetStmtNode : public StmtNode {
   /*! \brief The variable. */
   Var var;
   /*! \brief The value to be binded. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The body block. */
   Stmt body;
 
@@ -741,7 +741,7 @@ class LetStmtNode : public StmtNode {
     v->Visit("body", &body);
   }
 
-  TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
+  TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
 
   static constexpr const char* _type_key = "LetStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
@@ -764,7 +764,7 @@ class AttrStmtNode : public StmtNode {
   /*! \brief the type key of the attribute */
   std::string attr_key;
   /*! \brief The attribute value, value is well defined at current scope. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The body statement to be executed */
   Stmt body;
 
@@ -777,7 +777,7 @@ class AttrStmtNode : public StmtNode {
 
   TVM_DLL static Stmt make(ObjectRef node,
                            std::string type_key,
-                           Expr value,
+                           PrimExpr value,
                            Stmt body);
 
   static constexpr const char* _type_key = "AttrStmt";
@@ -790,9 +790,9 @@ class AttrStmtNode : public StmtNode {
 class AssertStmtNode : public StmtNode {
  public:
   /*! \brief Condition to be checked. */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief Error message when assertion failed. */
-  Expr message;
+  PrimExpr message;
   /*!
    * \brief Body which this assertion holds true.
    *  Will be executed after the assertion.
@@ -805,7 +805,7 @@ class AssertStmtNode : public StmtNode {
     v->Visit("body", &body);
   }
 
-  TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
+  TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
 
   static constexpr const char* _type_key = "AssertStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
@@ -857,11 +857,11 @@ class StoreNode : public StmtNode {
   /*! \brief The buffer variable. */
   Var buffer_var;
   /*! \brief The value to be stored. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The index locations to be stored. */
-  Expr index;
+  PrimExpr index;
   /*! \brief The predicate to mask which lanes would be stored. */
-  Expr predicate;
+  PrimExpr predicate;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("buffer_var", &buffer_var);
@@ -871,9 +871,9 @@ class StoreNode : public StmtNode {
   }
 
   TVM_DLL static Stmt make(Var buffer_var,
-                           Expr value,
-                           Expr index,
-                           Expr predicate);
+                           PrimExpr value,
+                           PrimExpr index,
+                           PrimExpr predicate);
 
   static constexpr const char* _type_key = "Store";
   TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
@@ -889,9 +889,9 @@ class ProvideNode : public StmtNode {
   /*! \brief The output value index if func's value is a tuple. */
   int value_index{0};
   /*! \brief The value to be stored. */
-  Expr value;
+  PrimExpr value;
   /*! \brief The index arguments of the function. */
-  Array<Expr> args;
+  Array<PrimExpr> args;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("func", &func);
@@ -902,8 +902,8 @@ class ProvideNode : public StmtNode {
 
   TVM_DLL static Stmt make(FunctionRef func,
                            int value_index,
-                           Expr value,
-                           Array<Expr> args);
+                           PrimExpr value,
+                           Array<PrimExpr> args);
 
   static constexpr const char* _type_key = "Provide";
   TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
@@ -919,14 +919,14 @@ class AllocateNode : public StmtNode {
   /*! \brief The type of the buffer. */
   DataType dtype;
   /*! \brief The extents of the buffer. */
-  Array<Expr> extents;
+  Array<PrimExpr> extents;
   /*! \brief Only allocate buffer when condition is satisfied. */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief The body to be executed. */
   Stmt body;
   // The following two fields are deprecated
   // kept for backward compatibility and will be refactored later.
-  Expr new_expr;
+  PrimExpr new_expr;
   std::string free_function;
 
   void VisitAttrs(AttrVisitor* v) {
@@ -939,10 +939,10 @@ class AllocateNode : public StmtNode {
 
   TVM_DLL static Stmt make(Var buffer_var,
                            DataType dtype,
-                           Array<Expr> extents,
-                           Expr condition,
+                           Array<PrimExpr> extents,
+                           PrimExpr condition,
                            Stmt body,
-                           Expr new_expr = Expr(),
+                           PrimExpr new_expr = PrimExpr(),
                            std::string free_function = std::string());
 
   /*!
@@ -960,7 +960,7 @@ class AllocateNode : public StmtNode {
    * \return The result.
    */
   TVM_DLL static int32_t constant_allocation_size(
-      const Array<Expr>& extents);
+      const Array<PrimExpr>& extents);
 
   static constexpr const char* _type_key = "Allocate";
   TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
@@ -997,7 +997,7 @@ class RealizeNode : public StmtNode {
   /*! \brief Bounds to be realized. */
   Region bounds;
   /*! \brief Only realize if condition holds. */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief The body of realization. */
   Stmt body;
 
@@ -1014,7 +1014,7 @@ class RealizeNode : public StmtNode {
                            int value_index,
                            DataType dtype,
                            Region bounds,
-                           Expr condition,
+                           PrimExpr condition,
                            Stmt body);
 
   static constexpr const char* _type_key = "Realize";
@@ -1136,7 +1136,7 @@ class SeqStmt : public Stmt {
 class IfThenElseNode : public StmtNode {
  public:
   /*! \brief The condition. */
-  Expr condition;
+  PrimExpr condition;
   /*! \brief The branch to be executed when condition is true. */
   Stmt then_case;
   /*! \brief The branch to be executed when condition is false, can be null. */
@@ -1148,7 +1148,7 @@ class IfThenElseNode : public StmtNode {
     v->Visit("else_case", &else_case);
   }
 
-  TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
+  TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
 
   static constexpr const char* _type_key = "IfThenElse";
   TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
@@ -1163,13 +1163,13 @@ class IfThenElseNode : public StmtNode {
 class EvaluateNode : public StmtNode {
  public:
   /*! \brief The expression to be evaluated. */
-  Expr value;
+  PrimExpr value;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("value", &value);
   }
 
-  TVM_DLL static Stmt make(Expr v);
+  TVM_DLL static Stmt make(PrimExpr v);
 
   static constexpr const char* _type_key = "Evaluate";
   TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
@@ -1209,9 +1209,9 @@ class ForNode : public StmtNode {
   /*! \brief The loop variable. */
   Var loop_var;
   /*! \brief The minimum value of iteration. */
-  Expr min;
+  PrimExpr min;
   /*! \brief The extent of the iteration. */
-  Expr extent;
+  PrimExpr extent;
   /*! \brief The type of the for loop. */
   ForType for_type;
   /*!
@@ -1223,8 +1223,8 @@ class ForNode : public StmtNode {
   Stmt body;
 
   TVM_DLL static Stmt make(Var loop_var,
-                           Expr min,
-                           Expr extent,
+                           PrimExpr min,
+                           PrimExpr extent,
                            ForType for_type,
                            DeviceAPI device_api,
                            Stmt body);
@@ -1707,7 +1707,7 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
  * \param dtype The data type
  * \return Expr a expression with dtype.
  */
-inline Expr TypeAnnotation(DataType dtype) {
+inline PrimExpr TypeAnnotation(DataType dtype) {
   return ir::CallNode::make(dtype,
                         "type_annotation", {},
                         ir::CallNode::PureIntrinsic);
index d70c8de..7d57564 100644 (file)
@@ -102,9 +102,9 @@ class StmtFunctor;
       });                                                               \
 
 template<typename R, typename ...Args>
-class ExprFunctor<R(const Expr& n, Args...)> {
+class ExprFunctor<R(const PrimExpr& n, Args...)> {
  private:
-  using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
+  using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
   using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
 
  public:
@@ -118,7 +118,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
    * \param args Additional arguments.
    * \return The result of the call
    */
-  R operator()(const Expr& n, Args... args) {
+  R operator()(const PrimExpr& n, Args... args) {
     return VisitExpr(n, std::forward<Args>(args)...);
   }
   /*!
@@ -127,7 +127,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
    * \param args Additional arguments.
    * \return The result of the call
    */
-  virtual R VisitExpr(const Expr& n, Args... args) {
+  virtual R VisitExpr(const PrimExpr& n, Args... args) {
     static FType vtable = InitVTable();
     return vtable(n, this, std::forward<Args>(args)...);
   }
@@ -291,7 +291,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
  * \brief ExprVisitor
  */
 class TVM_DLL ExprVisitor :
-      public ExprFunctor<void(const Expr&)> {
+      public ExprFunctor<void(const PrimExpr&)> {
  public:
   using ExprFunctor::operator();
 
@@ -336,45 +336,45 @@ class TVM_DLL ExprVisitor :
  * \brief ExprMutator that mutates expressions.
  */
 class TVM_DLL ExprMutator :
-      protected ExprFunctor<Expr(const Expr&)> {
+      protected ExprFunctor<PrimExpr(const PrimExpr&)> {
  public:
   using ExprFunctor::operator();
 
  protected:
   using ExprFunctor::VisitExpr;
   // list of functions to override.
-  Expr VisitExpr_(const VarNode* op) override;
-  Expr VisitExpr_(const LoadNode* op) override;
-  Expr VisitExpr_(const LetNode* op) override;
-  Expr VisitExpr_(const CallNode* op) override;
-  Expr VisitExpr_(const AddNode* op) override;
-  Expr VisitExpr_(const SubNode* op) override;
-  Expr VisitExpr_(const MulNode* op) override;
-  Expr VisitExpr_(const DivNode* op) override;
-  Expr VisitExpr_(const ModNode* op) override;
-  Expr VisitExpr_(const FloorDivNode* op) override;
-  Expr VisitExpr_(const FloorModNode* op) override;
-  Expr VisitExpr_(const MinNode* op) override;
-  Expr VisitExpr_(const MaxNode* op) override;
-  Expr VisitExpr_(const EQNode* op) override;
-  Expr VisitExpr_(const NENode* op) override;
-  Expr VisitExpr_(const LTNode* op) override;
-  Expr VisitExpr_(const LENode* op) override;
-  Expr VisitExpr_(const GTNode* op) override;
-  Expr VisitExpr_(const GENode* op) override;
-  Expr VisitExpr_(const AndNode* op) override;
-  Expr VisitExpr_(const OrNode* op) override;
-  Expr VisitExpr_(const ReduceNode* op) override;
-  Expr VisitExpr_(const CastNode* op) override;
-  Expr VisitExpr_(const NotNode* op) override;
-  Expr VisitExpr_(const SelectNode* op) override;
-  Expr VisitExpr_(const RampNode* op) override;
-  Expr VisitExpr_(const BroadcastNode* op) override;
-  Expr VisitExpr_(const ShuffleNode* op) override;
-  Expr VisitExpr_(const IntImmNode* op) override;
-  Expr VisitExpr_(const UIntImmNode* op) override;
-  Expr VisitExpr_(const FloatImmNode* op) override;
-  Expr VisitExpr_(const StringImmNode* op) override;
+  PrimExpr VisitExpr_(const VarNode* op) override;
+  PrimExpr VisitExpr_(const LoadNode* op) override;
+  PrimExpr VisitExpr_(const LetNode* op) override;
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  PrimExpr VisitExpr_(const AddNode* op) override;
+  PrimExpr VisitExpr_(const SubNode* op) override;
+  PrimExpr VisitExpr_(const MulNode* op) override;
+  PrimExpr VisitExpr_(const DivNode* op) override;
+  PrimExpr VisitExpr_(const ModNode* op) override;
+  PrimExpr VisitExpr_(const FloorDivNode* op) override;
+  PrimExpr VisitExpr_(const FloorModNode* op) override;
+  PrimExpr VisitExpr_(const MinNode* op) override;
+  PrimExpr VisitExpr_(const MaxNode* op) override;
+  PrimExpr VisitExpr_(const EQNode* op) override;
+  PrimExpr VisitExpr_(const NENode* op) override;
+  PrimExpr VisitExpr_(const LTNode* op) override;
+  PrimExpr VisitExpr_(const LENode* op) override;
+  PrimExpr VisitExpr_(const GTNode* op) override;
+  PrimExpr VisitExpr_(const GENode* op) override;
+  PrimExpr VisitExpr_(const AndNode* op) override;
+  PrimExpr VisitExpr_(const OrNode* op) override;
+  PrimExpr VisitExpr_(const ReduceNode* op) override;
+  PrimExpr VisitExpr_(const CastNode* op) override;
+  PrimExpr VisitExpr_(const NotNode* op) override;
+  PrimExpr VisitExpr_(const SelectNode* op) override;
+  PrimExpr VisitExpr_(const RampNode* op) override;
+  PrimExpr VisitExpr_(const BroadcastNode* op) override;
+  PrimExpr VisitExpr_(const ShuffleNode* op) override;
+  PrimExpr VisitExpr_(const IntImmNode* op) override;
+  PrimExpr VisitExpr_(const UIntImmNode* op) override;
+  PrimExpr VisitExpr_(const FloatImmNode* op) override;
+  PrimExpr VisitExpr_(const StringImmNode* op) override;
 };
 
 /*!
@@ -394,7 +394,7 @@ class TVM_DLL StmtVisitor :
    *       or have a class sub-class both StmtVisitor and ExprVisitor
    *       and redirect Visit to ExprMutator::VisitExpr(Expr)
    */
-  virtual void VisitExpr(const Expr& e) {}
+  virtual void VisitExpr(const PrimExpr& e) {}
   // statement visitor
   void VisitStmt_(const AttrStmtNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
@@ -486,7 +486,7 @@ class TVM_DLL StmtMutator :
    *       or have a class sub-class both StmtMutator and ExprMutator
    *       and redirect Mutate to ExprMutator::Mutate(Expr)
    */
-  virtual Expr VisitExpr(const Expr& e) {
+  virtual PrimExpr VisitExpr(const PrimExpr& e) {
     return e;
   }
   // statement visitor
@@ -537,7 +537,7 @@ class StmtExprVisitor :
   using StmtVisitor::VisitStmt;
   using ExprVisitor::VisitExpr;
 
-  void VisitExpr(const Expr& e) override {
+  void VisitExpr(const PrimExpr& e) override {
     return ExprVisitor::VisitExpr(e);
   }
 };
@@ -556,7 +556,7 @@ class StmtExprMutator :
   using StmtMutator::VisitExpr;
   using ExprMutator::VisitExpr;
 
-  Expr VisitExpr(const Expr& e) override {
+  PrimExpr VisitExpr(const PrimExpr& e) override {
     return ExprMutator::VisitExpr(e);
   }
 };
@@ -579,7 +579,7 @@ class StmtExprMutator :
 TVM_DLL Stmt IRTransform(Stmt node,
                          const runtime::PackedFunc& preorder,
                          const runtime::PackedFunc& postorder,
-                         const Array<Expr>& only_enable = {});
+                         const Array<PrimExpr>& only_enable = {});
 
 /*!
  * \brief recursively visit the ir in post DFS order node, apply fvisit
index aa1415e..36ca03f 100644 (file)
@@ -45,7 +45,7 @@ namespace ir {
  * \param vrange The range information about the variable.
  * \return Canonicalized statement.
  */
-TVM_DLL Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
+TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());
 
 /*!
  * \brief Simplify the statement.
@@ -70,7 +70,7 @@ Stmt CanonicalSimplify(Stmt stmt,
  * \param vrange The range information about the variable.
  * \return Canonicalized expression.
  */
-TVM_DLL Expr CanonicalSimplify(Expr expr,
+TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
                               Map<Var, Range> vrange = Map<Var, Range>());
 
 /*!
@@ -79,7 +79,7 @@ TVM_DLL Expr CanonicalSimplify(Expr expr,
  * \param rhs The right operand
  * \return The comparison result.
  */
-TVM_DLL bool Equal(const Expr& lhs, const Expr& rhs);
+TVM_DLL bool Equal(const PrimExpr& lhs, const PrimExpr& rhs);
 
 /*!
  * \brief Deep compare lhs and rhs
@@ -100,7 +100,7 @@ bool Equal(const Stmt& lhs, const Stmt& rhs);
  * \param rhs The right operand
  * \return The comparison result.
  */
-int Compare(const Expr& lhs, const Expr& rhs);
+int Compare(const PrimExpr& lhs, const PrimExpr& rhs);
 
 /*!
  * \brief verifies whether the IR stmt or Expr is in SSA form.
@@ -116,7 +116,7 @@ TVM_DLL bool VerifySSA(const Stmt& ir);
  * \brief Whether the expression have side effect.
  * \return whether expression have side effect
  */
-TVM_DLL bool HasSideEffect(const Expr& e);
+TVM_DLL bool HasSideEffect(const PrimExpr& e);
 
 /*!
  * \brief Whether e expression used var.
@@ -124,7 +124,7 @@ TVM_DLL bool HasSideEffect(const Expr& e);
  * \param v The variable.
  * \return Whether e uses v.
  */
-bool ExprUseVar(const Expr& e, const Var& v);
+bool ExprUseVar(const PrimExpr& e, const Var& v);
 
 /*!
  * \brief Whether e expression used any var in variable set..
@@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v);
  * \param vset The variable set.
  * \return Whether e uses vset.
  */
-bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
+bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vset);
 
 /*!
  * \brief Convert a IR node to be SSA form.
@@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt);
  * \return The converted form.
  */
 Stmt Substitute(Stmt stmt,
-                const std::unordered_map<const VarNode*, Expr>& value_map);
+                const std::unordered_map<const VarNode*, PrimExpr>& value_map);
 
 /*!
  * \brief Substitute the var specified in key->var to be value.
@@ -156,8 +156,8 @@ Stmt Substitute(Stmt stmt,
  * \param value_map The map of new values.
  * \return The converted expression.
  */
-Expr Substitute(Expr expr,
-                const std::unordered_map<const VarNode*, Expr>& value_map);
+PrimExpr Substitute(PrimExpr expr,
+                const std::unordered_map<const VarNode*, PrimExpr>& value_map);
 
 /*!
  * \brief Substitute the var specified in key->var to be value.
@@ -165,7 +165,7 @@ Expr Substitute(Expr expr,
  * \param value_map The map of new values.
  * \return The converted form.
  */
-Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
+Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
 
 /*!
  * \brief Substitute the var specified in key->var to be value.
@@ -173,7 +173,7 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
  * \param value_map The map of new values.
  * \return The converted expression.
  */
-Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
+PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
 
 /*!
  * \brief inline all calls of f in stmt.
@@ -189,7 +189,7 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map);
 Stmt Inline(Stmt stmt,
             FunctionRef f,
             Array<Var> args,
-            Expr body);
+            PrimExpr body);
 
 /*!
  * \brief Flatten the multi-dimensional read/write
@@ -485,7 +485,7 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
  * \param axis_map The map from StringImm -> ItrVar
  * \return Transformed function.
  */
-LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map);
+LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
 
 /*!
  * \brief Lower packed function call.
@@ -590,7 +590,7 @@ bool VerifyMemory(LoweredFunc func, int device_type);
  *
  */
 bool VerifyGPUCode(Stmt stmt,
-                   Map<std::string, Expr> constraints);
+                   Map<std::string, PrimExpr> constraints);
 
 
 }  // namespace ir
index 3de6bfd..310b454 100644 (file)
@@ -98,7 +98,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode {
    * \note Expr is used instead Type, because Type cannot be hold by Map.
    *  constant Expr of given type is used.
    */
-  Map<Var, Expr> handle_data_type;
+  Map<Var, PrimExpr> handle_data_type;
   /*! \brief The type of the function */
   LoweredFuncType func_type{kMixedFunc};
   /*! \brief Whether this function is packed function */
index ad8f825..039e26e 100644 (file)
@@ -81,7 +81,7 @@ class OperationNode : public ir::FunctionBaseNode {
    * \param i The output index.
    * \return shape of i-th output.
    */
-  virtual Array<Expr> output_shape(size_t i) const = 0;
+  virtual Array<PrimExpr> output_shape(size_t i) const = 0;
   /*!
    * \brief List all the input Tensors.
    * \return List of input tensors.
@@ -158,14 +158,14 @@ class OperationNode : public ir::FunctionBaseNode {
 class PlaceholderOpNode : public OperationNode {
  public:
   /*! \brief The shape of the input */
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   /*! \brief The data type of the input. */
   DataType dtype;
   // override behavior.
   int num_outputs() const final;
   Array<IterVar> root_iter_vars() const final;
   DataType output_dtype(size_t i) const final;
-  Array<Expr> output_shape(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
   Array<Tensor> InputTensors() const final;
   Operation ReplaceInputs(
       const Operation& self,
@@ -196,7 +196,7 @@ class PlaceholderOpNode : public OperationNode {
     v->Visit("dtype", &dtype);
   }
   static Operation make(std::string name,
-                        Array<Expr> shape,
+                        Array<PrimExpr> shape,
                         DataType dtype);
 
   static constexpr const char* _type_key = "PlaceholderOp";
@@ -216,7 +216,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
   Array<IterVar> reduce_axis;
   // override functions
   Array<IterVar> root_iter_vars() const final;
-  Array<Expr> output_shape(size_t idx) const final;
+  Array<PrimExpr> output_shape(size_t idx) const final;
   void GatherBound(
           const Operation& self,
           const std::unordered_map<Tensor, TensorDom>& tensor_dom,
@@ -238,7 +238,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
 class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
  public:
   /*! \brief the compute expression */
-  Array<Expr> body;
+  Array<PrimExpr> body;
   /*! \brief constructor */
   ComputeOpNode() {}
   // override functions
@@ -271,7 +271,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
                         std::string tag,
                         Map<std::string, ObjectRef> attrs,
                         Array<IterVar> axis,
-                        Array<Expr> body);
+                        Array<PrimExpr> body);
 
   static constexpr const char* _type_key = "ComputeOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
@@ -291,7 +291,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
   /*! \brief region of input tensors */
   Array<Region> input_regions;
   /*! \brief scalar expression inputs */
-  Array<Expr> scalar_inputs;
+  Array<PrimExpr> scalar_inputs;
   /*! \brief constructor */
   TensorComputeOpNode() {}
   // override functions
@@ -331,7 +331,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
                         TensorIntrin intrin,
                         Array<Tensor> tensors,
                         Array<Region> regions,
-                        Array<Expr> scalar_inputs);
+                        Array<PrimExpr> scalar_inputs);
 
   static constexpr const char* _type_key = "TensorComputeOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
@@ -371,7 +371,7 @@ class ScanOpNode : public OperationNode {
   int num_outputs() const final;
   Array<IterVar> root_iter_vars() const final;
   DataType output_dtype(size_t i) const final;
-  Array<Expr> output_shape(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
   Array<Tensor> InputTensors() const final;
   Operation ReplaceInputs(
       const Operation& self,
@@ -438,7 +438,7 @@ class ExternOpNode : public OperationNode {
   int num_outputs() const final;
   Array<IterVar> root_iter_vars() const final;
   DataType output_dtype(size_t i) const final;
-  Array<Expr> output_shape(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
   Array<Tensor> InputTensors() const final;
   Operation ReplaceInputs(
       const Operation& self,
@@ -506,7 +506,7 @@ class HybridOpNode : public OperationNode {
   int num_outputs() const final;
   Array<IterVar> root_iter_vars() const final;
   DataType output_dtype(size_t i) const final;
-  Array<Expr> output_shape(size_t i) const final;
+  Array<PrimExpr> output_shape(size_t i) const final;
   Array<Tensor> InputTensors() const final;
   Operation ReplaceInputs(
       const Operation& self,
@@ -550,10 +550,10 @@ class HybridOpNode : public OperationNode {
 };
 
 /*! \brief The compute function to specify the input source of a Tensor */
-using FCompute = std::function<Expr (const Array<Var>& i)>;
+using FCompute = std::function<PrimExpr (const Array<Var>& i)>;
 
 /*! \brief The compute function to specify the inputs source of Tensors */
-using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
+using FBatchCompute = std::function<Array<PrimExpr> (const Array<Var>& i)>;
 
 /*!
  * \brief create a place holder tensor.
@@ -561,7 +561,7 @@ using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>;
  * \param dtype the data type of the tensor.
  * \param name The name of the Tensor.
  */
-TVM_DLL Tensor placeholder(Array<Expr> shape,
+TVM_DLL Tensor placeholder(Array<PrimExpr> shape,
                            DataType dtype = DataType::Float(32),
                            std::string name = "placeholder");
 
@@ -574,7 +574,7 @@ TVM_DLL Tensor placeholder(Array<Expr> shape,
  * \param tag The optional tag of the tensor.
  * \param attrs Optional additional attributes of the compute.
  */
-TVM_DLL Tensor compute(Array<Expr> shape,
+TVM_DLL Tensor compute(Array<PrimExpr> shape,
                        FCompute fcompute,
                        std::string name = "tensor",
                        std::string tag = "",
@@ -589,7 +589,7 @@ TVM_DLL Tensor compute(Array<Expr> shape,
  * \param tag The optional tag of the tensor.
  * \param attrs Optional additional attributes of the compute.
  */
-TVM_DLL Array<Tensor> compute(Array<Expr> shape,
+TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape,
                               FBatchCompute fcompute,
                               std::string name = "tensor",
                               std::string tag = "",
@@ -616,32 +616,32 @@ TVM_DLL Array<Tensor> scan(Array<Tensor> init,
                            Map<std::string, ObjectRef> attrs = {});
 
 // same as compute, specialized for different fcompute function
-inline Tensor compute(Array<Expr> shape,
-                      std::function<Expr(Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var)> f,
                       std::string name = "tensor",
                       std::string tag = "",
                       Map<std::string, ObjectRef> attrs = {}) {
   FCompute fc = [f] (const Array<Var>& i) { return f(i[0]); };
   return compute(shape, fc, name, tag, attrs);
 }
-inline Tensor compute(Array<Expr> shape,
-                      std::function<Expr(Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var)> f,
                       std::string name = "tensor",
                       std::string tag = "",
                       Map<std::string, ObjectRef> attrs = {}) {
   FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1]); };
   return compute(shape, fc, name, tag, attrs);
 }
-inline Tensor compute(Array<Expr> shape,
-                      std::function<Expr(Var, Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var, Var)> f,
                       std::string name = "tensor",
                       std::string tag = "",
                       Map<std::string, ObjectRef> attrs = {}) {
   FCompute fc = [f] (const Array<Var>& i) { return f(i[0], i[1], i[2]); };
   return  compute(shape, fc, name, tag, attrs);
 }
-inline Tensor compute(Array<Expr> shape,
-                      std::function<Expr(Var, Var, Var, Var)> f,
+inline Tensor compute(Array<PrimExpr> shape,
+                      std::function<PrimExpr(Var, Var, Var, Var)> f,
                       std::string name = "tensor",
                       std::string tag = "",
                       Map<std::string, ObjectRef> attrs = {}) {
index b301a18..fa532ea 100644 (file)
@@ -100,15 +100,15 @@ struct ObjectTypeChecker<Map<K, V> > {
 };
 
 // extensions for tvm arg value
-inline TVMPODValue_::operator tvm::Expr() const {
-  if (type_code_ == kNull) return Expr();
+inline TVMPODValue_::operator tvm::PrimExpr() const {
+  if (type_code_ == kNull) return PrimExpr();
   if (type_code_ == kDLInt) {
     CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
     CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
-    return Expr(static_cast<int>(value_.v_int64));
+    return PrimExpr(static_cast<int>(value_.v_int64));
   }
   if (type_code_ == kDLFloat) {
-    return Expr(static_cast<float>(value_.v_float64));
+    return PrimExpr(static_cast<float>(value_.v_float64));
   }
 
   TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
@@ -120,10 +120,10 @@ inline TVMPODValue_::operator tvm::Expr() const {
   if (ptr->IsInstance<TensorNode>()) {
     return Tensor(ObjectPtr<Object>(ptr))();
   }
-  CHECK(ObjectTypeChecker<Expr>::Check(ptr))
-      << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
+  CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr))
+      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
       << " but get " << ptr->GetTypeKey();
-  return Expr(ObjectPtr<Object>(ptr));
+  return PrimExpr(ObjectPtr<Object>(ptr));
 }
 
 inline TVMPODValue_::operator tvm::Integer() const {
@@ -136,7 +136,7 @@ inline TVMPODValue_::operator tvm::Integer() const {
   TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
   Object* ptr = static_cast<Object*>(value_.v_handle);
   CHECK(ObjectTypeChecker<Integer>::Check(ptr))
-      << "Expect type " << ObjectTypeChecker<Expr>::TypeName()
+      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
       << " but get " << ptr->GetTypeKey();
   return Integer(ObjectPtr<Object>(ptr));
 }
index 1c7fc1c..26637d5 100644 (file)
@@ -26,6 +26,7 @@
 
 #include <tvm/attrs.h>
 #include <tvm/relay/base.h>
+#include <tvm/relay/expr.h>
 #include <string>
 
 namespace tvm {
index b4164fb..f2db652 100644 (file)
@@ -57,7 +57,7 @@ namespace relay {
 /*!
  * \brief Symbolic expression for tensor shape.
  */
-using IndexExpr = ::tvm::Expr;
+using IndexExpr = ::tvm::PrimExpr;
 
 using SourceName = tvm::SourceName;
 using Span = tvm::Span;
index 2d1e45f..1b6155f 100644 (file)
@@ -95,9 +95,9 @@ class PassContextNode : public RelayNode {
   int fallback_device{static_cast<int>(kDLCPU)};
 
   /*! \brief The list of required passes. */
-  tvm::Array<tvm::Expr> required_pass;
+  tvm::Array<tvm::PrimExpr> required_pass;
   /*! \brief The list of disabled passes. */
-  tvm::Array<tvm::Expr> disabled_pass;
+  tvm::Array<tvm::PrimExpr> disabled_pass;
 
   PassContextNode() = default;
 
@@ -192,7 +192,7 @@ class PassInfoNode : public RelayNode {
   std::string name;
 
   /*! \brief The passes that are required to perform the current pass. */
-  tvm::Array<tvm::Expr> required;
+  tvm::Array<tvm::PrimExpr> required;
 
   PassInfoNode() = default;
 
@@ -204,7 +204,7 @@ class PassInfoNode : public RelayNode {
 
   TVM_DLL static PassInfo make(int opt_level,
                                std::string name,
-                               tvm::Array<tvm::Expr> required);
+                               tvm::Array<tvm::PrimExpr> required);
 
   static constexpr const char* _type_key = "relay.PassInfo";
   TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode);
@@ -332,7 +332,7 @@ Pass CreateModulePass(
     const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::Expr>& required);
+    const tvm::Array<tvm::PrimExpr>& required);
 
 /*
  * \brief Create a function pass.
@@ -348,7 +348,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
                                 Function(Function, Module, PassContext)>& pass_func,
                                 int opt_level,
                                 const std::string& name,
-                                const tvm::Array<tvm::Expr>& required);
+                                const tvm::Array<tvm::PrimExpr>& required);
 
 /*! \brief Remove expressions which does not effect the program result.
  *
index 19c9434..afaff9c 100644 (file)
@@ -51,7 +51,7 @@
 namespace tvm {
 // forward declarations
 class Integer;
-class Expr;
+class PrimExpr;
 
 namespace runtime {
 
@@ -495,7 +495,7 @@ class TVMPODValue_ {
   template<typename TObjectRef>
   inline TObjectRef AsObjectRef() const;
   // ObjectRef Specializations
-  inline operator tvm::Expr() const;
+  inline operator tvm::PrimExpr() const;
   inline operator tvm::Integer() const;
 
  protected:
@@ -542,7 +542,7 @@ class TVMArgValue : public TVMPODValue_ {
   using TVMPODValue_::operator Module;
   using TVMPODValue_::IsObjectRef;
   using TVMPODValue_::AsObjectRef;
-  using TVMPODValue_::operator tvm::Expr;
+  using TVMPODValue_::operator tvm::PrimExpr;
   using TVMPODValue_::operator tvm::Integer;
 
   // conversion operator.
@@ -629,7 +629,7 @@ class TVMRetValue : public TVMPODValue_ {
   using TVMPODValue_::operator Module;
   using TVMPODValue_::IsObjectRef;
   using TVMPODValue_::AsObjectRef;
-  using TVMPODValue_::operator tvm::Expr;
+  using TVMPODValue_::operator tvm::PrimExpr;
   using TVMPODValue_::operator tvm::Integer;
 
   TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
index 01caf5a..7e13475 100644 (file)
@@ -112,7 +112,7 @@ class Stage : public ObjectRef {
    * \param predicate The condition to be checked.
    * \return reference to self.
    */
-  TVM_DLL Stage& set_store_predicate(Expr predicate);
+  TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
   /*!
    * \brief Specify environment threads that launched around the group's scope.
    *  This can only be used in group stage.
@@ -130,7 +130,7 @@ class Stage : public ObjectRef {
    * \param p_inner The result inner domain.
    * \return reference to self.
    */
-  TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
+  TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
   /*!
    * \brief Split the iteration with given number of parts.
    *
@@ -140,7 +140,7 @@ class Stage : public ObjectRef {
    * \param p_inner The result inner domain.
    * \return reference to self.
    */
-  TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
+  TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
   /*!
    * \brief Fuse the inner outer domain to the target
    * \param outer The outer domain to be fused.
@@ -185,7 +185,7 @@ class Stage : public ObjectRef {
    * \return reference to self.
    */
   TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
-                     Expr x_factor, Expr y_factor,
+                     PrimExpr x_factor, PrimExpr y_factor,
                      IterVar* p_x_outer, IterVar* p_y_outer,
                      IterVar* p_x_inner, IterVar* p_y_inner);
   /*!
@@ -225,7 +225,7 @@ class Stage : public ObjectRef {
    */
   TVM_DLL Stage& pragma(IterVar var,
                        const std::string& pragma_type,
-                       const Expr& pragma_value = Expr());   // NOLINT(*)
+                       const PrimExpr& pragma_value = PrimExpr());   // NOLINT(*)
   /*!
    * \brief Fetch data in advance.
    * \param domain the tensor to be prefetched
@@ -233,7 +233,7 @@ class Stage : public ObjectRef {
    * \param offset the number of iterations be to fetched in advance
    * \return reference to self
    */
-  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
+  TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*)
   /*!
    * \brief Set alignment requirement for specific dimension.
    *
@@ -468,7 +468,7 @@ class StageNode : public Object {
    *  Use this when there can be duplicated threads doing the same store.
    * \note Experimental primitive: used by cross thread-reduction.
    */
-  Expr store_predicate;
+  PrimExpr store_predicate;
   /*! \brief The relation bwteen of IterVars */
   Array<IterVarRelation> relations;
   /*! \brief additional attributes about iter var. */
@@ -598,7 +598,7 @@ class IterVarAttrNode : public Object {
   /*! \brief List of tensor to be prefetched in this loop */
   Array<Tensor> prefetch_data;
   /*! \brief The offset used in each prefetch */
-  Array<Expr> prefetch_offset;
+  Array<PrimExpr> prefetch_offset;
   /*!
    * \brief Tensor intrinsic used in tensorization,
    *   when the axis is marked as Tensorized
@@ -611,11 +611,11 @@ class IterVarAttrNode : public Object {
   /*!
    * \brief Additional pragma keys, array of StringImm
    */
-  Array<Expr> pragma_keys;
+  Array<PrimExpr> pragma_keys;
   /*!
    * \brief Additional values of pragma, if any
    */
-  Array<Expr> pragma_values;
+  Array<PrimExpr> pragma_values;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("iter_type", &iter_type);
@@ -653,9 +653,9 @@ class SplitNode : public IterVarRelationNode {
   /*! \brief The inner domain */
   IterVar inner;
   /*! \brief The split factor */
-  Expr factor;
+  PrimExpr factor;
   /*! \brief Number of parts, only factor or nparts can be given */
-  Expr nparts;
+  PrimExpr nparts;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("parent", &parent);
@@ -668,8 +668,8 @@ class SplitNode : public IterVarRelationNode {
   static IterVarRelation make(IterVar parent,
                               IterVar outer,
                               IterVar inner,
-                              Expr factor,
-                              Expr nparts);
+                              PrimExpr factor,
+                              PrimExpr nparts);
 
   static constexpr const char* _type_key = "Split";
   TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
index 25fb724..f7b36a0 100644 (file)
@@ -45,7 +45,7 @@ struct MemoryInfoNode : public Object {
    * \brief head address of the buffer, if visible to CPU
    *  This address can be None.
    */
-  Expr head_address;
+  PrimExpr head_address;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("unit_bits", &unit_bits);
index d6e93f5..ecadc35 100644 (file)
@@ -76,8 +76,8 @@ class Tensor : public ObjectRef {
    * \return the result expression representing tensor read.
    */
   template<typename... Args>
-  inline Expr operator()(Args&& ...args) const {
-    Array<Expr> indices{std::forward<Args>(args)...};
+  inline PrimExpr operator()(Args&& ...args) const {
+    Array<PrimExpr> indices{std::forward<Args>(args)...};
     return operator()(indices);
   }
   /*!
@@ -85,13 +85,13 @@ class Tensor : public ObjectRef {
    * \param indices the indices.
    * \return the result expression representing tensor read.
    */
-  TVM_DLL Expr operator()(Array<Expr> indices) const;
+  TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
   /*!
    * \brief Take elements from the tensor
    * \param indices the indices.
    * \return the result expression representing tensor read.
    */
-  TVM_DLL Expr operator()(Array<Var> indices) const;
+  TVM_DLL PrimExpr operator()(Array<Var> indices) const;
   /*!
    * \brief data structure to represent a slice that fixes first k coordinates.
    *  This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
@@ -99,15 +99,15 @@ class Tensor : public ObjectRef {
   class Slice {
    public:
     // construct via tensor and indices
-    Slice(const Tensor& tensor, std::vector<Expr> indices)
+    Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
         : tensor_(tensor), indices_(indices) {}
     /*!
      * \brief get i-th slice from the current slice.
      * \param i the index of the coordinate
      * \return the subsequent slice.
      */
-    inline Slice operator[](Expr i) {
-      std::vector<Expr> other = indices_;
+    inline Slice operator[](PrimExpr i) {
+      std::vector<PrimExpr> other = indices_;
       other.emplace_back(i);
       return Slice(tensor_, other);
     }
@@ -116,20 +116,20 @@ class Tensor : public ObjectRef {
      *  This is only valid when all the coordinates are fully specified.
      * \return the corresponding expression of this slice.
      */
-    inline operator Expr() const {
+    inline operator PrimExpr() const {
       return tensor_(indices_);
     }
 
    private:
     const Tensor& tensor_;
-    std::vector<Expr> indices_;
+    std::vector<PrimExpr> indices_;
   };
   /*!
    * \brief get i-th slice from the current Tensor.
    * \param i the index of the coordinate
    * \return the subsequent slice.
    */
-  inline Slice operator[](Expr i) const {
+  inline Slice operator[](PrimExpr i) const {
     return Slice(*this, {i});
   }
   /*! \brief specify container node */
@@ -161,7 +161,7 @@ class Operation : public ir::FunctionRef {
 class TensorNode : public Object {
  public:
   /*! \brief The shape of the tensor */
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   /*! \brief data type in the content of the tensor */
   DataType dtype;
   /*! \brief the source operation, can be None */
@@ -177,7 +177,7 @@ class TensorNode : public Object {
     v->Visit("op", &op);
     v->Visit("value_index", &value_index);
   }
-  TVM_DLL static Tensor make(Array<Expr> shape,
+  TVM_DLL static Tensor make(Array<PrimExpr> shape,
                              DataType dtype,
                              Operation op,
                              int value_index);
@@ -213,21 +213,21 @@ inline bool Tensor::operator!=(const Tensor& other) const {
 
 // macro to turn every operation of slice to expression
 #define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op)                              \
-  inline Expr operator Op (const Tensor::Slice& a) {                    \
-    return Op a.operator Expr() ;                                       \
+  inline PrimExpr operator Op (const Tensor::Slice& a) {           \
+    return Op a.operator PrimExpr() ;                              \
   }                                                                     \
 
 #define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op)                             \
   template<typename T>                                                  \
-  inline Expr operator Op (const Tensor::Slice& a, const T& b) {        \
-    return a.operator Expr() Op b;                                      \
+  inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \
+    return a.operator PrimExpr() Op b;                             \
   }                                                                     \
   template<typename T>                                                  \
-  inline Expr operator Op (const T& a, const Tensor::Slice& b) {        \
-    return a Op b.operator Expr();                                      \
-  }                                                                     \
-  inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
-    return a.operator Expr() Op b.operator Expr();                      \
+  inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) {  \
+    return a Op b.operator PrimExpr();                                \
+  }                                                                        \
+  inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
+    return a.operator PrimExpr() Op b.operator PrimExpr();                  \
   }
 
 DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
index f973909..879e206 100644 (file)
@@ -150,7 +150,7 @@ class TensorIntrinCallNode : public Object {
   Array<IterVar> reduce_axis;
 
   /*! \brief scalar expression inputs */
-  Array<Expr> scalar_inputs;
+  Array<PrimExpr> scalar_inputs;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("intrin", &intrin);
@@ -163,7 +163,7 @@ class TensorIntrinCallNode : public Object {
                                Array<Tensor> tensors,
                                Array<Region> regions,
                                Array<IterVar> reduce_axis,
-                               Array<Expr> scalar_inputs);
+                               Array<PrimExpr> scalar_inputs);
 
   static constexpr const char* _type_key = "TensorIntrinCall";
   TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
index ef121bc..4d0e347 100644 (file)
@@ -256,7 +256,7 @@ def placeholder(shape, dtype=None, name="placeholder"):
     tensor: Tensor
         The created tensor
     """
-    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+    shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
     dtype = float32 if dtype is None else dtype
     return _api_internal._Placeholder(
         shape, dtype, name)
@@ -293,7 +293,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
         if tag != "":
             raise ValueError("nested tag is not allowed for now")
         tag = _tag.TagScope.get_current().tag
-    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+    shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
     # for python3
     shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
     ndim = len(shape)
@@ -482,8 +482,8 @@ def extern(shape,
         if tag != "":
             raise ValueError("nested tag is not allowed for now")
         tag = _tag.TagScope.get_current().tag
-    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
-    if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
+    shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
+    if shape == () or isinstance(shape[0], (_expr.PrimExpr, _Integral)):
         shape = [shape]
     if in_buffers is not None:
         in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
@@ -518,7 +518,7 @@ def extern(shape,
         for shp, dt in zip(shape, dtype):
             output_placeholders.append(decl_buffer(shp, dt, name))
     body = fcompute(input_placeholders, output_placeholders)
-    if isinstance(body, _expr.Expr):
+    if isinstance(body, _expr.PrimExpr):
         body = _make.Evaluate(body)
 
     op = _api_internal._ExternOp(name, tag, attrs,
@@ -626,7 +626,7 @@ def decl_buffer(shape,
     If user pass a fully generic symbolic array to the strides,
     then the resulting function becomes fully generic.
     """
-    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
+    shape = (shape,) if isinstance(shape, (_expr.PrimExpr, _Integral)) else shape
     dtype = float32 if dtype is None else dtype
     strides = () if strides is None else strides
     if offset_factor != 0 and elem_offset is None:
@@ -827,7 +827,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
             result = fcombine(lhs, rhs)
             id_elem = fidentity(*dtypes)
         else:
-            assert isinstance(expr, _expr.Expr)
+            assert isinstance(expr, _expr.PrimExpr)
             size = 1
             dtype = expr.dtype
             lvar = var(code.co_varnames[0], dtype)
index b82da07..28c703d 100644 (file)
@@ -167,7 +167,7 @@ def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None
     tensor: SparsePlaceholderOp
         The created sparse tensor placeholder
     """
-    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
+    shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
     nonzeros = 0 if nonzeros is None else nonzeros
     dtype = float32 if dtype is None else dtype
     stype = 'csr' if stype is None else stype
index 733f57a..c6b3d9b 100644 (file)
@@ -242,27 +242,27 @@ class NotEqualOp(NodeGeneric, ExprOp):
         return _make._OpNE(self.a, self.b)
 
 
-class Expr(ExprOp, NodeBase):
+class PrimExpr(ExprOp, NodeBase):
     """Base class of all tvm Expressions"""
     # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__
     # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
     __hash__ = NodeBase.__hash__
 
 
-class ConstExpr(Expr):
+class ConstExpr(PrimExpr):
     pass
 
-class BinaryOpExpr(Expr):
+class BinaryOpExpr(PrimExpr):
     pass
 
-class CmpExpr(Expr):
+class CmpExpr(PrimExpr):
     pass
 
-class LogicalExpr(Expr):
+class LogicalExpr(PrimExpr):
     pass
 
 @register_node("Variable")
-class Var(Expr):
+class Var(PrimExpr):
     """Symbolic variable.
 
     Parameters
@@ -279,7 +279,7 @@ class Var(Expr):
 
 
 @register_node
-class Reduce(Expr):
+class Reduce(PrimExpr):
     """Reduce node.
 
     Parameters
@@ -383,7 +383,7 @@ class StringImm(ConstExpr):
 
 
 @register_node
-class Cast(Expr):
+class Cast(PrimExpr):
     """Cast expression.
 
     Parameters
@@ -703,7 +703,7 @@ class Not(LogicalExpr):
 
 
 @register_node
-class Select(Expr):
+class Select(PrimExpr):
     """Select node.
 
     Note
@@ -731,7 +731,7 @@ class Select(Expr):
 
 
 @register_node
-class Load(Expr):
+class Load(PrimExpr):
     """Load node.
 
     Parameters
@@ -754,7 +754,7 @@ class Load(Expr):
 
 
 @register_node
-class Ramp(Expr):
+class Ramp(PrimExpr):
     """Ramp node.
 
     Parameters
@@ -774,7 +774,7 @@ class Ramp(Expr):
 
 
 @register_node
-class Broadcast(Expr):
+class Broadcast(PrimExpr):
     """Broadcast node.
 
     Parameters
@@ -791,7 +791,7 @@ class Broadcast(Expr):
 
 
 @register_node
-class Shuffle(Expr):
+class Shuffle(PrimExpr):
     """Shuffle node.
 
     Parameters
@@ -808,7 +808,7 @@ class Shuffle(Expr):
 
 
 @register_node
-class Call(Expr):
+class Call(PrimExpr):
     """Call node.
 
     Parameters
@@ -843,7 +843,7 @@ class Call(Expr):
 
 
 @register_node
-class Let(Expr):
+class Let(PrimExpr):
     """Let node.
 
     Parameters
index 4148220..1d5612e 100644 (file)
@@ -91,7 +91,7 @@ def _allocate_tensor(func_id, args):
                      "allocate's first argument should be a tuple of shape!")
     shape = args[0]
     for i in shape:
-        _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
+        _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression")
     if n > 1:
         _internal_assert(isinstance(args[1], str),
                          "The data type should be an str")
@@ -125,7 +125,7 @@ def len(func_id, args):
 
 
 def _cast(func_id, args):
-    _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
+    _internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), \
                      "Only one expression can be cast")
     return _make.Cast(func_id, args[0])
 
@@ -137,8 +137,8 @@ uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name
 def ceil_div(func_id, args):
     _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
     _internal_assert(args.__len__() == 2, "2 arguments expected for division!")
-    _internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
-    _internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
+    _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div")
+    _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div")
     a, b = args[0], args[1]
     return (a + b - 1) // b
 
index 7e5659a..06bcbca 100644 (file)
@@ -325,7 +325,7 @@ class HybridParser(ast.NodeVisitor):
 
         _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
         lhs = node.targets[0]
-        if isinstance(rhs, _expr.Expr):
+        if isinstance(rhs, _expr.PrimExpr):
             rhs = _ir_pass.Simplify(rhs)
         if isinstance(lhs, ast.Name):
             #TODO: support defined intermediate buffer later
index 91fd291..6b577c4 100644 (file)
@@ -132,7 +132,7 @@ class Buffer(NodeBase):
         load : Expr
             The corresponding load expression.
         """
-        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
+        begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
         dtype = dtype if dtype else self.dtype
         return _api_internal._BufferVLoad(self, begin, dtype)
 
@@ -152,7 +152,7 @@ class Buffer(NodeBase):
         store : Stmt
             The corresponding store stmt.
         """
-        begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
+        begin = (begin,) if isinstance(begin, (int, _expr.PrimExpr)) else begin
         return _api_internal._BufferVStore(self, begin, value)
 
 
index 1cadf06..e4a2f4f 100644 (file)
@@ -65,7 +65,7 @@ class Tensor(NodeBase, _expr.ExprOp):
         indices = convert_to_node(indices)
         args = []
         for x in indices:
-            if isinstance(x, _expr.Expr):
+            if isinstance(x, _expr.PrimExpr):
                 args.append(x)
             elif isinstance(x, iter_var_cls):
                 args.append(x.var)
index 2ef7a4b..378cfe5 100644 (file)
@@ -133,9 +133,9 @@ def decl_tensor_intrin(op,
     else:
         body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
         scalar_params = []
-    if isinstance(body, (_expr.Expr, _stmt.Stmt)):
+    if isinstance(body, (_expr.PrimExpr, _stmt.Stmt)):
         body = [body]
-    body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
+    body = [_make.Evaluate(x) if isinstance(x, _expr.PrimExpr) else x for x in body]
     if len(body) < 3:
         body += [None] * (3 - len(body))
     return _api_internal._TensorIntrin(
index a69fd5d..7150d27 100644 (file)
@@ -49,7 +49,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound")
 
 TVM_REGISTER_GLOBAL("arith.DeduceBound")
 .set_body_typed([](
-  Expr v, Expr cond,
+  PrimExpr v, PrimExpr cond,
   const Map<Var, IntSet> hint_map,
   const Map<Var, IntSet> relax_map
 ) {
@@ -121,7 +121,7 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
             if (args[1].IsObjectRef<Range>()) {
               self->Bind(args[0], args[1].operator Range());
             } else {
-              self->Bind(args[0], args[1].operator Expr());
+              self->Bind(args[0], args[1].operator PrimExpr());
             }
         });
       } else if (name == "enter_constraint_context") {
index ba04239..ca4823b 100644 (file)
@@ -71,7 +71,7 @@ TVM_REGISTER_GLOBAL("make.SeqStmt")
 
 TVM_REGISTER_GLOBAL("make.For")
 .set_body_typed([](
-  VarExpr loop_var, Expr min, Expr extent,
+  Var loop_var, PrimExpr min, PrimExpr extent,
   int for_type, int device_api, Stmt body) {
   return ForNode::make(loop_var,
                    min,
@@ -93,7 +93,7 @@ TVM_REGISTER_GLOBAL("make.Load")
 
 TVM_REGISTER_GLOBAL("make.Store")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
-    Expr value = args[1];
+    PrimExpr value = args[1];
     if (args.size() == 3) {
       *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
     } else {
@@ -107,7 +107,7 @@ TVM_REGISTER_GLOBAL("make.Realize")
 TVM_REGISTER_GLOBAL("make.Call")
 .set_body_typed([](
   DataType type, std::string name,
-  Array<Expr> args, int call_type,
+  Array<PrimExpr> args, int call_type,
   FunctionRef func, int value_index
 ) {
   return CallNode::make(type,
@@ -173,7 +173,7 @@ REGISTER_MAKE(Evaluate);
 // has default args
 TVM_REGISTER_GLOBAL("make.Allocate")
   .set_body_typed([](
-    VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
+    Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition, Stmt body
   ){
     return AllocateNode::make(buffer_var, type, extents, condition, body);
   });
@@ -181,7 +181,7 @@ TVM_REGISTER_GLOBAL("make.Allocate")
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                     \
   TVM_REGISTER_GLOBAL("make."#Node)                             \
-  .set_body_typed([](Expr a, Expr b) {                          \
+  .set_body_typed([](PrimExpr a, PrimExpr b) {                  \
     return (Func(a, b));                                        \
   })
 
@@ -191,11 +191,11 @@ TVM_REGISTER_GLOBAL("make.Allocate")
     bool lhs_is_int = args[0].type_code() == kDLInt;                    \
     bool rhs_is_int = args[1].type_code() == kDLInt;                    \
     if (lhs_is_int) {                                                   \
-      *ret = (Func(args[0].operator int(), args[1].operator Expr()));   \
+      *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
     } else if (rhs_is_int) {                                            \
-      *ret = (Func(args[0].operator Expr(), args[1].operator int()));   \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
     } else {                                                            \
-      *ret = (Func(args[0].operator Expr(), args[1].operator Expr()));  \
+      *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
     }                                                                   \
   })
 
@@ -228,7 +228,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
 REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, operator>>);
 TVM_REGISTER_GLOBAL("make._OpIfThenElse")
-.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) {
+.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
   return if_then_else(cond, true_value, false_value);
 });
 
index 4e635ad..6a8bc58 100644 (file)
@@ -289,7 +289,7 @@ TVM_REGISTER_GLOBAL("_TensorHash")
   });
 
 TVM_REGISTER_GLOBAL("_Placeholder")
-.set_body_typed([](Array<Expr> shape, DataType dtype, std::string name) {
+.set_body_typed([](Array<PrimExpr> shape, DataType dtype, std::string name) {
   return placeholder(shape, dtype, name);
 });
 
@@ -337,14 +337,14 @@ TVM_REGISTER_GLOBAL("_StageBind")
 .set_body_method(&Stage::bind);
 
 TVM_REGISTER_GLOBAL("_StageSplitByFactor")
-.set_body_typed([](Stage stage, IterVar parent, Expr factor) {
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) {
   IterVar outer, inner;
   stage.split(parent, factor, &outer, &inner);
   return Array<IterVar>({outer, inner});
 });
 
 TVM_REGISTER_GLOBAL("_StageSplitByNParts")
-.set_body_typed([](Stage stage, IterVar parent, Expr nparts) {
+.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) {
   IterVar outer, inner;
   stage.split_by_nparts(parent, nparts, &outer, &inner);
   return Array<IterVar>({outer, inner});
@@ -373,7 +373,7 @@ TVM_REGISTER_GLOBAL("_StageTile")
 .set_body_typed([](
   Stage stage,
   IterVar x_parent, IterVar y_parent,
-  Expr x_factor, Expr y_factor
+  PrimExpr x_factor, PrimExpr y_factor
 ) {
     IterVar x_outer, y_outer, x_inner, y_inner;
     stage.tile(x_parent, y_parent,
index 16c2b1b..ff30f5e 100644 (file)
@@ -43,9 +43,9 @@ TVM_REGISTER_GLOBAL("ir_pass.Simplify")
       }
     } else {
       if (args.size() > 1) {
-        *ret = Simplify(args[0].operator Expr(), args[1]);
+        *ret = Simplify(args[0].operator PrimExpr(), args[1]);
       } else {
-        *ret = Simplify(args[0].operator Expr());
+        *ret = Simplify(args[0].operator PrimExpr());
       }
     }
   });
@@ -60,9 +60,9 @@ TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify")
       }
     } else {
       if (args.size() > 1) {
-        *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
+        *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]);
       } else {
-        *ret = CanonicalSimplify(args[0].operator Expr());
+        *ret = CanonicalSimplify(args[0].operator PrimExpr());
       }
     }
   });
@@ -70,9 +70,9 @@ TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify")
 TVM_REGISTER_GLOBAL("ir_pass.Substitute")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     if (args[0].IsObjectRef<Stmt>()) {
-      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, Expr>());
+      *ret = Substitute(args[0].operator Stmt(), args[1].operator Map<Var, PrimExpr>());
     } else {
-      *ret = Substitute(args[0].operator Expr(), args[1].operator Map<Var, Expr>());
+      *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map<Var, PrimExpr>());
     }
   });
 
@@ -81,7 +81,7 @@ TVM_REGISTER_GLOBAL("ir_pass.Equal")
     if (args[0].IsObjectRef<Stmt>()) {
       *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt());
     } else {
-      *ret = Equal(args[0].operator Expr(), args[1].operator Expr());
+      *ret = Equal(args[0].operator PrimExpr(), args[1].operator PrimExpr());
     }
   });
 
@@ -114,7 +114,7 @@ TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
 
 TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
+    *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
   });
 
 TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
index de37111..6f01c7a 100644 (file)
@@ -34,7 +34,7 @@ namespace tvm {
 struct TestAttrs : public AttrsNode<TestAttrs> {
   int axis;
   std::string name;
-  Array<Expr> padding;
+  Array<PrimExpr> padding;
   TypedEnvFunc<int(int)> func;
 
   TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
@@ -47,7 +47,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
         .describe("name");
     TVM_ATTR_FIELD(padding)
         .describe("padding of input")
-        .set_default(Array<Expr>({0, 0}));
+        .set_default(Array<PrimExpr>({0, 0}));
     TVM_ATTR_FIELD(func)
         .describe("some random env function")
         .set_default(TypedEnvFunc<int(int)>(nullptr));
index 68e0b05..7a3baa6 100644 (file)
@@ -35,8 +35,8 @@ Analyzer::Analyzer()
       int_set(this) {
 }
 
-void Analyzer::Bind(const VarExpr& var, const Expr& expr) {
-  Expr new_expr = expr;
+void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
+  PrimExpr new_expr = expr;
   new_expr = this->canonical_simplify(new_expr);
   new_expr = this->rewrite_simplify(new_expr);
 
@@ -46,7 +46,7 @@ void Analyzer::Bind(const VarExpr& var, const Expr& expr) {
   this->canonical_simplify.Update(var, new_expr);
 }
 
-void Analyzer::Bind(const VarExpr& var, const Range& range) {
+void Analyzer::Bind(const Var& var, const Range& range) {
   CHECK(range.defined());
   if (is_one(range->extent)) {
     this->Bind(var, range->min);
@@ -77,7 +77,7 @@ void ConstraintContext::ExitWithScope() {
   exit_();
 }
 
-bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
+bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
   if (const auto* ptr = expr.as<ir::IntImmNode>()) {
     return ptr->value >= lower_bound;
   }
@@ -86,7 +86,7 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
   return false;
 }
 
-bool Analyzer::CanProve(const Expr& expr) {
+bool Analyzer::CanProve(const PrimExpr& expr) {
   if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
     return ptr->value != 0;
   }
@@ -101,7 +101,7 @@ bool Analyzer::CanProve(const Expr& expr) {
   return false;
 }
 
-Expr Analyzer::Simplify(const Expr& expr) {
+PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
   if (is_const(expr)) return expr;
   auto res = this->rewrite_simplify(expr);
   if (is_const(res)) return res;
index 40f86de..1ba0293 100644 (file)
@@ -42,9 +42,9 @@ using namespace ir;
 // from a expression.
 class VariablePathFinder: public ExprVisitor {
  public:
-  explicit VariablePathFinder(Expr target) : target_(target) {}
+  explicit VariablePathFinder(PrimExpr target) : target_(target) {}
 
-  void VisitExpr(const Expr& node) final {
+  void VisitExpr(const PrimExpr& node) final {
     if (visited_.count(node.get()) != 0) return;
     visited_.insert(node.get());
 
@@ -58,13 +58,13 @@ class VariablePathFinder: public ExprVisitor {
 
  private:
   bool found_{false};
-  Expr target_;
+  PrimExpr target_;
   std::unordered_set<const Object*> visited_;
 };
 
 // get the path to the variable,
 // return empty vector to represent failure
-std::vector<const Object*> GetPath(Expr target, Expr expr) {
+std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) {
   VariablePathFinder v(target);
   v(expr);
   return v.path_;
@@ -77,14 +77,14 @@ class BoundDeducer: public ExprVisitor {
  public:
   friend class BoundDeduceInputChecker;
   friend class Converter;
-  BoundDeducer(Expr target, Expr expr,
+  BoundDeducer(PrimExpr target, PrimExpr expr,
                const std::unordered_map<const VarNode*, IntSet>& hint_map,
                const std::unordered_map<const VarNode*, IntSet>& relax_map)
   : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
 
   void Deduce();
 
-  void VisitExpr(const Expr& e) final {
+  void VisitExpr(const PrimExpr& e) final {
     if (!success_) return;
     if (e.get() == path_[iter_++]) {
       ExprVisitor::VisitExpr(e);
@@ -130,8 +130,8 @@ class BoundDeducer: public ExprVisitor {
 
   void VisitExpr_(const MulNode* op) final {
     bool left = op->a.get() == path_[iter_];
-    Expr operand = left ? op->b : op->a;
-    Expr target_var = left ? op->a : op->b;
+    PrimExpr operand = left ? op->b : op->a;
+    PrimExpr target_var = left ? op->a : op->b;
 
     SignType sign_operand;
     if (operand.dtype().is_uint()) {
@@ -176,7 +176,7 @@ class BoundDeducer: public ExprVisitor {
     this->VisitExpr(left ? op->a : op->b);
   }
 
-  Expr result_;
+  PrimExpr result_;
   CompareOp comp_op{kGreater};
   bool success_{true};
 
@@ -185,8 +185,8 @@ class BoundDeducer: public ExprVisitor {
   void Transform();
   void Relax();
   CompareOp ReverseOp(CompareOp comp_op);
-  Expr target_;
-  Expr expr_;
+  PrimExpr target_;
+  PrimExpr expr_;
   const std::unordered_map<const VarNode*, IntSet>& hint_map_;
   const std::unordered_map<const VarNode*, IntSet>& relax_map_;
   ExprIntSetMap expr_map_;
@@ -204,7 +204,7 @@ class BoundDeduceInputChecker: public ExprVisitor {
     return target_count == 1;
   }
 
-  void VisitExpr(const Expr& e) final {
+  void VisitExpr(const PrimExpr& e) final {
     if (e.same_as(deducer_->target_)) ++target_count;
     ExprVisitor::VisitExpr(e);
   }
@@ -329,13 +329,13 @@ void BoundDeducer::Relax() {
   result_ = (comp_op == kGreater) ? b.max() : b.min();
 }
 
-IntSet DeduceBound(Expr v, Expr e,
+IntSet DeduceBound(PrimExpr v, PrimExpr e,
   const std::unordered_map<const VarNode*, IntSet>& hint_map,
   const std::unordered_map<const VarNode*, IntSet>& relax_map) {
   BoundDeducer d(v, e, hint_map, relax_map);
   d.Deduce();
   if (!d.success_) return IntSet::nothing();
-  Expr min = neg_inf(), max = pos_inf();
+  PrimExpr min = neg_inf(), max = pos_inf();
   if (d.comp_op == kEqual) {
     min = d.result_;
     max = d.result_;
@@ -349,7 +349,7 @@ IntSet DeduceBound(Expr v, Expr e,
 
 // assuming e >= 0, deduce the bound of variable from it.
 // return empty set to represent deduce failure.
-IntSet DeduceBound(Expr v, Expr e,
+IntSet DeduceBound(PrimExpr v, PrimExpr e,
                    const Map<Var, IntSet>& hint_map,
                    const Map<Var, IntSet>& relax_map) {
   std::unordered_map<const VarNode*, IntSet> hmap;
index e33b0c5..5f721d7 100644 (file)
@@ -40,7 +40,7 @@ class SplitExpr;
  * \brief Base class of all temporary expression introduced
  *        for canonicalization.
  */
-class CanonicalExprNode : public BaseExprNode {
+class CanonicalExprNode : public PrimExprNode {
  public:
   virtual ~CanonicalExprNode() {}
   /*!
@@ -48,14 +48,14 @@ class CanonicalExprNode : public BaseExprNode {
    * \note Can mutate the internal data structure.
    * \return The normal expression.
    */
-  virtual Expr Normalize() const = 0;
+  virtual PrimExpr Normalize() const = 0;
 
   // overrides
   void VisitAttrs(tvm::AttrVisitor* v) {
   }
 
   static constexpr const char* _type_key = "arith.CanonicalExpr";
-  TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, BaseExprNode);
+  TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
 };
 
 enum DivMode {
@@ -65,7 +65,7 @@ enum DivMode {
   kFloorDiv
 };
 
-inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   if (mode == kTruncDiv) {
     return truncmod(a, b);
   } else {
@@ -74,7 +74,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
   }
 }
 
-inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
   if (mode == kTruncDiv) {
     return truncdiv(a, b);
   } else {
@@ -94,7 +94,7 @@ inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
 class SplitExprNode : public CanonicalExprNode {
  public:
   /*! \brief The base index expression. */
-  Expr index;
+  PrimExpr index;
   /*! \brief The division factor ratio. */
   int64_t lower_factor{1};
   /*!
@@ -112,8 +112,8 @@ class SplitExprNode : public CanonicalExprNode {
     CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0);
   }
 
-  Expr NormalizeWithScale(int64_t sscale) const {
-    Expr res = this->index;
+  PrimExpr NormalizeWithScale(int64_t sscale) const {
+    PrimExpr res = this->index;
     DataType dtype = this->dtype;
     if (this->scale == 0) {
       return make_const(dtype, 0);
@@ -132,7 +132,7 @@ class SplitExprNode : public CanonicalExprNode {
     return res;
   }
 
-  Expr Normalize() const final {
+  PrimExpr Normalize() const final {
     return NormalizeWithScale(1);
   }
 
@@ -149,9 +149,9 @@ class SplitExprNode : public CanonicalExprNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
 };
 
-class SplitExpr : public Expr {
+class SplitExpr : public PrimExpr {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, Expr, SplitExprNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
 };
 
@@ -190,7 +190,7 @@ class SumExprNode : public CanonicalExprNode {
    * \brief Return the normal Expr that is equivalent to self.
    * \return The normal expression.
    */
-  Expr Normalize() const final {
+  PrimExpr Normalize() const final {
     // quick path 1.
     if (this->args.size() == 0) {
       return make_const(this->dtype, this->base);
@@ -382,11 +382,11 @@ class SumExprNode : public CanonicalExprNode {
     std::stable_sort(args.begin(), args.end(), fcompare);
     return args;
   }
-  static Expr Normalize_(DataType dtype,
+  static PrimExpr Normalize_(DataType dtype,
                          const std::vector<SplitExpr>& args,
                          int64_t base) {
     // Positive scales first
-    Expr res = make_const(dtype, 0);
+    PrimExpr res = make_const(dtype, 0);
     for (size_t i = 0; i < args.size(); ++i) {
       if (args[i]->scale > 0) {
         res = res + args[i]->Normalize();
@@ -408,9 +408,9 @@ class SumExprNode : public CanonicalExprNode {
   }
 };
 
-class SumExpr : public Expr {
+class SumExpr : public PrimExpr {
  public:
-  TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, Expr, SumExprNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
 };
 
@@ -433,31 +433,31 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
       : Rewriter(parent) {}
 
 
-  Expr CanonicalSimplify(Expr expr) {
+  PrimExpr CanonicalSimplify(PrimExpr expr) {
     expr = operator()(expr);
     return expr;
   }
 
   // override the original mutate function.
-  Expr VisitExpr(const Expr& input_expr) final {
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
     auto expr = Rewriter::VisitExpr(input_expr);
     return Normalize(expr);
   }
 
   // Normal mutation without normalization.
-  Expr CanonicalMutate(Expr expr) {
+  PrimExpr CanonicalMutate(PrimExpr expr) {
     return Rewriter::VisitExpr(expr);
   }
 
   using Rewriter::VisitExpr_;
-  Expr VisitExpr_(const AddNode* op) final;
-  Expr VisitExpr_(const SubNode* op) final;
-  Expr VisitExpr_(const MulNode* op) final;
-  Expr VisitExpr_(const DivNode* op) final;
-  Expr VisitExpr_(const ModNode* op) final;
-  Expr VisitExpr_(const FloorDivNode* op) final;
-  Expr VisitExpr_(const FloorModNode* op) final;
-  Expr VisitExpr_(const ReduceNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const ModNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+  PrimExpr VisitExpr_(const ReduceNode* op) final;
 
  private:
   /*!
@@ -492,7 +492,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
    * \param expr The input expression.
    * \return Normalized expr.
    */
-  Expr Normalize(Expr expr) {
+  PrimExpr Normalize(PrimExpr expr) {
     if (const auto* op = expr.as<CanonicalExprNode>()) {
       return op->Normalize();
     } else {
@@ -504,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
    * \param expr The input expr.
    * \return The transformed SplitExpr.
    */
-  SplitExpr ToSplitExpr(Expr expr) {
+  SplitExpr ToSplitExpr(PrimExpr expr) {
     if (const auto* op = expr.as<SplitExprNode>()) {
       return GetRef<SplitExpr>(op);
     }
@@ -547,7 +547,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
    * \param expr The input expr.
    * \return The transformed SumExpr.
    */
-  SumExpr ToSumExpr(Expr expr) {
+  SumExpr ToSumExpr(PrimExpr expr) {
     if (const auto* op = expr.as<SumExprNode>()) {
       return GetRef<SumExpr>(op);
     }
@@ -562,20 +562,20 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
     }
   }
   // Simplify the combiner used in reduce.
-  Expr SimplifyReduceCombiner(const ReduceNode* op);
+  PrimExpr SimplifyReduceCombiner(const ReduceNode* op);
 };
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const AddNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
   // normalize
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<AddNode>(a, b);
+  PrimExpr const_res = TryConstFold<AddNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // canonical form simplification.
@@ -591,17 +591,17 @@ VisitExpr_(const AddNode* op) {
   return std::move(ret);
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const SubNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
   // normalize
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<SubNode>(a, b);
+  PrimExpr const_res = TryConstFold<SubNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // canonical form simplification.
@@ -618,17 +618,17 @@ VisitExpr_(const SubNode* op) {
 }
 
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const MulNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
   // normalize
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<MulNode>(a, b);
+  PrimExpr const_res = TryConstFold<MulNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // x * c
@@ -651,7 +651,7 @@ VisitExpr_(const MulNode* op) {
   a = Normalize(a);
   b = Normalize(b);
   if (op->a.same_as(a) && op->b.same_as(b)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return MulNode::make(a, b);
   }
@@ -725,17 +725,17 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
   return lhs;
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const DivNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
 
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<DivNode>(a, b);
+  PrimExpr const_res = TryConstFold<DivNode>(a, b);
   if (const_res.defined()) return const_res;
   PVar<Integer> c1;
   // x / c1
@@ -755,7 +755,7 @@ VisitExpr_(const DivNode* op) {
       if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
           analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
         lhs.CopyOnWrite()->DivideBy(cval);
-        Expr temp = Normalize(extra);
+        PrimExpr temp = Normalize(extra);
         if (const auto* pconst = temp.as<IntImmNode>()) {
           lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
         } else {
@@ -780,22 +780,22 @@ VisitExpr_(const DivNode* op) {
   a = Normalize(a);
   b = Normalize(b);
   if (op->a.same_as(a) && op->b.same_as(b)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return DivNode::make(a, b);
   }
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const FloorDivNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<FloorDivNode>(a, b);
+  PrimExpr const_res = TryConstFold<FloorDivNode>(a, b);
   if (const_res.defined()) return const_res;
   PVar<Integer> c1;
   // x / c1
@@ -812,7 +812,7 @@ VisitExpr_(const FloorDivNode* op) {
       }
       // continue simplification.
       lhs.CopyOnWrite()->DivideBy(cval);
-      Expr temp = Normalize(extra);
+      PrimExpr temp = Normalize(extra);
       if (const auto* pconst = temp.as<IntImmNode>()) {
         lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
       } else {
@@ -836,7 +836,7 @@ VisitExpr_(const FloorDivNode* op) {
   a = Normalize(a);
   b = Normalize(b);
   if (op->a.same_as(a) && op->b.same_as(b)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return FloorDivNode::make(a, b);
   }
@@ -892,17 +892,17 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
   return lhs;
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const ModNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
   // normalize
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<ModNode>(a, b);
+  PrimExpr const_res = TryConstFold<ModNode>(a, b);
   if (const_res.defined()) return const_res;
 
   PVar<Integer> c1;
@@ -918,7 +918,7 @@ VisitExpr_(const ModNode* op) {
       // both lhs and extra are non-negative
       if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
           analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
-        Expr temp = Normalize(extra);
+        PrimExpr temp = Normalize(extra);
         if (temp.as<IntImmNode>()) {
           return truncmod(temp, c1.Eval());
         } else {
@@ -956,23 +956,23 @@ VisitExpr_(const ModNode* op) {
   a = Normalize(a);
   b = Normalize(b);
   if (op->a.same_as(a) && op->b.same_as(b)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return ModNode::make(a, b);
   }
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const FloorModNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
   // normalize
-  Expr a = this->CanonicalMutate(op->a);
-  Expr b = this->CanonicalMutate(op->b);
+  PrimExpr a = this->CanonicalMutate(op->a);
+  PrimExpr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<FloorModNode>(a, b);
+  PrimExpr const_res = TryConstFold<FloorModNode>(a, b);
   if (const_res.defined()) return const_res;
 
   PVar<Integer> c1;
@@ -982,7 +982,7 @@ VisitExpr_(const FloorModNode* op) {
     if (const auto* psum = a.as<SumExprNode>()) {
       SumExpr lhs, extra;
       SeparateDivisibleParts(psum, cval, &lhs, &extra);
-      Expr temp = Normalize(extra);
+      PrimExpr temp = Normalize(extra);
       if (temp.as<IntImmNode>()) {
         return floormod(temp, c1.Eval());
       } else {
@@ -1016,19 +1016,19 @@ VisitExpr_(const FloorModNode* op) {
   a = Normalize(a);
   b = Normalize(b);
   if (op->a.same_as(a) && op->b.same_as(b)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return FloorModNode::make(a, b);
   }
 }
 
 // Simplify reduce expression.
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 SimplifyReduceCombiner(const ReduceNode* op) {
   // First simplify the results
-  Array<Expr> simplified_result;
+  Array<PrimExpr> simplified_result;
   for (const auto& res : op->combiner->result) {
-    Expr new_res = this->VisitExpr(res);
+    PrimExpr new_res = this->VisitExpr(res);
     simplified_result.push_back(new_res);
   }
 
@@ -1066,11 +1066,11 @@ SimplifyReduceCombiner(const ReduceNode* op) {
   }
 
   int new_value_index = op->value_index;
-  Array<Expr> new_result;
-  Array<Expr> new_identity;
+  Array<PrimExpr> new_result;
+  Array<PrimExpr> new_identity;
   Array<Var> new_lhs;
   Array<Var> new_rhs;
-  Array<Expr> new_source;
+  Array<PrimExpr> new_source;
 
   // new stuff is old stuff which is used
   for (size_t i = 0; i < used.size(); ++i) {
@@ -1093,10 +1093,10 @@ SimplifyReduceCombiner(const ReduceNode* op) {
       new_combiner, new_source, op->axis, op->condition, new_value_index);
 }
 
-Expr CanonicalSimplifier::Impl::
+PrimExpr CanonicalSimplifier::Impl::
 VisitExpr_(const ReduceNode* op) {
   // Recursively call simplification when necessary.
-  Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
+  PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op);
   op = ret.as<ReduceNode>();
   // already been simplified by const reduction axis removal
   if (op == nullptr) return ret;
@@ -1115,12 +1115,12 @@ VisitExpr_(const ReduceNode* op) {
   return ret;
 }
 
-Expr CanonicalSimplifier::operator()(const Expr& expr) {
+PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
   return impl_->CanonicalSimplify(expr);
 }
 
 void CanonicalSimplifier::Update(const Var& var,
-                                 const Expr& info,
+                                 const PrimExpr& info,
                                  bool override) {
   impl_->Update(var, info, override);
 }
index aca26e8..d78838f 100644 (file)
@@ -39,7 +39,7 @@ namespace arith {
  * \return The result.
  */
 template<typename OP>
-inline Expr Compute(Expr lhs, Expr rhs) {
+inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
   return OP::make(lhs, rhs);
 }
 
@@ -52,10 +52,10 @@ inline Expr Compute(Expr lhs, Expr rhs) {
  * \return The result.
  */
 template<typename Op>
-inline Expr ComputeReduce(
-    const Array<Expr>& values, Expr empty_value);
+inline PrimExpr ComputeReduce(
+    const Array<PrimExpr>& values, PrimExpr empty_value);
 
-inline bool GetConst(Expr e, int64_t* out) {
+inline bool GetConst(PrimExpr e, int64_t* out) {
   if (e.dtype().is_vector()) return false;
   const int64_t* v = as_const_int(e);
   if (v) {
@@ -66,7 +66,7 @@ inline bool GetConst(Expr e, int64_t* out) {
 }
 
 // get a small constant int
-inline bool GetConstInt(Expr e, int* out) {
+inline bool GetConstInt(PrimExpr e, int* out) {
   int64_t v1 = 0;
   if (GetConst(e, &v1)) {
     if (v1 > static_cast<int64_t>(
@@ -77,47 +77,47 @@ inline bool GetConstInt(Expr e, int* out) {
 }
 
 template<>
-inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::AddNode>(PrimExpr a, PrimExpr b) {
   return a + b;
 }
 
 template<>
-inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::SubNode>(PrimExpr a, PrimExpr b) {
   return a - b;
 }
 
 template<>
-inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MulNode>(PrimExpr a, PrimExpr b) {
   return a * b;
 }
 
 template<>
-inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::DivNode>(PrimExpr a, PrimExpr b) {
   return truncdiv(a, b);
 }
 
 template<>
-inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::ModNode>(PrimExpr a, PrimExpr b) {
   return truncmod(a, b);
 }
 
 template<>
-inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MaxNode>(PrimExpr a, PrimExpr b) {
   return max(a, b);
 }
 
 template<>
-inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
+inline PrimExpr Compute<ir::MinNode>(PrimExpr a, PrimExpr b) {
   return min(a, b);
 }
 
 template<typename Op>
-inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
+inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
   if (values.size() == 0U) {
     CHECK(empty_value.defined());
     return empty_value;
   }
-  Expr res = values[0];
+  PrimExpr res = values[0];
   for (size_t i = 1; i < values.size(); ++i) {
     res = Compute<Op>(res, values[i]);
   }
index db98a7e..55c156d 100644 (file)
@@ -44,8 +44,8 @@ namespace arith {
  * \return nullptr if constant fold fails, otherwise return folded result.
  */
 template<typename Op>
-inline Expr TryConstFold(Expr a, Expr b) {
-  return Expr();
+inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) {
+  return PrimExpr();
 }
 
 /*!
@@ -58,7 +58,7 @@ inline Expr TryConstFold(Expr a, Expr b) {
  * \return nullptr if constant fold fails, otherwise return folded result.
  */
 template<typename Op>
-inline Expr TryConstFold(Expr a);
+inline PrimExpr TryConstFold(PrimExpr a);
 
 /*!
  * \brief Check whether type is used to represent index.
@@ -100,7 +100,7 @@ inline bool IsIndexType(const DataType& type) {
 
 // specialization of constant folders.
 template<>
-inline Expr TryConstFold<ir::AddNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
@@ -110,11 +110,11 @@ inline Expr TryConstFold<ir::AddNode>(Expr a, Expr b) {
       if (fa && fa->value == 0) return b;
       if (fb && fb->value == 0) return a;
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::SubNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
@@ -122,11 +122,11 @@ inline Expr TryConstFold<ir::SubNode>(Expr a, Expr b) {
       if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
       if (fb && fb->value == 0) return a;
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::MulNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
@@ -148,11 +148,11 @@ inline Expr TryConstFold<ir::MulNode>(Expr a, Expr b) {
         if (fb->value == 0) return b;
       }
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::DivNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
@@ -177,11 +177,11 @@ inline Expr TryConstFold<ir::DivNode>(Expr a, Expr b) {
         CHECK_NE(fb->value, 0) << "Divide by zero";
       }
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::ModNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::ModNode>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
@@ -195,11 +195,11 @@ inline Expr TryConstFold<ir::ModNode>(Expr a, Expr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::FloorDivNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
@@ -222,11 +222,11 @@ inline Expr TryConstFold<ir::FloorDivNode>(Expr a, Expr b) {
         CHECK_NE(fb->value, 0) << "Divide by zero";
       }
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::FloorModNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::FloorModNode>(PrimExpr a, PrimExpr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
@@ -240,87 +240,87 @@ inline Expr TryConstFold<ir::FloorModNode>(Expr a, Expr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::MinNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
       if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::MaxNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
       if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::GTNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::GTNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::GENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::GENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::LTNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::LTNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::LENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::LENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::EQNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::EQNode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::NENode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::NENode>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
       if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
       if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
     });
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::AndNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::AndNode>(PrimExpr a, PrimExpr b) {
   using ir::UIntImmNode;
   const UIntImmNode* pa = a.as<UIntImmNode>();
   const UIntImmNode* pb = b.as<UIntImmNode>();
@@ -328,11 +328,11 @@ inline Expr TryConstFold<ir::AndNode>(Expr a, Expr b) {
   if (pa && !pa->value) return a;
   if (pb && pb->value) return a;
   if (pb && !pb->value) return b;
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::OrNode>(Expr a, Expr b) {
+inline PrimExpr TryConstFold<ir::OrNode>(PrimExpr a, PrimExpr b) {
   using ir::UIntImmNode;
   const UIntImmNode* pa = a.as<UIntImmNode>();
   const UIntImmNode* pb = b.as<UIntImmNode>();
@@ -340,25 +340,25 @@ inline Expr TryConstFold<ir::OrNode>(Expr a, Expr b) {
   if (pa && !pa->value) return b;
   if (pb && pb->value) return b;
   if (pb && !pb->value) return a;
-  return Expr();
+  return PrimExpr();
 }
 
 template<>
-inline Expr TryConstFold<ir::NotNode>(Expr a) {
+inline PrimExpr TryConstFold<ir::NotNode>(PrimExpr a) {
   using ir::UIntImmNode;
   const UIntImmNode* pa = a.as<UIntImmNode>();
   if (pa) {
     return UIntImmNode::make(DataType::UInt(1), !(pa->value));
   }
-  return Expr();
+  return PrimExpr();
 }
 
 /*! \brief Helper namespace for symbolic value limits */
 struct SymbolicLimits {
   /*! \brief positive infinity */
-  static Expr pos_inf_;
+  static PrimExpr pos_inf_;
   /*! \brief negative infinity */
-  static Expr neg_inf_;
+  static PrimExpr neg_inf_;
 };
 
 /*!
@@ -369,7 +369,7 @@ struct SymbolicLimits {
  *
  * \return positive infinity.
  */
-inline Expr pos_inf() {
+inline PrimExpr pos_inf() {
   return SymbolicLimits::pos_inf_;
 }
 
@@ -379,7 +379,7 @@ inline Expr pos_inf() {
  *
  * \return The check result.
  */
-inline bool is_pos_inf(const Expr& value) {
+inline bool is_pos_inf(const PrimExpr& value) {
   return value.same_as(SymbolicLimits::pos_inf_);
 }
 
@@ -391,7 +391,7 @@ inline bool is_pos_inf(const Expr& value) {
  *
  * \return negative infinity.
  */
-inline Expr neg_inf() {
+inline PrimExpr neg_inf() {
   return SymbolicLimits::neg_inf_;
 }
 
@@ -401,7 +401,7 @@ inline Expr neg_inf() {
  *
  * \return The check result.
  */
-inline bool is_neg_inf(const Expr& value) {
+inline bool is_neg_inf(const PrimExpr& value) {
   return value.same_as(SymbolicLimits::neg_inf_);
 }
 
index d3f885a..a041e40 100644 (file)
@@ -76,17 +76,17 @@ struct ConstIntBoundAnalyzer::Entry {
 };
 
 class ConstIntBoundAnalyzer::Impl :
-      public ExprFunctor<ConstIntBoundAnalyzer::Entry(const Expr&)> {
+      public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
  public:
   /*! \brief additional bound info about expr \in bound */
   struct BoundInfo {
     /*! \brief The expr */
-    Expr expr;
+    PrimExpr expr;
     /*! \brief The additional bound */
     Entry bound;
 
     BoundInfo() {}
-    BoundInfo(Expr expr, Entry bound)
+    BoundInfo(PrimExpr expr, Entry bound)
         : expr(expr), bound(bound) {
     }
   };
@@ -125,10 +125,10 @@ class ConstIntBoundAnalyzer::Impl :
   // Override visitor behaviors
   Entry VisitExprDefault_(const Object* op) final {
     return Everything(
-        static_cast<const ExprNode*>(op)->dtype);
+        static_cast<const PrimExprNode*>(op)->dtype);
   }
 
-  Entry VisitExpr(const Expr& expr) final {
+  Entry VisitExpr(const PrimExpr& expr) final {
     Entry res = ExprFunctor::VisitExpr(expr);
     // a linear search over additional info
     // assume we won't have a lot of conditions
@@ -315,7 +315,7 @@ class ConstIntBoundAnalyzer::Impl :
     }
   }
 
-  std::function<void()> EnterConstraint(const Expr& constraint) {
+  std::function<void()> EnterConstraint(const PrimExpr& constraint) {
     std::vector<BoundInfo> info = DetectBoundInfo(constraint);
     if (info.size() == 0) return nullptr;
     size_t old_size = additional_info_.size();
@@ -330,7 +330,7 @@ class ConstIntBoundAnalyzer::Impl :
 
  private:
   // internal variable map
-  std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
+  std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
   // additional bound info
   std::vector<BoundInfo> additional_info_;
   // constants: the limit value means umlimited
@@ -494,8 +494,8 @@ class ConstIntBoundAnalyzer::Impl :
    * \param cond The constraint condition.
    * \return List of detected bounds.
    */
-  static std::vector<BoundInfo> DetectBoundInfo(const Expr& cond) {
-    PVar<Expr> x, y;
+  static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
+    PVar<PrimExpr> x, y;
     PVar<Integer> c;
     // NOTE: canonical form always use <= or <
     if ((c <= x).Match(cond)) {
@@ -520,7 +520,7 @@ class ConstIntBoundAnalyzer::Impl :
   }
 };
 
-ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
+ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
   Entry ret = impl_->VisitExpr(expr);
   return ConstIntBound(ret.min_value, ret.max_value);
 }
@@ -535,7 +535,7 @@ void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
   impl_->Bind(var, range);
 }
 
-std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) {
+std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
   return impl_->EnterConstraint(constraint);
 }
 
index 7785801..3de5557 100644 (file)
@@ -33,22 +33,22 @@ using namespace ir;
 
 // Linear equation, the components can be undefined.
 struct LinearEqEntry {
-  Expr base;
-  Expr coeff;
+  PrimExpr base;
+  PrimExpr coeff;
 };
 
 struct IntervalEntry {
-  Expr min_value;
-  Expr max_value;
+  PrimExpr min_value;
+  PrimExpr max_value;
 };
 
 class LinearEqDetector
-    : public ExprFunctor<LinearEqEntry(const Expr&, const Expr &)> {
+    : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr &)> {
  public:
   explicit LinearEqDetector(Var var)
       : var_(var) {}
 
-  bool Detect(const Expr& e, LinearEqEntry* ret) {
+  bool Detect(const PrimExpr& e, LinearEqEntry* ret) {
     *ret = VisitExpr(e, e);
     if (fail_) return false;
     if (!ret->base.defined()) {
@@ -60,7 +60,7 @@ class LinearEqDetector
     return true;
   }
 
-  LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -70,7 +70,7 @@ class LinearEqDetector
     return ret;
   }
 
-  LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -80,7 +80,7 @@ class LinearEqDetector
     return ret;
   }
 
-  LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -96,7 +96,7 @@ class LinearEqDetector
     ret.coeff = MulCombine(a.base, b.coeff);
     return ret;
   }
-  LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
     LinearEqEntry ret;
     if (op == var_.get()) {
       ret.coeff = make_const(op->dtype, 1);
@@ -105,7 +105,7 @@ class LinearEqDetector
     }
     return ret;
   }
-  LinearEqEntry VisitExprDefault_(const Object* op, const Expr& e) final {
+  LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
     if (fail_) return LinearEqEntry();
     if (ExprUseVar(e, var_)) {
       fail_ = true;
@@ -121,32 +121,33 @@ class LinearEqDetector
   Var var_;
   bool fail_{false};
   // Combine by add
-  Expr AddCombine(Expr a, Expr b) {
+  PrimExpr AddCombine(PrimExpr a, PrimExpr b) {
     if (!a.defined()) return b;
     if (!b.defined()) return a;
     return a + b;
   }
-  Expr SubCombine(Expr a, Expr b) {
+  PrimExpr SubCombine(PrimExpr a, PrimExpr b) {
     // Check b first in case they are both undefined
     if (!b.defined()) return a;
     if (!a.defined()) return -b;
     return a - b;
   }
-  Expr MulCombine(Expr a, Expr b) {
+  PrimExpr MulCombine(PrimExpr a, PrimExpr b) {
     if (!a.defined()) return a;
     if (!b.defined()) return b;
     return a * b;
   }
 };
 
-Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
-  Expr base = e;
-  Array<Expr> coeff;
+Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
+                                          const Array<Var>& vars) {
+  PrimExpr base = e;
+  Array<PrimExpr> coeff;
 
   for (Var v : vars) {
     LinearEqEntry ret;
     if (!LinearEqDetector(v).Detect(base, &ret)) {
-      return Array<Expr>();
+      return Array<PrimExpr>();
     }
     coeff.push_back(ret.coeff);
     base = std::move(ret.base);
@@ -157,7 +158,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
     vset.insert(vars[i - 1].get());
     // The previous coeff contains the variable
     if (ExprUseVar(coeff[i - 2], vset)) {
-      return Array<Expr>();
+      return Array<PrimExpr>();
     }
   }
   coeff.push_back(base);
@@ -166,7 +167,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
 
 // Detect clip condition as min max value
 bool DetectClipBound(
-    const Expr& cond,
+    const PrimExpr& cond,
     std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
   int flag = 0;
   Var var;
@@ -187,7 +188,7 @@ bool DetectClipBound(
   PostOrderVisit(cond, fvisit);
   if (flag != 1) return false;
   // canonical form: exp >= 0
-  Expr canonical;
+  PrimExpr canonical;
   if (const LTNode* op = cond.as<LTNode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->b - op->a - make_const(op->a.dtype(), 1);
@@ -230,7 +231,7 @@ bool DetectClipBound(
 
 
 template<typename OP>
-void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
+void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
   if (const OP* op = e.as<OP>()) {
     SplitCommExpr<OP>(op->a, ret);
     SplitCommExpr<OP>(op->b, ret);
@@ -241,17 +242,17 @@ void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
 
 // Detect the lower and upper bound from the expression.
 // e must be connected by and.
-Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
-  std::vector<Expr> splits;
+Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
+  std::vector<PrimExpr> splits;
   SplitCommExpr<ir::AndNode>(e, &splits);
   std::unordered_map<const VarNode*, IntervalEntry> rmap;
   for (Var v : vars) {
     rmap[v.get()] = IntervalEntry();
   }
-  for (Expr cond : splits) {
-    if (!DetectClipBound(cond, &rmap)) return Array<Expr>();
+  for (PrimExpr cond : splits) {
+    if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
   }
-  Array<Expr> ret;
+  Array<PrimExpr> ret;
   for (Var v : vars) {
     IntervalEntry e = rmap[v.get()];
     if (e.min_value.defined()) {
index 1821c16..3889cd2 100644 (file)
@@ -99,7 +99,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
   }
 
  private:
-  void Touch(const Array<Expr>& args) {
+  void Touch(const Array<PrimExpr>& args) {
     if (args.size() > bounds_.size()) {
       bounds_.resize(args.size());
     }
index c60c825..ceaa976 100644 (file)
 namespace tvm {
 namespace arith {
 
-Expr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
-Expr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
+PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
+PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
 
-IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
+IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) {
   auto node = make_object<IntervalSetNode>();
   node->min_value = std::move(min_value);
   node->max_value = std::move(max_value);
   data_ = std::move(node);
 }
 
-IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) {
+IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) {
   return IntervalSet(min_value, max_value);
 }
 
@@ -54,8 +54,8 @@ TVM_REGISTER_GLOBAL("arith._make_IntervalSet")
 
 
 IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
-  Expr max_value = min(a->max_value, b->max_value);
-  Expr min_value = max(a->min_value, b->min_value);
+  PrimExpr max_value = min(a->max_value, b->max_value);
+  PrimExpr min_value = max(a->min_value, b->min_value);
   if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
       (min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
       analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
@@ -66,8 +66,8 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
 }
 
 IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
-  Expr max_value = max(a->max_value, b->max_value);
-  Expr min_value = min(a->min_value, b->min_value);
+  PrimExpr max_value = max(a->max_value, b->max_value);
+  PrimExpr min_value = min(a->min_value, b->min_value);
   return IntervalSet(min_value, max_value);
 }
 
@@ -102,7 +102,7 @@ inline IntervalSet Combine(Analyzer* analyzer,
                            IntervalSet a,
                            IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
-    Expr res = TryConstFold<Op>(a->min_value, b->min_value);
+    PrimExpr res = TryConstFold<Op>(a->min_value, b->min_value);
     if (!res.defined()) res = Op::make(a->min_value, b->min_value);
     return IntervalSet::SinglePoint(res);
   }
@@ -126,10 +126,10 @@ inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer,
   }
   if (a->IsEmpty()) return a;
   if (b->IsEmpty()) return b;
-  Expr min_value =
+  PrimExpr min_value =
       a->HasLowerBound() && b->HasLowerBound() ?
       a->min_value + b->min_value : neg_inf();
-  Expr max_value =
+  PrimExpr max_value =
       a->HasUpperBound() && b->HasUpperBound() ?
       a->max_value + b->max_value : pos_inf();
   return IntervalSet(min_value, max_value);
@@ -144,10 +144,10 @@ inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer,
   }
   if (a->IsEmpty()) return a;
   if (b->IsEmpty()) return b;
-  Expr min_value =
+  PrimExpr min_value =
       a->HasLowerBound() && b->HasUpperBound() ?
       a->min_value - b->max_value : neg_inf();
-  Expr max_value =
+  PrimExpr max_value =
       a->HasUpperBound() && b->HasLowerBound() ?
       a->max_value - b->min_value : pos_inf();
   return IntervalSet(min_value, max_value);
@@ -170,18 +170,18 @@ inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer,
     if (is_zero(b->min_value)) return b;
     if (is_one(b->min_value)) return a;
     if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
-      Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
-      Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
+      PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
+      PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
-      Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
-      Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
+      PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
+      PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
       using ir::SelectNode;
-      Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
-      Expr e1 = a->min_value * b->min_value;
-      Expr e2 = a->max_value * b->min_value;
+      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+      PrimExpr e1 = a->min_value * b->min_value;
+      PrimExpr e2 = a->max_value * b->min_value;
       return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
@@ -205,18 +205,18 @@ inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer,
     if (is_one(b->min_value)) return a;
     // no relaxation is needed in here due to set is inclusive
     if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
-      Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
-      Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
+      PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
+      PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
-      Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
-      Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
+      PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
+      PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
       using ir::SelectNode;
-      Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
-      Expr e1 = a->min_value / b->min_value;
-      Expr e2 = a->max_value / b->min_value;
+      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+      PrimExpr e1 = a->min_value / b->min_value;
+      PrimExpr e2 = a->max_value / b->min_value;
       return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
@@ -235,7 +235,7 @@ inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer,
   if (b->IsEmpty()) return b;
 
   if (b->IsSinglePoint()) {
-    const Expr& divisor = b->min_value;
+    const PrimExpr& divisor = b->min_value;
     if (is_zero(divisor)) {
       LOG(FATAL) << "Modular by zero in CombineInterval Mod";
     }
@@ -246,7 +246,7 @@ inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer,
     if (analyzer->CanProveGreaterEqual(divisor, 0)) {
       return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
     } else {
-      Expr bound = abs(divisor) - 1;
+      PrimExpr bound = abs(divisor) - 1;
       return IntervalSet(-bound, bound);
     }
   }
@@ -271,18 +271,18 @@ inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer,
     if (is_one(b->min_value)) return a;
     // no relaxation is needed in here due to set is inclusive
     if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
-      Expr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
-      Expr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
+      PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
+      PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
-      Expr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
-      Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
+      PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
+      PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
       using ir::SelectNode;
-      Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
-      Expr e1 = floordiv(a->min_value, b->min_value);
-      Expr e2 = floordiv(a->max_value, b->min_value);
+      PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
+      PrimExpr e1 = floordiv(a->min_value, b->min_value);
+      PrimExpr e2 = floordiv(a->max_value, b->min_value);
       return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
@@ -301,14 +301,14 @@ inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer,
   if (b->IsEmpty()) return b;
 
   if (b->IsSinglePoint()) {
-    const Expr& divisor = b->min_value;
+    const PrimExpr& divisor = b->min_value;
     if (is_zero(divisor)) {
       LOG(FATAL) << "Modular by zero in CombineInterval Mod";
     }
     if (analyzer->CanProveGreaterEqual(divisor, 0)) {
       return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
     } else {
-      Expr bound = abs(divisor) - 1;
+      PrimExpr bound = abs(divisor) - 1;
       return IntervalSet(-bound, bound);
     }
   }
@@ -356,7 +356,7 @@ using namespace ir;
 // Simplified version of int set evaluator that operates on IntervalSet
 // We might use better set analysis in the future to replace the intervalset.
 class IntervalSetEvaluator :
-      public ExprFunctor<IntervalSet(const Expr&)> {
+      public ExprFunctor<IntervalSet(const PrimExpr&)> {
  public:
   IntervalSetEvaluator(Analyzer* analyzer,
                        const Map<Var, IntSet>& dom_map,
@@ -366,7 +366,7 @@ class IntervalSetEvaluator :
         eval_vec_(eval_vec) {
   }
 
-  IntervalSet Eval(const Expr& val) {
+  IntervalSet Eval(const PrimExpr& val) {
     return this->VisitExpr(val);
   }
   // evaluate and relax the set
@@ -381,11 +381,11 @@ class IntervalSetEvaluator :
   }
 
   IntervalSet VisitExpr_(const IntImmNode* op) final {
-    return IntervalSet::SinglePoint(GetRef<Expr>(op));
+    return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
   }
 
   IntervalSet VisitExpr_(const UIntImmNode* op) final {
-    return IntervalSet::SinglePoint(GetRef<Expr>(op));
+    return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
   }
 
   IntervalSet VisitExpr_(const VarNode* op) final {
@@ -492,7 +492,7 @@ class IntervalSetEvaluator :
             IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
       }
     }
-    DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<Expr>(op);
+    DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op);
     return IntervalSet::Everything();
   }
 
@@ -515,7 +515,7 @@ class IntervalSetEvaluator :
  private:
   // whether set is exactly single point that equals value.
   bool MatchPoint(const IntervalSet& set,
-                  const Expr& value) const {
+                  const PrimExpr& value) const {
     return set->min_value.same_as(value) && set->max_value.same_as(value);
   }
 
@@ -524,7 +524,7 @@ class IntervalSetEvaluator :
     IntervalSet a = this->Eval(op->a);
     IntervalSet b = this->Eval(op->b);
     if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
-      return IntervalSet::SinglePoint(GetRef<Expr>(op));
+      return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
     }
     return Combine<T>(analyzer_, a, b);
   }
@@ -543,7 +543,7 @@ class IntSetAnalyzer::Impl {
       : analyzer_(analyzer) {
   }
 
-  IntSet Eval(const Expr& expr, const Map<Var, IntSet>& dom_map) const {
+  IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
     return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
   }
 
@@ -559,7 +559,7 @@ IntSetAnalyzer::~IntSetAnalyzer() {
   delete impl_;
 }
 
-IntSet IntSetAnalyzer::operator()(const Expr& expr,
+IntSet IntSetAnalyzer::operator()(const PrimExpr& expr,
                                   const Map<Var, IntSet>& dom_map) {
   return impl_->Eval(expr, dom_map);
 }
@@ -577,13 +577,13 @@ Range IntSet::cover_range(Range max_range) const {
   return max_range;
 }
 
-Expr IntSet::min() const {
+PrimExpr IntSet::min() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   CHECK(s_int);
   return s_int->min_value;
 }
 
-Expr IntSet::max() const {
+PrimExpr IntSet::max() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   CHECK(s_int);
   return s_int->max_value;
@@ -641,7 +641,7 @@ SignType IntSet::sign_type() const {
     return kUnknown;
   }
 }
-Expr IntSet::point_value() const {
+PrimExpr IntSet::point_value() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   CHECK(s_int && s_int->IsSinglePoint());
   return s_int->min_value;
@@ -655,11 +655,11 @@ IntSet IntSet::everything() {
   return IntervalSet::Everything();
 }
 
-IntSet IntSet::single_point(Expr x) {
+IntSet IntSet::single_point(PrimExpr x) {
   return IntervalSet::SinglePoint(x);
 }
 
-IntSet IntSet::interval(Expr min, Expr max) {
+IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
   if (min.same_as(max)) {
     return IntSet::single_point(min);
   }
@@ -667,7 +667,7 @@ IntSet IntSet::interval(Expr min, Expr max) {
 }
 
 // Range related code
-inline bool ProveEqual(Expr lhs, Expr rhs) {
+inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) {
   return is_zero(ir::Simplify(lhs - rhs));
 }
 
@@ -728,24 +728,24 @@ Map<Var, IntSet> ConvertDomMap(
   return dmap;
 }
 
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
                const Map<Var, IntSet>& dom_map) {
   Analyzer ana;
   return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
 }
 
-IntSet IntSet::vector(Expr x) {
+IntSet IntSet::vector(PrimExpr x) {
   Analyzer ana;
   Map<Var, IntSet> dmap;
   return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
 }
 
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
                const Map<IterVar, IntSet>& dom_map) {
   return EvalSet(e, ConvertDomMap(dom_map));
 }
 
-IntSet EvalSet(Expr e,
+IntSet EvalSet(PrimExpr e,
                const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   return EvalSet(e, ConvertDomMap(dom_map));
 }
@@ -755,7 +755,7 @@ IntSet EvalSet(Range r,
   Analyzer ana;
   IntervalSetEvaluator m(&ana, dom_map);
   // Simplifying first can give tighter bounds if r->min and r->extent share variables
-  Expr sum = r->min + r->extent - 1;
+  PrimExpr sum = r->min + r->extent - 1;
   auto res  = m.Eval(IntervalSet(r->min,  Simplify(sum)));
   return std::move(res);
 }
@@ -771,9 +771,9 @@ IntSet EvalSet(IntSet s,
   auto dmap = ConvertDomMap(dom_map);
   IntervalSetEvaluator m(&ana, dmap);
   const IntervalSetNode* s_int = s.as<IntervalSetNode>();
-  Expr vmax = s_int->HasUpperBound() ?
+  PrimExpr vmax = s_int->HasUpperBound() ?
       m.Eval(s_int->max_value).max() : s_int->max_value;
-  Expr vmin = s_int->HasLowerBound() ?
+  PrimExpr vmin = s_int->HasLowerBound() ?
       m.Eval(s_int->min_value).min() : s_int->min_value;
   return IntervalSet(vmin, vmax);
 }
@@ -785,7 +785,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
       const Map<Var, IntSet>& dom_map)
       : IntervalSetEvaluator(analyzer, dom_map) {}
 
-  IntervalSet VisitExpr(const Expr& n) final {
+  IntervalSet VisitExpr(const PrimExpr& n) final {
     IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
     expr_map[n] = ret;
     return ret;
@@ -795,7 +795,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
 };
 
 ExprIntSetMap EvalSetForEachSubExpr(
-    Expr e,
+    PrimExpr e,
     const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   Analyzer ana;
   auto dmap = ConvertDomMap(dom_map);
index 2e07212..b28f1cb 100644 (file)
@@ -42,9 +42,9 @@ namespace arith {
 class IntervalSetNode : public IntSetNode {
  public:
   /*! \brief Minimum value in the interval. */
-  Expr min_value;
+  PrimExpr min_value;
   /*! \brief Maximum value in the interval. */
-  Expr max_value;
+  PrimExpr max_value;
 
   // visitor overload.
   void VisitAttrs(tvm::AttrVisitor* v) {
@@ -90,14 +90,14 @@ class IntervalSet : public IntSet {
    * \param max_value The maximum value in the interval.
    * \return The created set.
    */
-  TVM_DLL IntervalSet(Expr min_value, Expr max_value);
+  TVM_DLL IntervalSet(PrimExpr min_value, PrimExpr max_value);
 
   /*!
    * \brief Create an IntervalSet that represents a single point.
    * \param value The value to be represented.
    * \return The result set.
    */
-  static IntervalSet SinglePoint(Expr value) {
+  static IntervalSet SinglePoint(PrimExpr value) {
     return IntervalSet(value, value);
   }
   /*!
index 961c476..1345e7e 100644 (file)
@@ -38,7 +38,7 @@ VisitStmt_(const ForNode* op) {
 
 Stmt IRMutatorWithAnalyzer::
 VisitStmt_(const LetStmtNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     analyzer_->Bind(op->var, value);
   }
@@ -58,7 +58,7 @@ VisitStmt_(const LetStmtNode* op) {
 
 Stmt IRMutatorWithAnalyzer::
 VisitStmt_(const IfThenElseNode* op) {
-  Expr condition = this->VisitExpr(op->condition);
+  PrimExpr condition = this->VisitExpr(op->condition);
   Stmt then_case, else_case;
   {
     With<ConstraintContext> ctx(analyzer_, condition);
@@ -107,8 +107,8 @@ VisitStmt_(const AttrStmtNode* op) {
 
 Stmt IRMutatorWithAnalyzer::
 VisitStmt_(const AssertStmtNode* op) {
-  Expr condition = this->VisitExpr(op->condition);
-  Expr message = this->VisitExpr(op->message);
+  PrimExpr condition = this->VisitExpr(op->condition);
+  PrimExpr message = this->VisitExpr(op->message);
   With<ConstraintContext> ctx(analyzer_, condition);
   Stmt body = this->VisitStmt(op->body);
 
@@ -125,12 +125,12 @@ VisitStmt_(const AssertStmtNode* op) {
   }
 }
 
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
 VisitExpr_(const CallNode* op) {
   // add condition context to if_then_else
   if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
-    Expr cond = this->VisitExpr(op->args[0]);
-    Expr true_value, false_value;
+    PrimExpr cond = this->VisitExpr(op->args[0]);
+    PrimExpr true_value, false_value;
     {
       With<ConstraintContext> constraint(analyzer_, cond);
       true_value = this->VisitExpr(op->args[1]);
@@ -149,7 +149,7 @@ VisitExpr_(const CallNode* op) {
     if (cond.same_as(op->args[0]) &&
         true_value.same_as(op->args[1]) &&
         false_value.same_as(op->args[2])) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       return CallNode::make(op->dtype, op->name,
                         {cond, true_value, false_value},
@@ -159,27 +159,27 @@ VisitExpr_(const CallNode* op) {
   return StmtExprMutator::VisitExpr_(op);
 }
 
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
 VisitExpr_(const LetNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     analyzer_->Bind(op->var, value);
   }
   // We keep the let-binding here
   // as sub-class may or maynot choose to replace it.
-  Expr body = this->VisitExpr(op->body);
+  PrimExpr body = this->VisitExpr(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return LetNode::make(op->var, value, body);
   }
 }
 
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
 VisitExpr_(const SelectNode* op) {
-  Expr cond = this->VisitExpr(op->condition);
-  Expr true_value, false_value;
+  PrimExpr cond = this->VisitExpr(op->condition);
+  PrimExpr true_value, false_value;
   {
     With<ConstraintContext> constraint(analyzer_, cond);
     true_value = VisitExpr(op->true_value);
@@ -199,13 +199,13 @@ VisitExpr_(const SelectNode* op) {
   if (cond.same_as(op->condition) &&
       true_value.same_as(op->true_value) &&
       false_value.same_as(op->false_value)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return SelectNode::make(cond, true_value, false_value);
   }
 }
 
-Expr IRMutatorWithAnalyzer::
+PrimExpr IRMutatorWithAnalyzer::
 VisitExpr_(const ReduceNode* op) {
   // Setup the domain information before simplification.
   for (const IterVar& iv : op->axis) {
index 1e96c0a..a2297cb 100644 (file)
@@ -54,10 +54,10 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
   Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
   Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
   Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
-  Expr VisitExpr_(const ir::LetNode* op) override;
-  Expr VisitExpr_(const ir::SelectNode* op) override;
-  Expr VisitExpr_(const ir::CallNode* op) override;
-  Expr VisitExpr_(const ir::ReduceNode* op) override;
+  PrimExpr VisitExpr_(const ir::LetNode* op) override;
+  PrimExpr VisitExpr_(const ir::SelectNode* op) override;
+  PrimExpr VisitExpr_(const ir::CallNode* op) override;
+  PrimExpr VisitExpr_(const ir::ReduceNode* op) override;
 
  protected:
   /*! \brief internal analyzer field. */
index 07ec186..08be59b 100644 (file)
@@ -34,7 +34,7 @@ namespace ir {
 
 class IRVisitorWithAnalyzer final : public StmtExprVisitor {
  public:
-  Expr Simplify(const Expr& expr) {
+  PrimExpr Simplify(const PrimExpr& expr) {
     return analyzer_.Simplify(expr);
   }
 
index 8e2e065..01dd2e8 100644 (file)
@@ -85,7 +85,7 @@ struct ModularSetAnalyzer::Entry {
 };
 
 class ModularSetAnalyzer::Impl :
-      public ExprFunctor<ModularSetAnalyzer::Entry(const Expr&)> {
+      public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
  public:
   explicit Impl(Analyzer* parent)
       : parent_(parent) {}
@@ -107,7 +107,7 @@ class ModularSetAnalyzer::Impl :
   }
 
   // Detect useful constraints and use them in the analysis scope.
-  std::function<void()> EnterConstraint(const Expr& constraint) {
+  std::function<void()> EnterConstraint(const PrimExpr& constraint) {
     PVar<Var> var;
     PVar<Integer> coeff, base;
     // pattern match interesting constraints
@@ -168,7 +168,7 @@ class ModularSetAnalyzer::Impl :
     return Entry(coeff, a.base * b.base);
   }
 
-  Entry DivByConst(const Expr& lhs,
+  Entry DivByConst(const PrimExpr& lhs,
                    int64_t val,
                    bool round_down) {
     Entry a = VisitExpr(lhs);
@@ -255,7 +255,7 @@ class ModularSetAnalyzer::Impl :
   /*! \brief pointer to parent. */
   Analyzer* parent_{nullptr};
   // internal variable map
-  std::unordered_map<Var, Entry, ExprHash, ExprEqual> var_map_;
+  std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
   /*!
    * \brief Update var by intersecting entry with var's current set.
    * \param var The variable.
@@ -398,7 +398,7 @@ class ModularSetAnalyzer::Impl :
   }
 };
 
-ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
+ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
   Entry ret = impl_->VisitExpr(expr);
   return ModularSet(ret.coeff, ret.base);
 }
@@ -409,7 +409,7 @@ void ModularSetAnalyzer::Update(const Var& var,
   impl_->Update(var, info, override);
 }
 
-std::function<void()> ModularSetAnalyzer::EnterConstraint(const Expr& constraint) {
+std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
   return impl_->EnterConstraint(constraint);
 }
 
index e964abb..733dcf4 100644 (file)
@@ -131,9 +131,9 @@ class PEqualChecker {
 };
 
 template<>
-class PEqualChecker<Expr> {
+class PEqualChecker<PrimExpr> {
  public:
-  bool operator()(const Expr& lhs, const Expr& rhs) const {
+  bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
     if (lhs.same_as(rhs)) return true;
     return ir::Equal(lhs, rhs);
   }
@@ -260,10 +260,10 @@ class PBinaryExpr :
     }
   }
 
-  Expr Eval() const {
-    Expr lhs = a_.Eval();
-    Expr rhs = b_.Eval();
-    Expr ret = TryConstFold<NodeType>(lhs, rhs);
+  PrimExpr Eval() const {
+    PrimExpr lhs = a_.Eval();
+    PrimExpr rhs = b_.Eval();
+    PrimExpr ret = TryConstFold<NodeType>(lhs, rhs);
     if (ret.defined()) return ret;
     return NodeType::make(lhs, rhs);
   }
@@ -290,7 +290,7 @@ class PConstWithTypeLike :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return make_const(ref_.Eval().dtype(), value_);
   }
 
@@ -373,7 +373,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return ir::NotNode::make(value_.Eval());
   }
 
@@ -421,7 +421,7 @@ class PSelectExpr :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return ir::SelectNode::make(
         condition_.Eval(), true_value_.Eval(), false_value_.Eval());
   }
@@ -482,7 +482,7 @@ class PCastExpr :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return ir::CastNode::make(dtype_.Eval(), value_.Eval());
   }
 
@@ -541,7 +541,7 @@ class PRampExpr :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
   }
 
@@ -602,7 +602,7 @@ class PBroadcastExpr :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
   }
 
@@ -675,7 +675,7 @@ struct PCallExprMatchFunctor {
 };
 
 struct PCallExprEvalArgsFunctor {
-  Array<Expr> args_;
+  Array<PrimExpr> args_;
 
   template<typename T>
   void operator()(size_t i, const T& pattern) {
@@ -716,7 +716,7 @@ class PCallExpr :
     }
   }
 
-  Expr Eval() const {
+  PrimExpr Eval() const {
     detail::PCallExprEvalArgsFunctor feval_args;
     detail::tuple_for_each(feval_args, args_);
     return Op::Eval(feval_args.args_);
@@ -729,7 +729,7 @@ class PCallExpr :
 // arithemetic intrinsics
 #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)          \
   struct OpName {                                                       \
-    static Expr Eval(Array<Expr> args) {                                \
+    static PrimExpr Eval(Array<PrimExpr> args) {                                \
       return ir::CallNode::make(args[0].dtype(), kName, args,           \
                                 ir::CallNode::PureIntrinsic);           \
     }                                                                   \
@@ -750,7 +750,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
 // unary intrinsics
 #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)           \
   struct OpName {                                                       \
-    static Expr Eval(Array<Expr> args) {                                \
+    static PrimExpr Eval(Array<PrimExpr> args) {                                \
       return ir::CallNode::make(args[0].dtype(), kName, args,           \
                                 ir::CallNode::PureIntrinsic);           \
     }                                                                   \
@@ -766,7 +766,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
 
 // if_then_else
 struct PIfThenElseOp {
-  static Expr Eval(Array<Expr> args) {
+  static PrimExpr Eval(Array<PrimExpr> args) {
     return ir::CallNode::make(
         args[1].dtype(), kName, args,
         ir::CallNode::PureIntrinsic);
index 2421e10..94d951d 100644 (file)
@@ -67,8 +67,8 @@ using namespace ir;
 
 // try to prove x equals val
 RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
-TryCompare(const Expr& x, int64_t val) {
-  Expr diff = this->VisitExpr(x);
+TryCompare(const PrimExpr& x, int64_t val) {
+  PrimExpr diff = this->VisitExpr(x);
   if (const auto* ptr = diff.as<IntImmNode>()) {
     if (ptr->value == val) {
       return kEQ;
@@ -101,7 +101,7 @@ TryCompare(const Expr& x, int64_t val) {
 }
 
 void RewriteSimplifier::Impl::
-Update(const Var& var, const Expr& info, bool override) {
+Update(const Var& var, const PrimExpr& info, bool override) {
   if (!override) {
     auto it = var_map_.find(var);
     if (it != var_map_.end()) {
@@ -115,14 +115,14 @@ Update(const Var& var, const Expr& info, bool override) {
   var_map_[var] = info;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const AddNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<AddNode>();
-  Expr const_res = TryConstFold<AddNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<AddNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1, b2, s1, s2;
+  PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
@@ -219,7 +219,7 @@ VisitExpr_(const AddNode* op) {
   return ret;
 }
 
-std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& constraint) {
+std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
   size_t old_literal_size = literal_constraints_.size();
   literal_constraints_.push_back(constraint);
   size_t new_literal_size = literal_constraints_.size();
@@ -230,14 +230,14 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const
   return frecover;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const SubNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<SubNode>();
-  Expr const_res = TryConstFold<SubNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<SubNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1, b2, s1, s2;
+  PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
@@ -429,14 +429,14 @@ VisitExpr_(const SubNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const MulNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MulNode>();
-  Expr const_res = TryConstFold<MulNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<MulNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1, b2, s1, s2;
+  PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   // Pattern var for lanes in broadcast and ramp
@@ -468,14 +468,14 @@ VisitExpr_(const MulNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const DivNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<DivNode>();
-  Expr const_res = TryConstFold<DivNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<DivNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1;
+  PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
   PVar<Integer> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
@@ -690,15 +690,15 @@ VisitExpr_(const DivNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const ModNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<ModNode>();
-  Expr const_res = TryConstFold<ModNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<ModNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1;
+  PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   // Pattern var for lanes in broadcast and ramp
@@ -763,7 +763,7 @@ VisitExpr_(const ModNode* op) {
     // NOTE: trunc div required
     TVM_TRY_RECURSIVE_REWRITE_IF(
         truncmod(x, c1),
-        truncmod(x, PConst<Expr>(make_const(op->dtype, -c1.Eval()->value))),
+        truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))),
         c1.Eval()->value < 0);
 
     // try modular analysis
@@ -780,14 +780,14 @@ VisitExpr_(const ModNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const FloorDivNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<FloorDivNode>();
-  Expr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1;
+  PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
   PVar<Integer> c1, c2, c3;
   // Pattern var for lanes in broadcast and ramp
@@ -924,15 +924,15 @@ VisitExpr_(const FloorDivNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const FloorModNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<FloorModNode>();
-  Expr const_res = TryConstFold<FloorModNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<FloorModNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, b1;
+  PVar<PrimExpr> x, y, z, b1;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   // Pattern var for lanes in broadcast and ramp
@@ -994,15 +994,15 @@ VisitExpr_(const FloorModNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const MinNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MinNode>();
-  Expr const_res = TryConstFold<MinNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<MinNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, s1, s2;
+  PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   PVar<int> lanes;
@@ -1179,15 +1179,15 @@ VisitExpr_(const MinNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const MaxNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<MaxNode>();
-  Expr const_res = TryConstFold<MaxNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<MaxNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, s1, s2;
+  PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   PVar<int> lanes;
@@ -1352,15 +1352,15 @@ VisitExpr_(const MaxNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const EQNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<EQNode>();
-  Expr const_res = TryConstFold<EQNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<EQNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y;
+  PVar<PrimExpr> x, y;
   // Pattern var match IntImm
   PVar<Integer> c1;
   PVar<int> lanes;
@@ -1386,35 +1386,35 @@ VisitExpr_(const EQNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const NENode* op) {
   return this->VisitExpr(NotNode::make(op->a == op->b));
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const LENode* op) {
   return this->VisitExpr(NotNode::make(op->b < op->a));
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const GTNode* op) {
   return this->VisitExpr(op->b < op->a);
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const GENode* op) {
   return this->VisitExpr(NotNode::make(op->a < op->b));
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const LTNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<LTNode>();
-  Expr const_res = TryConstFold<LTNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<LTNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y, z, s1, s2;
+  PVar<PrimExpr> x, y, z, s1, s2;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   PVar<int> lanes;
@@ -1562,14 +1562,14 @@ VisitExpr_(const LTNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const NotNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<NotNode>();
-  Expr const_res = TryConstFold<NotNode>(op->a);
+  PrimExpr const_res = TryConstFold<NotNode>(op->a);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
-  PVar<Expr> x, y;
+  PVar<PrimExpr> x, y;
   PVar<int> lanes;
   if (op->dtype.lanes() != 1) {
     TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
@@ -1587,15 +1587,15 @@ VisitExpr_(const NotNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const AndNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<AndNode>();
-  Expr const_res = TryConstFold<AndNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<AndNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y;
+  PVar<PrimExpr> x, y;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   PVar<int> lanes;
@@ -1605,7 +1605,7 @@ VisitExpr_(const AndNode* op) {
                     broadcast(x && y, lanes));
   }
 
-  auto cfalse = PConst<Expr>(make_const(op->dtype, false));
+  auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
   TVM_TRY_REWRITE(x == y && x != y, cfalse);
   TVM_TRY_REWRITE(x != y && x == y, cfalse);
   TVM_TRY_REWRITE(x && !x, cfalse);
@@ -1636,15 +1636,15 @@ VisitExpr_(const AndNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const OrNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<OrNode>();
-  Expr const_res = TryConstFold<OrNode>(op->a, op->b);
+  PrimExpr const_res = TryConstFold<OrNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
-  PVar<Expr> x, y;
+  PVar<PrimExpr> x, y;
   // Pattern var match IntImm
   PVar<Integer> c1, c2;
   PVar<int> lanes;
@@ -1654,7 +1654,7 @@ VisitExpr_(const OrNode* op) {
                     broadcast(x || y, lanes));
   }
 
-  auto ctrue = PConst<Expr>(make_const(op->dtype, true));
+  auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));
 
   TVM_TRY_REWRITE(x == y || x != y, ctrue);
   TVM_TRY_REWRITE(x != y || x == y, ctrue);
@@ -1686,21 +1686,21 @@ VisitExpr_(const OrNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const SelectNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<SelectNode>();
   if (op == nullptr) return ret;
   // Pattern var to match any expression
-  PVar<Expr> x, y;
+  PVar<PrimExpr> x, y;
   TVM_TRY_REWRITE(select(x, y, y), y);
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const CallNode* op) {
   // add condition context to if_then_else
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<CallNode>();
   if (op == nullptr) return ret;
   if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) {
@@ -1727,47 +1727,47 @@ VisitExpr_(const CallNode* op) {
   return ret;
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const VarNode* op) {
   Var var = GetRef<Var>(op);
   auto it = var_map_.find(var);
   if (it != var_map_.end()) {
     return it->second;
   }
-  return GetRef<Expr>(op);
+  return GetRef<PrimExpr>(op);
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const CastNode* op) {
-  Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
   op = ret.as<CastNode>();
   return cast(op->dtype, op->value);
 }
 
-Expr RewriteSimplifier::Impl::
+PrimExpr RewriteSimplifier::Impl::
 VisitExpr_(const LetNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     // it is fine to discard the let binding
     // because the value will always be inlined in the simplifier.
     analyzer_->Bind(op->var, value);
     return this->VisitExpr(op->body);
   }
-  Expr body = this->VisitExpr(op->body);
+  PrimExpr body = this->VisitExpr(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return LetNode::make(op->var, value, body);
   }
 }
 
-Expr RewriteSimplifier::operator()(const Expr& expr) {
+PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) {
   // Run simplification in post order
-  Expr res = expr;
+  PrimExpr res = expr;
   int max_iter = 2;
   for (int i = 0; i < max_iter; ++i) {
-    Expr new_expr = impl_->operator()(res);
+    PrimExpr new_expr = impl_->operator()(res);
     if (new_expr.same_as(res)) return res;
     res = new_expr;
   }
@@ -1775,12 +1775,12 @@ Expr RewriteSimplifier::operator()(const Expr& expr) {
 }
 
 void RewriteSimplifier::Update(const Var& var,
-                               const Expr& info,
+                               const PrimExpr& info,
                                bool override) {
   impl_->Update(var, info, override);
 }
 
-std::function<void()> RewriteSimplifier::EnterConstraint(const Expr& constraint) {
+std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
   return impl_->EnterConstraint(constraint);
 }
 
index f2659a9..6b4193c 100644 (file)
@@ -49,32 +49,32 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
   explicit Impl(Analyzer* parent)
       : IRMutatorWithAnalyzer(parent) {}
 
-  void Update(const Var& var, const Expr& info, bool override_info);
-  Expr VisitExpr_(const AddNode* op) override;
-  Expr VisitExpr_(const SubNode* op) override;
-  Expr VisitExpr_(const MulNode* op) override;
-  Expr VisitExpr_(const DivNode* op) override;
-  Expr VisitExpr_(const ModNode* op) override;
-  Expr VisitExpr_(const FloorDivNode* op) override;
-  Expr VisitExpr_(const FloorModNode* op) override;
-  Expr VisitExpr_(const MinNode* op) override;
-  Expr VisitExpr_(const MaxNode* op) override;
-  Expr VisitExpr_(const EQNode* op) override;
-  Expr VisitExpr_(const NENode* op) override;
-  Expr VisitExpr_(const LTNode* op) override;
-  Expr VisitExpr_(const LENode* op) override;
-  Expr VisitExpr_(const GTNode* op) override;
-  Expr VisitExpr_(const GENode* op) override;
-  Expr VisitExpr_(const AndNode* op) override;
-  Expr VisitExpr_(const OrNode* op) override;
-  Expr VisitExpr_(const NotNode* op) override;
-  Expr VisitExpr_(const SelectNode* op) override;
-  Expr VisitExpr_(const CallNode* op) override;
-  Expr VisitExpr_(const VarNode* op) override;
-  Expr VisitExpr_(const CastNode* op) override;
-  Expr VisitExpr_(const LetNode* op) override;
-
-  std::function<void()> EnterConstraint(const Expr& constraint);
+  void Update(const Var& var, const PrimExpr& info, bool override_info);
+  PrimExpr VisitExpr_(const AddNode* op) override;
+  PrimExpr VisitExpr_(const SubNode* op) override;
+  PrimExpr VisitExpr_(const MulNode* op) override;
+  PrimExpr VisitExpr_(const DivNode* op) override;
+  PrimExpr VisitExpr_(const ModNode* op) override;
+  PrimExpr VisitExpr_(const FloorDivNode* op) override;
+  PrimExpr VisitExpr_(const FloorModNode* op) override;
+  PrimExpr VisitExpr_(const MinNode* op) override;
+  PrimExpr VisitExpr_(const MaxNode* op) override;
+  PrimExpr VisitExpr_(const EQNode* op) override;
+  PrimExpr VisitExpr_(const NENode* op) override;
+  PrimExpr VisitExpr_(const LTNode* op) override;
+  PrimExpr VisitExpr_(const LENode* op) override;
+  PrimExpr VisitExpr_(const GTNode* op) override;
+  PrimExpr VisitExpr_(const GENode* op) override;
+  PrimExpr VisitExpr_(const AndNode* op) override;
+  PrimExpr VisitExpr_(const OrNode* op) override;
+  PrimExpr VisitExpr_(const NotNode* op) override;
+  PrimExpr VisitExpr_(const SelectNode* op) override;
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  PrimExpr VisitExpr_(const VarNode* op) override;
+  PrimExpr VisitExpr_(const CastNode* op) override;
+  PrimExpr VisitExpr_(const LetNode* op) override;
+
+  std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
  protected:
   /*! \brief internal structure for comparison. */
@@ -90,9 +90,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
   // counter to record recursive rewrite depth.
   int recur_depth_{0};
   // internal variable map
-  std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
+  std::unordered_map<Var, PrimExpr, ObjectHash, ObjectEqual> var_map_;
 
-  std::vector<Expr> literal_constraints_;
+  std::vector<PrimExpr> literal_constraints_;
 
   // maximum number of recursion allowed during a single pass.
   static const constexpr int kMaxRecurDepth = 5;
@@ -103,15 +103,15 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
    * \param val The constant value.
    * \return comparison result.
    */
-  CompareResult TryCompare(const Expr& x, int64_t val);
+  CompareResult TryCompare(const PrimExpr& x, int64_t val);
 
  private:
   // Whether x >= val
-  bool CanProveGreaterEqual(const Expr& x, int64_t val) {
+  bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
     return analyzer_->CanProveGreaterEqual(x, val);
   }
   // Whether x == val
-  bool CanProveEqual(const Expr& x, int64_t val) {
+  bool CanProveEqual(const PrimExpr& x, int64_t val) {
     // TODO(tqchen) refer back to super-analyzer.
     return TryCompare(x, val) == kEQ;
   }
@@ -119,10 +119,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
   // Recursive rewrite x
   // we limit maximum depth of recursive rewrite allowed to
   // avoid infinite loop
-  Expr RecursiveRewrite(const Expr& x) {
+  PrimExpr RecursiveRewrite(const PrimExpr& x) {
     if (recur_depth_ >= kMaxRecurDepth) return x;
     ++recur_depth_;
-    Expr res = this->VisitExpr(x);
+    PrimExpr res = this->VisitExpr(x);
     --recur_depth_;
     return res;
   }
index 73b5dce..dcc7e5d 100644 (file)
@@ -42,7 +42,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   using Parent::VisitStmt;
   using Parent::VisitStmt_;
 
-  Expr VisitExpr(const Expr& expr) final {
+  PrimExpr VisitExpr(const PrimExpr& expr) final {
     return analyzer_->Simplify(expr);
   }
 
@@ -58,7 +58,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   }
 
   Stmt VisitStmt_(const LetStmtNode* op) {
-    Expr value = this->VisitExpr(op->value);
+    PrimExpr value = this->VisitExpr(op->value);
     if (!ir::HasSideEffect(value)) {
       // it is fine to discard the let binding
       // because the call to simplify will always inline the var.
@@ -103,7 +103,7 @@ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
   return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt));
 }
 
-Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
+PrimExpr CanonicalSimplify(PrimExpr expr, Map<Var, Range> vrange) {
   arith::Analyzer analyzer;
   for (auto kv : vrange) {
     analyzer.Bind(kv.first, kv.second);
@@ -111,7 +111,7 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
   return analyzer.canonical_simplify(expr);
 }
 
-Expr Simplify(Expr expr, Map<Var, Range> vrange) {
+PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
   arith::Analyzer analyzer;
   for (auto kv : vrange) {
     analyzer.Bind(kv.first, kv.second);
index 11452d3..a83d248 100644 (file)
@@ -60,7 +60,7 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
 void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::thread_extent ||
       op->attr_key == attr::virtual_thread) {
-    VarExpr var = op->node.as<tvm::IterVarNode>()->var;
+    Var var = op->node.as<tvm::IterVarNode>()->var;
     const auto *extent = op->value.as<IntImmNode>();
     CHECK(extent);
 
index 9f65fb4..b2ea80f 100644 (file)
@@ -69,7 +69,7 @@ class FeatureVisitor : public StmtExprVisitor {
  * \param ann_type The type for the for loop
  * \return skip Whether skip this node
  */
-  virtual bool EnterItervar_(tvm::VarExpr var, int64_t length, AnnotationType ann_type) = 0;
+  virtual bool EnterItervar_(tvm::Var var, int64_t length, AnnotationType ann_type) = 0;
   /*! \brief Exit a for loop subtree */
   virtual void ExitItervar_() = 0;
   /*!
@@ -77,7 +77,7 @@ class FeatureVisitor : public StmtExprVisitor {
    * \param buffer_var The buffer to access.
    * \param index Index expression
    */
-  virtual void EnterMem_(tvm::VarExpr buffer_var, tvm::Expr index) = 0;
+  virtual void EnterMem_(tvm::Var buffer_var, tvm::PrimExpr index) = 0;
   /*! \brief Exit a memory access node */
   virtual void ExitMem_() = 0;
 };
index 0ee4b11..cf138ed 100644 (file)
@@ -46,7 +46,7 @@ int ParallelLevel(AnnotationType ann) {
 // get touch pattern from index expression
 class IndexParser: public ExprVisitor {
  public:
-  void Parse(Expr expr) {
+  void Parse(PrimExpr expr) {
     pattern_map.clear();
     this->VisitExpr(expr);
   }
@@ -76,7 +76,7 @@ class IndexParser: public ExprVisitor {
 };
 
 // extract iter vars and their touch pattern from ir
-bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type) {
+bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_type) {
   // do not insert duplicated occurrences of virtual thread
   if (ann_type == kVirtualThread && itervar_map.count(var) != 0) {
     skip_stack_size_.push_back(itervar_stack_.size());
@@ -90,7 +90,7 @@ bool TouchExtractor::EnterItervar_(VarExpr var, int64_t length, AnnotationType a
       // these happens when we create tvm.thread_axis("threadIdx.x") once and
       // bind it twice. Here we treat them as two axes
       // so we create a snapshot for the old one and freeze it
-      VarExpr old = VarExpr(var.get()->name_hint);
+      Var old = Var(var.get()->name_hint);
       itervar_map.insert({old, itervar_map[var]});
       itervar_map.erase(var);
     }
@@ -110,7 +110,7 @@ void TouchExtractor::ExitItervar_() {
     skip_stack_size_.pop_back();
     return;
   }
-  VarExpr var = itervar_stack_.back();
+  Var var = itervar_stack_.back();
 
   // update count and reuse ratio for upper iter vars (includes self)
   for (auto kv : itervar_map[var].touch_feature) {
@@ -169,7 +169,7 @@ void TouchExtractor::ExitItervar_() {
   }
 }
 
-void TouchExtractor::EnterMem_(VarExpr buffer_var, Expr index) {
+void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) {
   std::string name = buffer_var.get()->name_hint;
   TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++);
 
@@ -219,17 +219,17 @@ void TouchExtractor::ExitMem_() {
  * \note If you want to flatten these features as the input of your model,
  * You can use the faster one GetItervarFeatureFlatten below.
  */
-void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *ret_feature) {
+void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > > *ret_feature) {
   // extract
   TouchExtractor touch_analyzer;
   touch_analyzer.Analyze(stmt);
 
   // sort according to order
-  std::vector<VarExpr> vars;
+  std::vector<Var> vars;
   for (auto kv : touch_analyzer.itervar_map) {
     vars.push_back(kv.first);
   }
-  std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+  std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
     return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
   });
 
@@ -250,11 +250,11 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
 
   // serialize for front end
   for (auto var : vars) {
-    Array<Array<Expr> > feature_row;
+    Array<Array<PrimExpr> > feature_row;
     ItervarFeature &fea = touch_analyzer.itervar_map[var];
-    feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
+    feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
 
-    Array<Expr> attr{std::string("_attr_"),
+    Array<PrimExpr> attr{std::string("_attr_"),
                      FloatImmNode::make(DataType::Float(32), trans(fea.length)),
                      IntImmNode::make(DataType::Int(32), fea.nest_level),
                      FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
@@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
     feature_row.push_back(attr);
 
     // arithmetic
-    feature_row.push_back(Array<Expr>{std::string("_arith_"),
+    feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
             FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
             FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
             FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
@@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
     for (auto k : bufs) {
       TouchPattern &v = fea.touch_feature[k];
       feature_row.push_back(
-          Array<Expr>{k,
+          Array<PrimExpr>{k,
                 FloatImmNode::make(DataType::Float(32), trans(v.stride)),
                 FloatImmNode::make(DataType::Float(32), trans(v.mod)),
                 FloatImmNode::make(DataType::Float(32), trans(v.count)),
@@ -311,11 +311,11 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float> *ret_
   touch_analyzer.Analyze(stmt);
 
   // sort according to order
-  std::vector<VarExpr> vars;
+  std::vector<Var> vars;
   for (auto kv : touch_analyzer.itervar_map) {
     vars.push_back(kv.first);
   }
-  std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+  std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
     return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
   });
 
@@ -383,11 +383,11 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float> *r
   touch_ext.Analyze(stmt);
 
   // sort according to order
-  std::vector<VarExpr> vars;
+  std::vector<Var> vars;
   for (auto kv : touch_ext.itervar_map) {
     vars.push_back(kv.first);
   }
-  std::sort(vars.begin(), vars.end(), [&](const VarExpr &lhs, const VarExpr &rhs) -> bool {
+  std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool {
     return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
   });
 
@@ -490,7 +490,7 @@ TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
   Stmt stmt = args[0];
   bool take_log = args[1];
-  Array<Array<Array<Expr > > > ret_feature;
+  Array<Array<Array<PrimExpr > > > ret_feature;
 
   GetItervarFeature(stmt, take_log, &ret_feature);
 
index 5265aad..3af368d 100644 (file)
@@ -56,7 +56,7 @@ struct TouchPattern {
 
 // all the feature of an iter var
 struct ItervarFeature {
-  ItervarFeature(VarExpr var,
+  ItervarFeature(Var var,
                  int64_t extent,
                  int nest,
                  AnnotationType ann_type,
@@ -122,18 +122,18 @@ class TouchExtractor : public FeatureVisitor {
     FeatureVisitor::VisitExpr_(op);
   }
 
-  std::unordered_map<VarExpr, ItervarFeature, tvm::ExprHash, tvm::ExprEqual> itervar_map;
+  std::unordered_map<Var, ItervarFeature, tvm::ObjectHash, tvm::ObjectEqual> itervar_map;
 
  private:
-  bool EnterItervar_(VarExpr var, int64_t length, AnnotationType ann_type);
+  bool EnterItervar_(Var var, int64_t length, AnnotationType ann_type);
   void ExitItervar_();
-  void EnterMem_(VarExpr buffer_var, Expr index);
+  void EnterMem_(Var buffer_var, PrimExpr index);
   void ExitMem_();
 
   int64_t topdown_product_{1};
   std::map<std::string, size_t> buffer_counter_;
   size_t itervar_counter_{0};
-  std::deque<VarExpr> itervar_stack_;  // use deque instead of stack for indexing
+  std::deque<Var> itervar_stack_;  // use deque instead of stack for indexing
   std::deque<size_t> skip_stack_size_;
 
   using FeatureVisitor::VisitExpr_;
index 77b1c9d..9f79342 100644 (file)
@@ -338,7 +338,7 @@ Target DefaultTargetHost(Target target) {
   }
 }
 
-Buffer BufferWithOffsetAlignment(Array<Expr> shape,
+Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape,
                                  DataType dtype,
                                  std::string name,
                                  int data_alignment,
@@ -356,14 +356,14 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
   }
   BufferType buffer_type = has_any ? kAutoBroadcast : kDefault;
 
-  Expr elem_offset;
+  PrimExpr elem_offset;
   if (offset_factor != 0) {
     elem_offset = Var(name + "_elem_offset", shape[0].dtype());
   } else {
-    elem_offset = Expr();
+    elem_offset = PrimExpr();
   }
 
-  return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
+  return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
     data_alignment, offset_factor, buffer_type);
 }
 
@@ -855,7 +855,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
   GenericFunc generic_func = args[0];
   // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
   PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
-  Array<Expr> tags = args[2];
+  Array<PrimExpr> tags = args[2];
   bool allow_override = args[3];
 
   std::vector<std::string> tags_vector;
index ea9d7ba..777ad62 100644 (file)
@@ -121,7 +121,7 @@ std::string CodeGenC::Finish() {
   return decl_stream.str() + stream.str();
 }
 
-void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
   if (print_ssa_form_) {
     std::ostringstream temp;
     VisitExpr(n, temp);
@@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign(
 
 // Print a reference expression to a buffer.
 std::string CodeGenC::GetBufferRef(
-    DataType t, const VarNode* buffer, Expr index) {
+    DataType t, const VarNode* buffer, PrimExpr index) {
   std::ostringstream os;
   std::string vid = GetVarID(buffer);
   std::string scope;
@@ -213,7 +213,7 @@ std::string CodeGenC::GetBufferRef(
 
 // Print a reference expression to a buffer.
 std::string CodeGenC::GetStructRef(
-    DataType t, const Expr& buffer, const Expr& index, int kind) {
+    DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) {
   if (kind < intrinsic::kArrKindBound_) {
     std::ostringstream os;
     os << "(((TVMArray*)";
@@ -296,12 +296,12 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
 }
 
 std::string CodeGenC::GetVecLoad(
-    DataType t, const VarNode* buffer, Expr base) {
+    DataType t, const VarNode* buffer, PrimExpr base) {
   return GetBufferRef(t, buffer, base);
 }
 
 void CodeGenC::PrintVecStore(const VarNode* buffer,
-                             DataType t, Expr base,
+                             DataType t, PrimExpr base,
                              const std::string& value) {
   std::string ref = GetBufferRef(t, buffer, base);
   this->PrintIndent();
@@ -594,7 +594,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
 
 void CodeGenC::PrintVecBinaryOp(
     const std::string& op, DataType t,
-    Expr lhs, Expr rhs, std::ostream& os) {  // NOLINT(*)
+    PrimExpr lhs, PrimExpr rhs, std::ostream& os) {  // NOLINT(*)
   if (isalpha(op[0])) {
     os << op << "(";
     this->PrintExpr(lhs, os);
@@ -619,7 +619,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
   } else {
     CHECK(is_one(op->predicate))
         << "predicated load is not supported";
-    Expr base;
+    PrimExpr base;
     if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
       std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
       os << ref;
@@ -673,7 +673,7 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
   } else {
     CHECK(is_one(op->predicate))
         << "Predicated store is not supported";
-    Expr base;
+    PrimExpr base;
     if (GetRamp1Base(op->index, t.lanes(), &base)) {
       std::string value = this->PrintExpr(op->value);
       this->PrintVecStore(op->buffer_var.get(), t, base, value);
index 593bbcd..cb092c5 100644 (file)
@@ -49,7 +49,7 @@ using namespace ir;
  * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`.
  */
 class CodeGenC :
-      public ExprFunctor<void(const Expr&, std::ostream&)>,
+      public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
       public StmtFunctor<void(const Stmt&)>,
       public CodeGenSourceBase {
  public:
@@ -80,12 +80,12 @@ class CodeGenC :
    * \param n The expression to be printed.
    * \param os The output stream
    */
-  void PrintExpr(const Expr& n, std::ostream& os);
+  void PrintExpr(const PrimExpr& n, std::ostream& os);
   /*!
    * \brief Same as PrintExpr, but simply returns result string
    * \param n The expression to be printed.
    */
-  std::string PrintExpr(const Expr& n) {
+  std::string PrintExpr(const PrimExpr& n) {
     std::ostringstream os;
     PrintExpr(n, os);
     return os.str();
@@ -158,12 +158,12 @@ class CodeGenC :
   // Binary vector op.
   virtual void PrintVecBinaryOp(
       const std::string&op, DataType op_type,
-      Expr lhs, Expr rhs, std::ostream& os);  // NOLINT(*)
+      PrimExpr lhs, PrimExpr rhs, std::ostream& os);  // NOLINT(*)
   // print vector load
-  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
+  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base);
   // print vector store
   virtual void PrintVecStore(const VarNode* buffer,
-                             DataType t, Expr base,
+                             DataType t, PrimExpr base,
                              const std::string& value);  // NOLINT(*)
   // print load of single element
   virtual void PrintVecElemLoad(
@@ -177,10 +177,10 @@ class CodeGenC :
  protected:
   // Print reference to struct location
   std::string GetStructRef(
-      DataType t, const Expr& buffer, const Expr& index, int kind);
+      DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
   // print reference to a buffer as type t in index.
   virtual std::string GetBufferRef(
-      DataType t, const VarNode* buffer, Expr index);
+      DataType t, const VarNode* buffer, PrimExpr index);
   /*!
    * \brief If buffer is allocated as type t.
    * \param buf_var The buffer variable.
index 53a008d..d06e9aa 100644 (file)
@@ -200,7 +200,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
 
 void CodeGenCUDA::PrintVecBinaryOp(
     const std::string&op, DataType t,
-    Expr lhs, Expr rhs, std::ostream& os) {  // NOLINT(*)
+    PrimExpr lhs, PrimExpr rhs, std::ostream& os) {  // NOLINT(*)
   // unpacking operations.
   int lanes = t.lanes();
 
index fc2e6ae..23fbf7f 100644 (file)
@@ -48,7 +48,7 @@ class CodeGenCUDA final : public CodeGenC {
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
   void PrintVecBinaryOp(
       const std::string&op, DataType t,
-      Expr lhs, Expr rhs, std::ostream& os) final;  // NOLINT(*)
+      PrimExpr lhs, PrimExpr rhs, std::ostream& os) final;  // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
   void PrintVecElemLoad(
       const std::string& vec, DataType t, int i, std::ostream& os) final;  // NOLINT(*)
index 8914db8..ef90cfc 100644 (file)
@@ -145,7 +145,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
 }
 
 void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
-                                 Expr base, std::ostream& os) {  // NOLINT(*)
+                                 PrimExpr base, std::ostream& os) {  // NOLINT(*)
   if (!HandleTypeMatch(buffer, t.element_of())) {
     os << '(';
     auto it = alloc_storage_scope_.find(buffer);
@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
   PrintExpr(base, os);
 }
 std::string CodeGenOpenCL::GetVecLoad(
-    DataType t, const VarNode* buffer, Expr base) {
+    DataType t, const VarNode* buffer, PrimExpr base) {
   std::ostringstream os;
   os << "vload" << t.lanes() << "(0, ";
   PrintVecAddr(buffer, t, base, os);
@@ -169,7 +169,7 @@ std::string CodeGenOpenCL::GetVecLoad(
 }
 
 void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
-                                  DataType t, Expr base,
+                                  DataType t, PrimExpr base,
                                   const std::string& value) {
   this->PrintIndent();
   stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
index a606a3a..07b28fd 100644 (file)
@@ -45,13 +45,13 @@ class CodeGenOpenCL final : public CodeGenC {
   void PrintStorageSync(const CallNode* op) final;  // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
   std::string GetVecLoad(DataType t, const VarNode* buffer,
-                         Expr base) final;
+                         PrimExpr base) final;
   void PrintVecStore(const VarNode* buffer,
-                     DataType t, Expr base,
+                     DataType t, PrimExpr base,
                      const std::string& value) final;  // NOLINT(*)
   // the address of load/store
   void PrintVecAddr(const VarNode* buffer, DataType t,
-                    Expr base, std::ostream& os);  // NOLINT(*)
+                    PrimExpr base, std::ostream& os);  // NOLINT(*)
   std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
 
   // overload visitor
index 5666de3..7967c18 100644 (file)
@@ -194,7 +194,7 @@ void CodeGenOpenGL::VisitStmt_(const StoreNode* op) {
 }
 
 // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
-std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
+std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) {
   std::ostringstream os;
   os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
   PrintExpr(index, os);
@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
 // Print a reference expression to a buffer.
 // Format: texelFetch(buffer, index, 0).r
 std::string CodeGenOpenGL::GetBufferRef(
-    DataType t, const VarNode* buffer, Expr index) {
+    DataType t, const VarNode* buffer, PrimExpr index) {
   CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
   CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
 
index bb69365..cd1ec83 100644 (file)
@@ -44,8 +44,8 @@ class CodeGenOpenGL final : public CodeGenC {
   void InitFuncState(LoweredFunc f) final;
   void BindThreadIndex(const IterVar& iv) final;
   void VisitStmt_(const StoreNode* op) final;
-  std::string TexelFetch(const VarNode* buffer, Expr index);
-  std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
+  std::string TexelFetch(const VarNode* buffer, PrimExpr index);
+  std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final;
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
 
   // Codegen for immediate values
index e7231a1..a1bd9f0 100644 (file)
@@ -140,7 +140,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
   std::string whole_code = cg.Finish();
 
   // Generate source code for compilation.
-  Array<Array<Expr> > kernel_info;
+  Array<Array<PrimExpr> > kernel_info;
   for (LoweredFunc f : funcs) {
     CodeGenVivadoHLS cg;
     cg.Init(output_ssa);
@@ -149,7 +149,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
     if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
       code = (*f)(code).operator std::string();
     }
-    kernel_info.push_back(Array<Expr>({f->name, code}));
+    kernel_info.push_back(Array<PrimExpr>({f->name, code}));
   }
 
   std::string xclbin;
index 571ec52..0609989 100644 (file)
@@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
 .set_body([](const TVMArgs& args, TVMRetValue* rv){
-    Expr e = args[0];
+    PrimExpr e = args[0];
     const CallNode* call = e.as<CallNode>();
     CHECK(call != nullptr);
 
@@ -66,7 +66,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
 .set_body([](const TVMArgs& args, TVMRetValue* rv){
-    Expr e = args[0];
+    PrimExpr e = args[0];
     const CallNode* call = e.as<CallNode>();
     CHECK(call != nullptr);
 
index a0665bf..56ba225 100644 (file)
@@ -60,7 +60,7 @@ struct Direct {
 // Call pure extern function.
 template<typename T>
 inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
-  Expr e = args[0];
+  PrimExpr e = args[0];
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   std::string name = T()(call->dtype, call->name);
index 397f9d3..fb7abc3 100644 (file)
@@ -231,7 +231,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
 
   const auto *find_rocm_bitcodes =
       tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path");
-  Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
+  Array<PrimExpr> bitcode_files = (*find_rocm_bitcodes)();
 
   for (auto &bitcode : bitcode_files) {
     std::string path = bitcode.as<StringImmNode>()->value;
index fdc1b42..6879fd5 100644 (file)
@@ -42,7 +42,7 @@ class CodeGenARM final : public CodeGenCPU {
   llvm::Value* CreateIntrinsic(const CallNode* op) override;
 
  private:
-  Expr ARMPopcount(const CallNode* op);
+  PrimExpr ARMPopcount(const CallNode* op);
 };
 
 llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
@@ -50,16 +50,16 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
         op->args[0].as<UIntImmNode>()->value);
     if (id == ::llvm::Intrinsic::ctpop) {
-      Expr e = ARMPopcount(op);
+      PrimExpr e = ARMPopcount(op);
       return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
     }
   }
   return CodeGenCPU::CreateIntrinsic(op);
 }
 
-Expr CodeGenARM::ARMPopcount(const CallNode *call) {
+PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) {
   using namespace ir;
-  const Expr& e = call->args[2];
+  const PrimExpr& e = call->args[2];
   ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
   ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu;
 
@@ -67,7 +67,7 @@ Expr CodeGenARM::ARMPopcount(const CallNode *call) {
   int total_size =  call->dtype.bits() * call->dtype.lanes();
   if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
      (total_size != 128 && total_size != 64)) {
-    Array<Expr> vcnt_args;
+    Array<PrimExpr> vcnt_args;
     vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
     vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
     vcnt_args.push_back(e);
@@ -88,41 +88,41 @@ Expr CodeGenARM::ARMPopcount(const CallNode *call) {
       uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32);
 
   // Interpret input as vector of 8bit values
-  Expr input8 = reinterpret(uint8_type, e);
+  PrimExpr input8 = reinterpret(uint8_type, e);
   // Popcount 8bit->8bit
   const CallNode* c0 = input8.as<CallNode>();
   CHECK(c0 != nullptr);
-  Array<Expr> vcnt8_args;
+  Array<PrimExpr> vcnt8_args;
   vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
   vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt8_args.push_back(input8);
-  Expr vcnt8 = ir::CallNode::make(
+  PrimExpr vcnt8 = ir::CallNode::make(
     uint8_type,  "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
 
   // Accumulation 8->16bit
-  Array<Expr> vcnt16_args;
+  Array<PrimExpr> vcnt16_args;
   vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
   vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt16_args.push_back(vcnt8);
-  Expr vcnt16 = ir::CallNode::make(
+  PrimExpr vcnt16 = ir::CallNode::make(
     uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
   if (call->dtype.bits() == 16) {
     return vcnt16;
   }
 
   // Accumulation 16->32bit
-  Array<Expr> vcnt32_args;
+  Array<PrimExpr> vcnt32_args;
   vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
   vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt32_args.push_back(vcnt16);
-  Expr vcnt32 = ir::CallNode::make(
+  PrimExpr vcnt32 = ir::CallNode::make(
     uint32_type,  "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
   if (call->dtype.bits() == 32) {
     return vcnt32;
   }
 
   // Accumulation 32->64bit
-  Array<Expr> vcnt64_args;
+  Array<PrimExpr> vcnt64_args;
   vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
   vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt64_args.push_back(vcnt32);
index 0622269..39d8c7f 100644 (file)
@@ -669,7 +669,7 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
 }
 
 llvm::BasicBlock *
-CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
+CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
                            llvm::Value **ret_tcode, const DataType &r_type,
                            const int64_t begin, const int64_t end) {
   using llvm::BasicBlock;
@@ -923,8 +923,8 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
       CHECK(parallel_env_.num_task.defined());
       CHECK(parallel_env_.penv != nullptr);
       DataType t = op->extent.dtype();
-      Expr num_task = cast(t, parallel_env_.num_task);
-      Expr task_id = cast(t, parallel_env_.task_id);
+      PrimExpr num_task = cast(t, parallel_env_.num_task);
+      PrimExpr task_id = cast(t, parallel_env_.task_id);
       CHECK(!parallel_env_.in_parallel_loop)
           << "Nested parallel loop is not supported by threadpool, try fuse them instead";
       parallel_env_.in_parallel_loop = true;
@@ -935,9 +935,9 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
                         op->loop_var,
                         op->body);
       } else {
-        Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
-        Expr begin = MinNode::make(task_id * step, op->extent);
-        Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
+        PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
+        PrimExpr begin = MinNode::make(task_id * step, op->extent);
+        PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
         CreateSerialFor(MakeValue(begin),
                         MakeValue(end),
                         ConstInt32(1),
index 46f3f96..6ee0085 100644 (file)
@@ -77,8 +77,8 @@ class CodeGenCPU : public CodeGenLLVM {
  private:
   // the parallel group information
   struct ParallelEnv {
-    VarExpr task_id;
-    VarExpr num_task;
+    Var task_id;
+    Var num_task;
     bool stride_pattern{false};
     bool in_parallel_loop{false};
     int parallel_loop_count{0};
@@ -101,7 +101,7 @@ class CodeGenCPU : public CodeGenLLVM {
                          const Array<Var>& fields,
                          std::unordered_map<const VarNode*, llvm::Value*>* vmap);
   // Make packed call.
-  llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
+  llvm::BasicBlock *MakeCallPacked(const Array<PrimExpr> &args,
                                    llvm::Value **rvalue,
                                    llvm::Value **ret_tcode, const DataType &r_type,
                                    const int64_t begin, const int64_t end);
index e2ba19a..c04a023 100644 (file)
@@ -334,7 +334,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const {
 //
 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
                                const VarNode* buffer,
-                               Expr index,
+                               PrimExpr index,
                                DataType type) {
   if (alias_var_set_.count(buffer) != 0) {
     // Mark all possibly aliased pointer as same type.
@@ -389,7 +389,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
 
 void CodeGenLLVM::GetAlignment(DataType t,
                                const VarNode* buf_var,
-                               const Expr& index,
+                               const PrimExpr& index,
                                int* p_alignment,
                                int* p_native_bits) {
   int max_align_bits = t.bits();
@@ -526,7 +526,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
 void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
                                   llvm::Value* end,
                                   llvm::Value* stride,
-                                  const VarExpr& loop_var,
+                                  const Var& loop_var,
                                   const Stmt& body) {
   using llvm::BasicBlock;
   BasicBlock* pre_block = builder_->GetInsertBlock();
@@ -711,7 +711,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
         addrspace = llvm::dyn_cast<llvm::PointerType>(
           ptr->getType())->getAddressSpace();
     } else {
-        Expr index = r->base / make_const(DataType::Int(32), r->lanes);
+        PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes);
         ptr = CreateBufferVecPtr(
           l->dtype, MakeValue(l->buffer_var), MakeValue(index));
         addrspace = llvm::dyn_cast<llvm::PointerType>(
@@ -776,11 +776,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   }
 }
 
-void CodeGenLLVM::Scalarize(const Expr& e,
+void CodeGenLLVM::Scalarize(const PrimExpr& e,
                             std::function<void(int i, llvm::Value* v)> f) {
   if (const RampNode* ramp = e.as<RampNode>()) {
     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
-      Expr offset = ramp->base + (ramp->stride * i);
+      PrimExpr offset = ramp->base + (ramp->stride * i);
       f(i, MakeValue(offset));
     }
   } else {
@@ -988,7 +988,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
     llvm::LoadInst* load = builder_->CreateAlignedLoad(
         ptr, basic_align, is_volatile);
     ret = builder_->CreateInsertElement(ret, load, ConstInt32(i));
-    AddAliasInfo(load, op->buffer_var.get(), Expr(), t);
+    AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t);
   };
   this->Scalarize(op->index, f);
   return ret;
@@ -1084,7 +1084,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
     llvm::StoreInst* store = builder_->CreateAlignedStore(
         builder_->CreateExtractElement(value, i),
         ptr, basic_align, is_volatile);
-    AddAliasInfo(store, op->buffer_var.get(), Expr(), op->value.dtype());
+    AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype());
   };
   this->Scalarize(op->index, f);
 }
index 67ca7c1..34c3ee7 100644 (file)
@@ -47,7 +47,7 @@ using namespace ir;
  * \brief A base class to generate a LLVM.
  */
 class CodeGenLLVM :
-      public ExprFunctor<llvm::Value* (const Expr&)>,
+      public ExprFunctor<llvm::Value* (const PrimExpr&)>,
       public StmtFunctor<void(const Stmt&)> {
  public:
   /*!
@@ -95,7 +95,7 @@ class CodeGenLLVM :
    * \param e The expression to be created value for.
    * \return created value.
    */
-  llvm::Value* MakeValue(const Expr& e) {
+  llvm::Value* MakeValue(const PrimExpr& e) {
     return VisitExpr(e);
   }
   // Short hande code to get a constant int 32
@@ -184,7 +184,7 @@ class CodeGenLLVM :
   virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
   // Scalarize by iterating elements of e.
   // f is a callback that takes index and v.
-  virtual void Scalarize(const Expr& e,
+  virtual void Scalarize(const PrimExpr& e,
                          std::function<void(int i, llvm::Value* v)> f);
   // Initialize target
   virtual void InitTarget(llvm::TargetMachine* tm);
@@ -211,7 +211,7 @@ class CodeGenLLVM :
   void InitFuncState();
   // Get alignment given index.
   void GetAlignment(
-      DataType t, const VarNode* buf_var, const Expr& index,
+      DataType t, const VarNode* buf_var, const PrimExpr& index,
       int* p_alignment, int* p_native_bits);
   // Get constant string
   llvm::Value* GetConstString(const std::string& str);
@@ -243,9 +243,9 @@ class CodeGenLLVM :
   void CreateSerialFor(llvm::Value* begin,
                        llvm::Value* end,
                        llvm::Value* stride,
-                       const VarExpr& loop_var, const Stmt& body);
+                       const Var& loop_var, const Stmt& body);
   // add alias information.
-  void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
+  void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type);
   // The IRBuilder.
   using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
   // The current function
index 10774ec..b05185b 100644 (file)
@@ -63,21 +63,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
-  Expr e = targs[0];
+  PrimExpr e = targs[0];
   const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
-  const Expr& x = call->args[0];
-  Expr one = make_const(x.dtype(), 1);
-  Expr two = make_const(x.dtype(), 2);
-  Expr neg_two = make_const(x.dtype(), -2);
+  const PrimExpr& x = call->args[0];
+  PrimExpr one = make_const(x.dtype(), 1);
+  PrimExpr two = make_const(x.dtype(), 2);
+  PrimExpr neg_two = make_const(x.dtype(), -2);
 
-  Expr exp_neg2x = ir::CallNode::make(
+  PrimExpr exp_neg2x = ir::CallNode::make(
       x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
-  Expr exp_pos2x = ir::CallNode::make(
+  PrimExpr exp_pos2x = ir::CallNode::make(
       x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
 
-  Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
-  Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
+  PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
+  PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
   *rv = ir::SelectNode::make(
       x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
 });
index a870385..b3ab557 100644 (file)
@@ -38,15 +38,15 @@ namespace codegen {
 // num_signature means number of arguments used to query signature
 template<unsigned id, int num_signature>
 inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
-  Expr e = targs[0];
+  PrimExpr e = targs[0];
   const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
-  Array<Expr> cargs;
+  Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
   cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
 
-  for (Expr arg : call->args) {
+  for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
   *rv = ir::CallNode::make(
@@ -55,14 +55,14 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
 
 template<unsigned id, int num_signature>
 inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
-  Expr e = targs[0];
+  PrimExpr e = targs[0];
   const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
-  Array<Expr> cargs;
+  Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
   cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
-  for (Expr arg : call->args) {
+  for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
   *rv = ir::CallNode::make(
index 00824bb..fcd8a1a 100644 (file)
@@ -33,7 +33,7 @@ namespace tvm {
 namespace codegen {
 
 inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
-  Expr e = args[0];
+  PrimExpr e = args[0];
   using namespace ir;
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
index 09de88f..41035af 100644 (file)
@@ -33,7 +33,7 @@ namespace tvm {
 namespace codegen {
 
 inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
-  Expr e = args[0];
+  PrimExpr e = args[0];
   using namespace ir;
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
index 254e436..a749424 100644 (file)
@@ -90,7 +90,7 @@ void CodeGenSPIRV::InitFuncState() {
 }
 
 spirv::Value CodeGenSPIRV::GetThreadIndex(
-    const IterVar& iv, const Expr& extent) {
+    const IterVar& iv, const PrimExpr& extent) {
   runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
   spirv::Value v;
   if (ts.rank == 1) {
@@ -403,7 +403,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
           CHECK((me->coeff % ramp->lanes) == 0 &&
                 (me->base % ramp->lanes)  == 0)
               << "Only aligned vector access is allowed in SPIRV";
-          Expr vec_index = ir::Simplify(
+          PrimExpr vec_index = ir::Simplify(
               ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
           spirv::Value ptr = builder_->StructArrayAccess(
               ptr_type, buffer, MakeValue(vec_index));
@@ -417,11 +417,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
   return spirv::Value();
 }
 
-void CodeGenSPIRV::Scalarize(const Expr& e,
+void CodeGenSPIRV::Scalarize(const PrimExpr& e,
                              std::function<void(int i, spirv::Value v)> f) {
   if (const RampNode* ramp = e.as<RampNode>()) {
     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
-      Expr offset = ramp->base + ramp->stride * i;
+      PrimExpr offset = ramp->base + ramp->stride * i;
       f(i, MakeValue(offset));
     }
   } else {
@@ -481,7 +481,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
           CHECK((me->coeff % ramp->lanes) == 0 &&
                 (me->base % ramp->lanes)  == 0)
               << "Only aligned vector access is allowed in SPIRV";
-          Expr vec_index = ir::Simplify(
+          PrimExpr vec_index = ir::Simplify(
               ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
           spirv::Value ptr = builder_->StructArrayAccess(
               ptr_type, buffer, MakeValue(vec_index));
index b72cd5b..3804bda 100644 (file)
@@ -44,7 +44,7 @@ using namespace ir;
  * \brief Code generator into SPIRV
  */
 class CodeGenSPIRV:
-      public ExprFunctor<spirv::Value(const Expr&)>,
+      public ExprFunctor<spirv::Value(const PrimExpr&)>,
       public StmtFunctor<void(const Stmt&)> {
  public:
   /*!
@@ -58,7 +58,7 @@ class CodeGenSPIRV:
    * \param e The expression to be created value for.
    * \return created value.
    */
-  spirv::Value MakeValue(const Expr& e) {
+  spirv::Value MakeValue(const PrimExpr& e) {
     return VisitExpr(e);
   }
   // override codegen
@@ -128,9 +128,9 @@ class CodeGenSPIRV:
   // Reset the state so it works for a new function.
   void InitFuncState();
   // Get the thread index
-  spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
+  spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
   spirv::Value CreateStorageSync(const CallNode* op);
-  void Scalarize(const Expr& e,
+  void Scalarize(const PrimExpr& e,
                  std::function<void(int i, spirv::Value v)> f);
   // The builder
   std::unique_ptr<spirv::IRBuilder> builder_;
index 69d2014..d41d96d 100644 (file)
@@ -34,14 +34,14 @@ using namespace runtime;
 // num_signature means number of arguments used to query signature
 template<unsigned id>
 inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
-  Expr e = targs[0];
+  PrimExpr e = targs[0];
   const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
-  Array<Expr> cargs;
+  Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
 
-  for (Expr arg : call->args) {
+  for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
   *rv = ir::CallNode::make(
index 3da083b..eccff6c 100644 (file)
@@ -244,8 +244,8 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
 }
 
 void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
-                                const Expr& a,
-                                const Expr& b) {
+                                const PrimExpr& a,
+                                const PrimExpr& b) {
   this->Push(a);
   this->Push(b);
   DataType t = a.dtype();
index 1f00ecc..07989b2 100644 (file)
@@ -46,7 +46,7 @@ using runtime::StackVM;
  *  into device function when only device JIT is available.
  */
 class CodeGenStackVM
-    : public ExprFunctor<void(const Expr&)>,
+    : public ExprFunctor<void(const PrimExpr&)>,
       public StmtFunctor<void(const Stmt&)> {
  public:
  /*!
@@ -60,7 +60,7 @@ class CodeGenStackVM
   /*! \brief Push stmt to generate new code */
   void Push(const Stmt& n);
   /*! \brief Push expr to generate new code */
-  void Push(const Expr& n) {
+  void Push(const PrimExpr& n) {
     VisitExpr(n);
   }
   /*!
@@ -105,8 +105,8 @@ class CodeGenStackVM
   int GetVarID(const VarNode* v) const;
   // Push binary operator
   void PushBinary(StackVM::OpCode op_int64,
-                  const Expr& a,
-                  const Expr& b);
+                  const PrimExpr& a,
+                  const PrimExpr& b);
   // push cast;
   void PushCast(DataType dst, DataType src);
   // overloadable functions
index 01696b2..89a1ece 100644 (file)
@@ -46,7 +46,7 @@ using namespace ir;
  * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
  */
 class CodeGenHybrid :
-      public ExprFunctor<void(const Expr&, std::ostream&)>,
+      public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
       public StmtFunctor<void(const Stmt&)> {
  public:
   /*!
@@ -77,14 +77,14 @@ class CodeGenHybrid :
    * \param n The expression to be printed.
    * \param os The output stream
    */
-  void PrintExpr(const Expr &n, std::ostream &os) {
+  void PrintExpr(const PrimExpr &n, std::ostream &os) {
     this->VisitExpr(n, os);
   }
   /*!
    * \brief Same as PrintExpr, but simply returns result string
    * \param n The expression to be printed.
    */
-  std::string PrintExpr(const Expr &n) {
+  std::string PrintExpr(const PrimExpr &n) {
     std::ostringstream os;
     PrintExpr(n, os);
     return os.str();
index 6264e0f..1d3e767 100644 (file)
@@ -44,9 +44,9 @@ void DictAttrsNode::InitByPackedArgs(
     if (val.IsObjectRef<ObjectRef>()) {
       dict.Set(key, val.operator ObjectRef());
     } else if (val.type_code() == kStr) {
-      dict.Set(key, Expr(val.operator std::string()));
+      dict.Set(key, PrimExpr(val.operator std::string()));
     } else {
-      dict.Set(key, val.operator Expr());
+      dict.Set(key, val.operator PrimExpr());
     }
   }
 }
index d96033d..925e3db 100644 (file)
@@ -34,22 +34,22 @@ namespace tvm {
 using IndexMod = ir::FloorModNode;
 using IndexDiv = ir::FloorDivNode;
 
-Array<Expr> SimplifyArray(Array<Expr> array) {
+Array<PrimExpr> SimplifyArray(Array<PrimExpr> array) {
   for (size_t i = 0; i < array.size(); ++i) {
     array.Set(i, ir::Simplify(array[i]));
   }
   return array;
 }
 
-Buffer decl_buffer(Array<Expr> shape,
+Buffer decl_buffer(Array<PrimExpr> shape,
                    DataType dtype,
                    std::string name) {
   return BufferNode::make(
       Var(name, DataType::Handle()),
       dtype,
       shape,
-      Array<Expr>(),
-      Expr(),
+      Array<PrimExpr>(),
+      PrimExpr(),
       name,
       "",
       0, 0,
@@ -57,13 +57,13 @@ Buffer decl_buffer(Array<Expr> shape,
 }
 
 // Split the given expression w.r.t the add operator
-inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
+inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr &expr) {
   using namespace ir;
-  std::vector<const Expr*> ret;
-  std::stack<const Expr*> split_buffer;
+  std::vector<const PrimExpr*> ret;
+  std::stack<const PrimExpr*> split_buffer;
   split_buffer.push(&expr);
   while (!split_buffer.empty()) {
-    const Expr* top_ele = split_buffer.top();
+    const PrimExpr* top_ele = split_buffer.top();
     split_buffer.pop();
     auto expr_add_match = top_ele->as<AddNode>();
     if (expr_add_match) {
@@ -84,14 +84,14 @@ inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
 // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
 // Currently the we will not search the add/mult combinations exhaustively
 //   as it will take too much computation.
-inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
-                                              const Expr &mod_l_expr,
-                                              const Expr &mod_r_expr) {
+inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr &mult_expr,
+                                              const PrimExpr &mod_l_expr,
+                                              const PrimExpr &mod_r_expr) {
   using namespace ir;
   const MulNode* mult_ptr = mult_expr.as<MulNode>();
-  if (!mult_ptr) return std::make_pair(false, Expr());
-  Expr mult_outer = mult_ptr->b;
-  const Expr* inner = &(mult_ptr->a);
+  if (!mult_ptr) return std::make_pair(false, PrimExpr());
+  PrimExpr mult_outer = mult_ptr->b;
+  const PrimExpr* inner = &(mult_ptr->a);
   // 1. Calculate the outer multiplier
   while (true) {
     mult_ptr = inner->as<MulNode>();
@@ -108,32 +108,32 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
   //   If Mult is found, we will expand the inner multiplication factor
   //   If Div is found, we will go on testing whether lhs matches the lhs of mod expr
   //      and returns the optimization result.
-  const Expr* search_ptr = inner;
-  Expr mult_inner;  // The inner multiplication factor
-  Expr no_opt_sum;  // Sum of the exprs that cannot be optimized
+  const PrimExpr* search_ptr = inner;
+  PrimExpr mult_inner;  // The inner multiplication factor
+  PrimExpr no_opt_sum;  // Sum of the exprs that cannot be optimized
   while (true) {
     auto inner_div_ptr = search_ptr->as<IndexDiv>();
     auto inner_mult_ptr = search_ptr->as<MulNode>();
     auto inner_add_ptr = search_ptr->as<AddNode>();
     if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
-      return std::make_pair(false, Expr());
+      return std::make_pair(false, PrimExpr());
     } else if (inner_div_ptr) {
-      Expr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
+      PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer;
       if (Equal(overall_mult, inner_div_ptr->b)
           && Equal(overall_mult, mod_r_expr)
           && Equal(inner_div_ptr->a, mod_l_expr)) {
         // Found!
-        Expr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
+        PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
         return std::make_pair(true, ret);
       } else {
-        return std::make_pair(false, Expr());
+        return std::make_pair(false, PrimExpr());
       }
     } else if (inner_mult_ptr) {
       mult_inner = mult_inner.get() ? inner_mult_ptr->b * mult_inner : inner_mult_ptr->b;
       search_ptr = &(inner_mult_ptr->a);
     } else if (inner_add_ptr) {
       if (mult_inner.get()) {
-        return std::make_pair(false, Expr());
+        return std::make_pair(false, PrimExpr());
       }
       no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a;
       search_ptr = &(inner_add_ptr->b);
@@ -142,23 +142,23 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
       break;
     }
   }
-  return std::make_pair(false, Expr());
+  return std::make_pair(false, PrimExpr());
 }
 
 // Insert the elements into the corresponding mult_exprs and mod_exprs.
 // If the element is found to match Mul, it will be pushed to the mult_exprs.
 // If the element it found to match Mod, it will be pused to the mod_exprs.
 // Otherwise, the elements will be added to the no_opt_sum variable
-inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
-                                      std::list<Expr>* mult_exprs,
-                                      std::list<std::pair<Expr, Expr> >* mod_exprs,
-                                      Expr* no_opt_sum,
+inline void MergeMulModInsertElements(const std::vector<const PrimExpr*>& eles,
+                                      std::list<PrimExpr>* mult_exprs,
+                                      std::list<std::pair<PrimExpr, PrimExpr> >* mod_exprs,
+                                      PrimExpr* no_opt_sum,
                                       bool* has_mult,
                                       bool* has_mod) {
   using namespace ir;
   *has_mult = false;
   *has_mod = false;
-  for (const Expr* ele : eles) {
+  for (const PrimExpr* ele : eles) {
     auto mod_ptr = ele->as<IndexMod>();
     auto mult_ptr = ele->as<MulNode>();
     if (mod_ptr) {
@@ -180,30 +180,30 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
 // The search will be performed repeatively until no pattern is found.
 // Return: a pair with (false, Expr()) if cannot be optimized.
 //         a pair with (true, optimized_expr) if can be optimized
-inline Expr MergeMulMod(const Expr &base) {
+inline PrimExpr MergeMulMod(const PrimExpr &base) {
   using namespace ir;
   // 1. Prepare the lists.
   // We store two lists, a list that contain all the elements that match Mul and
   //                     a list that contain all the elements that match Mod.
   // The elements in the Mod will be used to match against the elements in Mul.
   // The result will then be split and pushed back to these two lists.
-  Expr simplified_base = Simplify(base);
-  std::vector<const Expr*> eles = ExprSplitAddition(simplified_base);
-  std::list<Expr> mult_exprs;
-  std::list<std::pair<Expr, Expr> > mod_exprs;
-  Expr no_opt_sum;
+  PrimExpr simplified_base = Simplify(base);
+  std::vector<const PrimExpr*> eles = ExprSplitAddition(simplified_base);
+  std::list<PrimExpr> mult_exprs;
+  std::list<std::pair<PrimExpr, PrimExpr> > mod_exprs;
+  PrimExpr no_opt_sum;
   bool has_mult;
   bool has_mod;
   MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs,
                             &no_opt_sum, &has_mult, &has_mod);
   bool find_opt = false;
-  std::list<std::pair<Expr, Expr> >::iterator search_mod_it = mod_exprs.begin();
+  std::list<std::pair<PrimExpr, PrimExpr> >::iterator search_mod_it = mod_exprs.begin();
   // 2. Exhaustive Search
   while (search_mod_it != mod_exprs.end()) {
-    std::list<Expr>::iterator mult_it = mult_exprs.begin();
+    std::list<PrimExpr>::iterator mult_it = mult_exprs.begin();
     bool inner_find_opt = false;
     while (mult_it != mult_exprs.end()) {
-      std::pair<bool, Expr> ret = MergeMulModInner(*mult_it,
+      std::pair<bool, PrimExpr> ret = MergeMulModInner(*mult_it,
                                                    search_mod_it->first,
                                                    search_mod_it->second);
       if (ret.first) {
@@ -212,7 +212,7 @@ inline Expr MergeMulMod(const Expr &base) {
         ++search_mod_it;
         mod_exprs.erase(temp_mod_it);
         mult_exprs.erase(mult_it);
-        std::vector<const Expr*> ret_eles = ExprSplitAddition(ret.second);
+        std::vector<const PrimExpr*> ret_eles = ExprSplitAddition(ret.second);
         MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs,
                                   &no_opt_sum, &has_mult, &has_mod);
         if (has_mult) {
@@ -233,10 +233,10 @@ inline Expr MergeMulMod(const Expr &base) {
   if (!find_opt) {
     return simplified_base;
   }
-  for (std::list<Expr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
+  for (std::list<PrimExpr>::iterator it = mult_exprs.begin(); it != mult_exprs.end(); ++it) {
     no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it;
   }
-  for (std::list<std::pair<Expr, Expr> >::iterator it = mod_exprs.begin();
+  for (std::list<std::pair<PrimExpr, PrimExpr> >::iterator it = mod_exprs.begin();
                                                    it != mod_exprs.end(); ++it) {
     no_opt_sum = no_opt_sum.get() ?
         no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second);
@@ -247,8 +247,8 @@ inline Expr MergeMulMod(const Expr &base) {
 // The buffer offset in convention of number of elements of
 // original data ignoring number of lanes.
 // We also perform optimization to simplify the indexing expression.
-inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
-  Expr base = n->elem_offset;
+inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
+  PrimExpr base = n->elem_offset;
   if (n->strides.size() == 0) {
     // Scalar case
     if (n->shape.size() == 0 && index.size() == 1) {
@@ -258,7 +258,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
     } else {
       CHECK_EQ(n->shape.size(), index.size());
       if (index.size() > 0) {
-        Expr offset = index[0];
+        PrimExpr offset = index[0];
         for (size_t i = 1; i < index.size(); ++i) {
           offset = MergeMulMod(offset * n->shape[i] + index[i]);
         }
@@ -279,8 +279,8 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
   return base;
 }
 
-inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype) {
-  Expr offset = ElemOffset(n, index);
+inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
+  PrimExpr offset = ElemOffset(n, index);
   if (n->dtype.lanes() != 1) {
     offset = offset * make_const(offset.dtype(), dtype.lanes());
   }
@@ -291,7 +291,7 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype)
   }
 }
 
-Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
+PrimExpr Buffer::vload(Array<PrimExpr> begin, DataType dtype) const {
   // specially handle bool, stored asDataType::Int(8)
   const BufferNode* n = operator->();
   CHECK(dtype.element_of() == n->dtype.element_of() &&
@@ -311,7 +311,7 @@ Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
   }
 }
 
-Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
+Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
   // specially handle bool, stored asDataType::Int(8)
   const BufferNode* n = operator->();
   DataType dtype = value.dtype();
@@ -333,9 +333,9 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
 Buffer Buffer::MakeStrideView() const {
   if ((*this)->strides.size() != 0) return *this;
   if ((*this)->shape.size() == 0) return *this;
-  std::vector<Expr> temp;
+  std::vector<PrimExpr> temp;
   auto n = make_object<BufferNode>(*operator->());
-  Expr acc = make_const(n->DefaultIndexType(), 1);
+  PrimExpr acc = make_const(n->DefaultIndexType(), 1);
   for (size_t i = n->shape.size(); i != 0 ; --i) {
     temp.push_back(acc);
     acc = acc * n->shape[i - 1];
@@ -346,11 +346,11 @@ Buffer Buffer::MakeStrideView() const {
   return Buffer(n);
 }
 
-Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
+Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const {
   const BufferNode* n = operator->();
   begins = SimplifyArray(begins);
-  Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
-  Array<Expr> strides = n->strides;
+  PrimExpr elem_offset = ir::Simplify(ElemOffset(n, begins));
+  Array<PrimExpr> strides = n->strides;
   if (strides.size() == 0) {
     bool can_relax = true;
     bool need_stride = false;
@@ -381,19 +381,22 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
                           n->buffer_type);
 }
 
-Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, Expr offset) const {
+PrimExpr Buffer::access_ptr(int access_mask,
+                            DataType ptr_type,
+                            int content_lanes,
+                            PrimExpr offset) const {
   const BufferNode* self = operator->();
-  Expr e_dtype;
-  Expr extent;
+  PrimExpr e_dtype;
+  PrimExpr extent;
   if (self->shape.size() == 0) {
     extent = make_const(self->DefaultIndexType(), 1);
   } else if (self->strides.size() == self->shape.size()) {
     int highest_dim = 0;
     extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
   } else {
-    extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
+    extent = arith::ComputeReduce<ir::MulNode>(self->shape, PrimExpr()) - offset;
   }
-  Expr elem_offset = self->elem_offset + offset;
+  PrimExpr elem_offset = self->elem_offset + offset;
   if (content_lanes > 1) {
     e_dtype = ir::TypeAnnotation(self->dtype.with_lanes(content_lanes));
     extent = extent / make_const(self->elem_offset.dtype(), content_lanes);
@@ -402,7 +405,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
   } else {
     e_dtype = ir::TypeAnnotation(self->dtype);
   }
-  Array<Expr> acc_args{
+  Array<PrimExpr> acc_args{
     e_dtype, self->data, elem_offset,
         extent, make_const(DataType::Int(32), access_mask)};
   return ir::CallNode::make(
@@ -411,9 +414,9 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
 
 Buffer BufferNode::make(Var data,
                         DataType dtype,
-                        Array<Expr> shape,
-                        Array<Expr> strides,
-                        Expr elem_offset,
+                        Array<PrimExpr> shape,
+                        Array<PrimExpr> strides,
+                        PrimExpr elem_offset,
                         std::string name,
                         std::string scope,
                         int data_alignment,
index c30f344..ba5e4ad 100644 (file)
@@ -103,13 +103,13 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
                           << " before dimension " << c;
       std::string shape_name("_shape");
       shape_name.insert(0, 1, c);
-      IterVar axis = IterVarNode::make(Range(Expr(0), Var(shape_name)),
+      IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)),
                                        Var(std::string(1, c)), kDataPar);
       node->axes.push_back(axis);
     } else if (c >= 'a' && c <= 'z') {
       CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size "
                           << factor << " for dimension " << c;
-      IterVar axis = IterVarNode::make(Range(Expr(0), Expr(factor)),
+      IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)),
                                        Var(std::string(1, c)), kDataPar);
       node->axes.push_back(axis);
       factor = 0;
@@ -171,7 +171,7 @@ Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor)
   Array<IterVar> new_layout;
   for (size_t i = 0; i <= this->ndim(); ++i) {
     if (i == target_pos) {
-      new_layout.push_back(IterVarNode::make(Range(Expr(0), Expr(factor)),
+      new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)),
                                              Var(axis.ToSubordinate().name()), kDataPar));
     }
     if (i == this->ndim()) break;
@@ -200,7 +200,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
     p->stream << "Layout(" << l->name << ")";
   });
 
-inline bool GetStoreRule(Array<Expr>* rule,
+inline bool GetStoreRule(Array<PrimExpr>* rule,
                          const Layout& src_layout,
                          const Layout& dst_layout) {
   if (!src_layout.defined() || src_layout.name().empty() ||
@@ -210,17 +210,17 @@ inline bool GetStoreRule(Array<Expr>* rule,
   for (size_t i = 0; i < dst_layout.ndim(); ++i) {
     const auto& store_axis = dst_layout[i];
     const IterVar& store_axis_impl = dst_layout->axes[i];
-    Expr store(0);
+    PrimExpr store(0);
 
     for (size_t j = 0; j < src_layout.ndim(); ++j) {
       const auto& orig_axis = src_layout[j];
       const IterVar& orig_axis_impl = src_layout->axes[j];
       if (store_axis.ToPrimal() == orig_axis.ToPrimal()) {
         if (orig_axis.IsPrimal()) {
-          Expr orig_var = orig_axis_impl->var;
+          PrimExpr orig_var = orig_axis_impl->var;
           const int32_t factor = src_layout.FactorOf(orig_axis);
           if (factor > 0) {
-            orig_var = orig_var * Expr(factor);
+            orig_var = orig_var * PrimExpr(factor);
           }
           store = store + orig_var;
         } else {
@@ -236,7 +236,7 @@ inline bool GetStoreRule(Array<Expr>* rule,
     if (store_axis.IsPrimal()) {
       const int32_t factor = dst_layout.FactorOf(store_axis);
       if (factor > 0) {
-        store = indexdiv(store, Expr(factor));
+        store = indexdiv(store, PrimExpr(factor));
       }
     } else {
       store = indexmod(store, store_axis_impl->dom->extent);
@@ -247,21 +247,21 @@ inline bool GetStoreRule(Array<Expr>* rule,
   return true;
 }
 
-inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
+inline Array<PrimExpr> TransformIndex(const Array<PrimExpr>& src_index,
                                   const Array<IterVar>& src_axis,
-                                  const Array<Expr>& transform_rule) {
-  Array<Expr> result;
-  std::unordered_map<const VarNode*, Expr> bind_map;
+                                  const Array<PrimExpr>& transform_rule) {
+  Array<PrimExpr> result;
+  std::unordered_map<const VarNode*, PrimExpr> bind_map;
   for (size_t i = 0; i < src_index.size(); ++i) {
     bind_map[src_axis[i]->var.get()] = src_index[i];
   }
-  for (Expr rule : transform_rule) {
+  for (PrimExpr rule : transform_rule) {
     result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
   }
   return result;
 }
 
-Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
+Array<PrimExpr> BijectiveLayout::ForwardIndex(const Array<PrimExpr>& src_index) const {
   CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
   const BijectiveLayoutNode* self = operator->();
   CHECK_EQ(src_index.size(), self->src_layout->axes.size())
@@ -270,7 +270,7 @@ Array<Expr> BijectiveLayout::ForwardIndex(const Array<Expr>& src_index) const {
 }
 
 
-Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
+Array<PrimExpr> BijectiveLayout::BackwardIndex(const Array<PrimExpr>& dst_index) const {
   CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
   const BijectiveLayoutNode* self = operator->();
   CHECK_EQ(dst_index.size(), self->dst_layout->axes.size())
@@ -278,19 +278,19 @@ Array<Expr> BijectiveLayout::BackwardIndex(const Array<Expr>& dst_index) const {
   return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule);
 }
 
-inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
+inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
                                   const Array<IterVar>& src_axis,
                                   const Array<IterVar>& target_axis,
-                                  const Array<Expr>& transform_rule) {
+                                  const Array<PrimExpr>& transform_rule) {
   CHECK_EQ(src_shape.size(), src_axis.size());
   // bind variables for original axes
   // for major-axis, bind the corresponding size
   // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
   // e.g., (C * 16 + c) / 32
-  std::unordered_map<const VarNode*, Expr> bind_map;
+  std::unordered_map<const VarNode*, PrimExpr> bind_map;
   std::unordered_set<size_t> symbolic_var_set;
   for (size_t i = 0; i < src_shape.size(); ++i) {
-    Expr orig_shape = src_shape[i];
+    PrimExpr orig_shape = src_shape[i];
     IterVar orig_axis = src_axis[i];
     if (orig_shape.as<ir::AnyNode>()) {
       symbolic_var_set.insert(i);
@@ -305,7 +305,7 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
             << orig_axis->dom->extent << ", get " << orig_shape;
         }
       }
-      bind_map[orig_axis->var.get()] = Expr(0);
+      bind_map[orig_axis->var.get()] = PrimExpr(0);
     } else {
       bind_map[orig_axis->var.get()] = orig_shape;
     }
@@ -313,10 +313,10 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
   // infer the target shape,
   // for major-axis, use the forward/backward_rule directly,
   // for minor-axis, simply use the extent.
-  Array<Expr> result;
+  Array<PrimExpr> result;
   CHECK_EQ(transform_rule.size(), target_axis.size());
   for (size_t i = 0; i < transform_rule.size(); ++i) {
-    Expr rule = transform_rule[i];
+    PrimExpr rule = transform_rule[i];
     IterVar axis = target_axis[i];
     if (!LayoutAxis::Get(axis).IsPrimal()) {
       result.push_back(axis->dom->extent);
@@ -331,14 +331,14 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
   return result;
 }
 
-Array<Expr> BijectiveLayout::ForwardShape(const Array<Expr>& shape) const {
+Array<PrimExpr> BijectiveLayout::ForwardShape(const Array<PrimExpr>& shape) const {
   CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
   const BijectiveLayoutNode* self = operator->();
   return TransformShape(shape, self->src_layout->axes,
                         self->dst_layout->axes, self->forward_rule);
 }
 
-Array<Expr> BijectiveLayout::BackwardShape(const Array<Expr>& shape) const {
+Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) const {
   CHECK(defined()) << "Cannot operate on an undefined bijective layout.";
   const BijectiveLayoutNode* self = operator->();
   return TransformShape(shape, self->dst_layout->axes,
index 58a97ed..a728936 100644 (file)
 
 namespace tvm {
 
-Expr::Expr(int32_t value)
-    : Expr(IntImmNode::make(DataType::Int(32), value)) {}
+PrimExpr::PrimExpr(int32_t value)
+    : PrimExpr(IntImmNode::make(DataType::Int(32), value)) {}
 
-Expr::Expr(float value)
-    : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
+PrimExpr::PrimExpr(float value)
+    : PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
 
-Expr::Expr(std::string str)
-    : Expr(ir::StringImmNode::make(str)) {}
+PrimExpr::PrimExpr(std::string str)
+    : PrimExpr(ir::StringImmNode::make(str)) {}
 
 Var::Var(std::string name_hint, DataType t)
     : Var(VarNode::make(t, name_hint)) {}
@@ -48,7 +48,7 @@ Var VarNode::make(DataType t, std::string name_hint) {
   return Var(node);
 }
 
-Range::Range(Expr begin, Expr end)
+Range::Range(PrimExpr begin, PrimExpr end)
     : Range(make_object<RangeNode>(
           begin,
           is_zero(begin) ? end : (end - begin))) {
@@ -63,7 +63,7 @@ Integer IntImmNode::make(DataType t, int64_t value) {
   return Integer(node);
 }
 
-Range Range::make_by_min_extent(Expr min, Expr extent) {
+Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
   return Range(make_object<RangeNode>(min, extent));
 }
 
index 34fac72..078ca62 100644 (file)
 namespace tvm {
 
 // simple cast that only checks if type matches and cast
-inline Expr SimpleCast(const DataType& t, Expr value) {
+inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
   if (value.dtype() == t) return value;
   return ir::CastNode::make(t, value);
 }
 
 // The public function with a quick checking path.
-void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) {  // NOLINT(*)
+void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) {  // NOLINT(*)
   if (lhs.dtype() == rhs.dtype()) return;
   DataType ltype = lhs.dtype();
   DataType rtype = rhs.dtype();
@@ -80,7 +80,7 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) {  // NOLINT(*)
 
 
 // maximum and min limits
-Expr max_value(const DataType& dtype) {
+PrimExpr max_value(const DataType& dtype) {
   using namespace ir;
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
@@ -109,10 +109,10 @@ Expr max_value(const DataType& dtype) {
     }
   }
   LOG(FATAL) << "Cannot decide max_value for type" << dtype;
-  return Expr();
+  return PrimExpr();
 }
 
-Expr min_value(const DataType& dtype) {
+PrimExpr min_value(const DataType& dtype) {
   using namespace ir;
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
@@ -135,7 +135,7 @@ Expr min_value(const DataType& dtype) {
     }
   }
   LOG(FATAL) << "Cannot decide min_value for type" << dtype;
-  return Expr();
+  return PrimExpr();
 }
 
 template<typename ValueType>
@@ -152,7 +152,7 @@ inline bool ConstPowerHelper(ValueType val, int *shift) {
   return true;
 }
 
-bool is_const_power_of_two_integer(const Expr& x, int* shift) {
+bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
   if (const auto* op = x.as<ir::IntImmNode>()) {
     return ConstPowerHelper(op->value, shift);
   } else if (const auto* op = x.as<ir::UIntImmNode>()) {
@@ -162,7 +162,7 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
   }
 }
 
-Expr cast(const DataType& t, Expr value) {
+PrimExpr cast(const DataType& t, PrimExpr value) {
   using ir::IntImmNode;
   using ir::UIntImmNode;
   using ir::FloatImmNode;
@@ -200,21 +200,21 @@ Expr cast(const DataType& t, Expr value) {
   }
 }
 
-Expr reinterpret(const DataType& t, Expr value) {
+PrimExpr reinterpret(const DataType& t, PrimExpr value) {
   if (value.dtype() == t) return value;
   return ir::CallNode::make(
     t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator+(Expr a, Expr b) {
+PrimExpr operator+(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::AddNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::AddNode>(a, b);
   if (ret.defined()) return ret;
   return ir::AddNode::make(a, b);
 }
 
 // negation
-Expr operator-(Expr a) {
+PrimExpr operator-(PrimExpr a) {
   using ir::IntImmNode;
   using ir::FloatImmNode;
   const IntImmNode* pa = a.as<IntImmNode>();
@@ -224,76 +224,76 @@ Expr operator-(Expr a) {
   return make_zero(a.dtype()) - a;
 }
 
-Expr operator-(Expr a, Expr b) {
+PrimExpr operator-(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::SubNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::SubNode>(a, b);
   if (ret.defined()) return ret;
   return ir::SubNode::make(a, b);
 }
 
-Expr operator*(Expr a, Expr b) {
+PrimExpr operator*(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::MulNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::MulNode>(a, b);
   if (ret.defined()) return ret;
   return ir::MulNode::make(a, b);
 }
 
-Expr div(Expr a, Expr b) {
+PrimExpr div(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::DivNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::DivNode>(a, b);
   if (ret.defined()) return ret;
   return ir::DivNode::make(a, b);
 }
 
-Expr truncdiv(Expr a, Expr b) {
+PrimExpr truncdiv(PrimExpr a, PrimExpr b) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   CHECK(b.dtype().is_int() || b.dtype().is_uint());
   return div(a, b);
 }
 
-Expr truncmod(Expr a, Expr b) {
+PrimExpr truncmod(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::ModNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::ModNode>(a, b);
   if (ret.defined()) return ret;
   return ir::ModNode::make(a, b);
 }
 
-Expr operator/(Expr a, Expr b) {
+PrimExpr operator/(PrimExpr a, PrimExpr b) {
   return div(a, b);
 }
 
-Expr operator%(Expr a, Expr b) {
+PrimExpr operator%(PrimExpr a, PrimExpr b) {
   return truncmod(a, b);
 }
 
 // TODO(tqchen): switch to floordiv
-Expr indexdiv(Expr a, Expr b) {
+PrimExpr indexdiv(PrimExpr a, PrimExpr b) {
   return floordiv(a, b);
 }
 
-Expr indexmod(Expr a, Expr b) {
+PrimExpr indexmod(PrimExpr a, PrimExpr b) {
   return floormod(a, b);
 }
 
-Expr floordiv(Expr a, Expr b) {
+PrimExpr floordiv(PrimExpr a, PrimExpr b) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   CHECK(b.dtype().is_int() || b.dtype().is_uint());
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
   if (ret.defined()) return ret;
   return ir::FloorDivNode::make(a, b);
 }
 
-Expr floormod(Expr a, Expr b) {
+PrimExpr floormod(PrimExpr a, PrimExpr b) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   CHECK(b.dtype().is_int() || b.dtype().is_uint());
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
   if (ret.defined()) return ret;
   return ir::FloorModNode::make(a, b);
 }
 
-Expr min(Expr a, Expr b) {
+PrimExpr min(PrimExpr a, PrimExpr b) {
   // inf-aware simplificaiton
   using arith::is_pos_inf;
   using arith::is_neg_inf;
@@ -302,12 +302,12 @@ Expr min(Expr a, Expr b) {
   if (is_pos_inf(b)) return a;
   if (is_neg_inf(b)) return b;
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::MinNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::MinNode>(a, b);
   if (ret.defined()) return ret;
   return ir::MinNode::make(a, b);
 }
 
-Expr max(Expr a, Expr b) {
+PrimExpr max(PrimExpr a, PrimExpr b) {
   // inf-aware simplificaiton
   using arith::is_pos_inf;
   using arith::is_neg_inf;
@@ -316,12 +316,12 @@ Expr max(Expr a, Expr b) {
   if (is_pos_inf(b)) return b;
   if (is_neg_inf(b)) return a;
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::MaxNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::MaxNode>(a, b);
   if (ret.defined()) return ret;
   return ir::MaxNode::make(a, b);
 }
 
-Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
+PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
   using ir::IntImmNode;
   using ir::UIntImmNode;
   CHECK(cond.dtype() == DataType::Bool(1))
@@ -347,7 +347,7 @@ Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
       ir::CallNode::PureIntrinsic);
 }
 
-Expr likely(Expr cond) {
+PrimExpr likely(PrimExpr cond) {
   if (is_const(cond)) return cond;
   return ir::CallNode::make(cond.dtype(),
                             ir::CallNode::likely,
@@ -355,72 +355,72 @@ Expr likely(Expr cond) {
                             ir::CallNode::PureIntrinsic);
 }
 
-Expr operator>(Expr a, Expr b) {
+PrimExpr operator>(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::GTNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::GTNode>(a, b);
   if (ret.defined()) return ret;
   return ir::GTNode::make(a, b);
 }
 
-Expr operator>=(Expr a, Expr b) {
+PrimExpr operator>=(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::GENode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::GENode>(a, b);
   if (ret.defined()) return ret;
   return ir::GENode::make(a, b);
 }
 
-Expr operator<(Expr a, Expr b) {
+PrimExpr operator<(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::LTNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::LTNode>(a, b);
   if (ret.defined()) return ret;
   return ir::LTNode::make(a, b);
 }
 
-Expr operator<=(Expr a, Expr b) {
+PrimExpr operator<=(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::LENode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::LENode>(a, b);
   if (ret.defined()) return ret;
   return ir::LENode::make(a, b);
 }
 
-Expr operator==(Expr a, Expr b) {
+PrimExpr operator==(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::EQNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::EQNode>(a, b);
   if (ret.defined()) return ret;
   return ir::EQNode::make(a, b);
 }
 
-Expr operator!=(Expr a, Expr b) {
+PrimExpr operator!=(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::NENode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::NENode>(a, b);
   if (ret.defined()) return ret;
   return ir::NENode::make(a, b);
 }
 
-Expr operator&&(Expr a, Expr b) {
+PrimExpr operator&&(PrimExpr a, PrimExpr b) {
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::AndNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::AndNode>(a, b);
   if (ret.defined()) return ret;
   return ir::AndNode::make(a, b);
 }
 
-Expr operator||(Expr a, Expr b) {
+PrimExpr operator||(PrimExpr a, PrimExpr b) {
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::OrNode>(a, b);
+  PrimExpr ret = arith::TryConstFold<ir::OrNode>(a, b);
   if (ret.defined()) return ret;
   return ir::OrNode::make(a, b);
 }
 
-Expr operator!(Expr a) {
+PrimExpr operator!(PrimExpr a) {
   CHECK(a.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::NotNode>(a);
+  PrimExpr ret = arith::TryConstFold<ir::NotNode>(a);
   if (ret.defined()) return ret;
   return ir::NotNode::make(a);
 }
 
-Expr operator>>(Expr a, Expr b) {
+PrimExpr operator>>(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
@@ -433,7 +433,7 @@ Expr operator>>(Expr a, Expr b) {
     a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator<<(Expr a, Expr b) {
+PrimExpr operator<<(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
@@ -446,7 +446,7 @@ Expr operator<<(Expr a, Expr b) {
     a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator&(Expr a, Expr b) {
+PrimExpr operator&(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
@@ -456,7 +456,7 @@ Expr operator&(Expr a, Expr b) {
     a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator|(Expr a, Expr b) {
+PrimExpr operator|(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
@@ -466,7 +466,7 @@ Expr operator|(Expr a, Expr b) {
     a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator^(Expr a, Expr b) {
+PrimExpr operator^(PrimExpr a, PrimExpr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
@@ -476,20 +476,20 @@ Expr operator^(Expr a, Expr b) {
     a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
-Expr operator~(Expr a) {
+PrimExpr operator~(PrimExpr a) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   return ir::CallNode::make(
     a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic);
 }
 
-Expr pow(Expr x, Expr y) {
+PrimExpr pow(PrimExpr x, PrimExpr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "power only applies to float";
   return ir::CallNode::make(
     x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic);
 }
 
-Expr abs(Expr x) {
+PrimExpr abs(PrimExpr x) {
   if (x.dtype().is_int()) {
     using ir::IntImmNode;
     const IntImmNode* px = x.as<IntImmNode>();
@@ -513,7 +513,7 @@ Expr abs(Expr x) {
   }
 }
 
-Expr isnan(Expr x) {
+PrimExpr isnan(PrimExpr x) {
   DataType t = DataType::Bool(x.dtype().lanes());
   if (x.dtype().is_int() || x.dtype().is_uint()) {
     return make_const(t, false);
@@ -537,97 +537,97 @@ Expr isnan(Expr x) {
   }
 }
 
-Expr sum(Expr source, Array<IterVar> rdom) {
+PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::AddNode::make(x, y);
-  Expr identity_element = make_zero(source.dtype());
+  PrimExpr result = ir::AddNode::make(x, y);
+  PrimExpr identity_element = make_zero(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr all(Expr source, Array<IterVar> rdom) {
+PrimExpr all(PrimExpr source, Array<IterVar> rdom) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::AndNode::make(x, y);
-  Expr identity_element = make_const(source.dtype(), true);
+  PrimExpr result = ir::AndNode::make(x, y);
+  PrimExpr identity_element = make_const(source.dtype(), true);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr any(Expr source, Array<IterVar> rdom) {
+PrimExpr any(PrimExpr source, Array<IterVar> rdom) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::OrNode::make(x, y);
-  Expr identity_element = make_const(source.dtype(), false);
+  PrimExpr result = ir::OrNode::make(x, y);
+  PrimExpr identity_element = make_const(source.dtype(), false);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr max(Expr source, Array<IterVar> rdom) {
+PrimExpr max(PrimExpr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::MaxNode::make(x, y);
-  Expr identity_element = min_value(source.dtype());
+  PrimExpr result = ir::MaxNode::make(x, y);
+  PrimExpr identity_element = min_value(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr min(Expr source, Array<IterVar> rdom) {
+PrimExpr min(PrimExpr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::MinNode::make(x, y);
-  Expr identity_element = max_value(source.dtype());
+  PrimExpr result = ir::MinNode::make(x, y);
+  PrimExpr identity_element = max_value(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr prod(Expr source, Array<IterVar> rdom) {
+PrimExpr prod(PrimExpr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::MulNode::make(x, y);
-  Expr identity_element = make_const(source.dtype(), 1);
+  PrimExpr result = ir::MulNode::make(x, y);
+  PrimExpr identity_element = make_const(source.dtype(), 1);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
   return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
-Expr fmod(Expr x, Expr y) {
+PrimExpr fmod(PrimExpr x, PrimExpr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "fmod only applies to float";
   return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic);
 }
 
-Expr floor(Expr x) {
+PrimExpr floor(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
   return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
 }
 
-Expr ceil(Expr x) {
+PrimExpr ceil(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
   return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
 }
 
-Expr round(Expr x) {
+PrimExpr round(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
   return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
 }
 
-Expr nearbyint(Expr x) {
+PrimExpr nearbyint(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
   return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
 }
 
-Expr trunc(Expr x) {
+PrimExpr trunc(PrimExpr x) {
   using ir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) {
index 6b777cc..ad7f260 100644 (file)
@@ -31,41 +31,41 @@ namespace tvm {
 namespace ir {
 
 // constructors
-Expr UIntImmNode::make(DataType t, uint64_t value) {
+PrimExpr UIntImmNode::make(DataType t, uint64_t value) {
   CHECK(t.is_uint() && t.lanes() == 1)
       << "ValueError: UIntImm can only take scalar";
   ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
   node->dtype = t;
   node->value = value;
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr FloatImmNode::make(DataType t, double value) {
+PrimExpr FloatImmNode::make(DataType t, double value) {
   CHECK_EQ(t.lanes(), 1)
       << "ValueError: FloatImm can only take scalar";
   ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
   node->dtype = t;
   node->value = value;
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr StringImmNode::make(std::string value) {
+PrimExpr StringImmNode::make(std::string value) {
   ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
   node->dtype = DataType::Handle();
   node->value = std::move(value);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr CastNode::make(DataType t, Expr value) {
+PrimExpr CastNode::make(DataType t, PrimExpr value) {
   CHECK(value.defined());
   CHECK_EQ(t.lanes(), value.dtype().lanes());
   ObjectPtr<CastNode> node = make_object<CastNode>();
   node->dtype = t;
   node->value = std::move(value);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr AndNode::make(Expr a, Expr b) {
+PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(b.defined()) << "ValueError: b is undefined";
   CHECK(a.dtype().is_bool());
@@ -76,10 +76,10 @@ Expr AndNode::make(Expr a, Expr b) {
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
   node->b = std::move(b);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr OrNode::make(Expr a, Expr b) {
+PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(b.defined()) << "ValueError: b is undefined";
   CHECK(a.dtype().is_bool());
@@ -90,20 +90,20 @@ Expr OrNode::make(Expr a, Expr b) {
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
   node->b = std::move(b);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr NotNode::make(Expr a) {
+PrimExpr NotNode::make(PrimExpr a) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(a.dtype().is_bool());
 
   ObjectPtr<NotNode> node = make_object<NotNode>();
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) {
+PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
   CHECK(condition.defined()) << "ValueError: condition is undefined";
   CHECK(true_value.defined()) << "ValueError: true_value is undefined";
   CHECK(false_value.defined()) << "ValueError: true_value is undefined";
@@ -116,10 +116,10 @@ Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) {
   node->condition = std::move(condition);
   node->true_value = std::move(true_value);
   node->false_value = std::move(false_value);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
+PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) {
   CHECK(buffer_var.defined());
   CHECK(predicate.defined());
   CHECK(index.defined());
@@ -132,10 +132,10 @@ Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate)
   node->index = std::move(index);
   node->predicate = std::move(predicate);
 
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr RampNode::make(Expr base, Expr stride, int lanes) {
+PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) {
   CHECK(base.defined());
   CHECK(stride.defined());
   CHECK(base.dtype().is_scalar());
@@ -148,10 +148,10 @@ Expr RampNode::make(Expr base, Expr stride, int lanes) {
   node->base = base;
   node->stride = stride;
   node->lanes = lanes;
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr BroadcastNode::make(Expr value, int lanes) {
+PrimExpr BroadcastNode::make(PrimExpr value, int lanes) {
   CHECK(value.defined());
   CHECK(value.dtype().is_scalar());
   CHECK_GT(lanes, 1);
@@ -160,10 +160,10 @@ Expr BroadcastNode::make(Expr value, int lanes) {
   node->dtype = value.dtype().with_lanes(lanes);
   node->value = std::move(value);
   node->lanes = lanes;
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr LetNode::make(Var var, Expr value, Expr body) {
+PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) {
   CHECK(value.defined());
   CHECK(body.defined());
   CHECK_EQ(value.dtype(), var.dtype());
@@ -173,7 +173,7 @@ Expr LetNode::make(Var var, Expr value, Expr body) {
   node->var = std::move(var);
   node->value = std::move(value);
   node->body = std::move(body);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
 const char* CallNode::vectorizable_intrinsics[] = {
@@ -192,9 +192,9 @@ bool CallNode::is_vectorizable() const {
   return false;
 }
 
-Expr CallNode::make(DataType dtype,
+PrimExpr CallNode::make(DataType dtype,
                 std::string name,
-                Array<Expr> args,
+                Array<PrimExpr> args,
                 CallType call_type,
                 FunctionRef func,
                 int value_index) {
@@ -215,18 +215,18 @@ Expr CallNode::make(DataType dtype,
   node->call_type = call_type;
   node->func = std::move(func);
   node->value_index = value_index;
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr ShuffleNode::make(Array<Expr> vectors,
-                   Array<Expr> indices) {
+PrimExpr ShuffleNode::make(Array<PrimExpr> vectors,
+                   Array<PrimExpr> indices) {
   CHECK_NE(vectors.size(), 0U);
   CHECK_NE(indices.size(), 0U);
 
   DataType base_type = vectors[0].dtype().element_of();
   int total_lanes = 0;
 
-  for (Expr val : vectors) {
+  for (PrimExpr val : vectors) {
     CHECK(val.dtype().element_of() == base_type);
     total_lanes += val.dtype().lanes();
   }
@@ -236,17 +236,17 @@ Expr ShuffleNode::make(Array<Expr> vectors,
   node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
   node->vectors = std::move(vectors);
   node->indices = std::move(indices);
-  return Expr(node);
+  return PrimExpr(node);
 }
 
-Expr ShuffleNode::make_concat(Array<Expr> vectors) {
+PrimExpr ShuffleNode::make_concat(Array<PrimExpr> vectors) {
   CHECK_NE(vectors.size(), 0);
   if (vectors.size() == 1) {
     return vectors[0];
   }
-  Array<Expr> indices;
+  Array<PrimExpr> indices;
   int index = 0;
-  for (const Expr& e : vectors) {
+  for (const PrimExpr& e : vectors) {
     for (int i = 0; i < e.dtype().lanes(); ++i) {
       indices.push_back(IntImmNode::make(DataType::Int(32), index++));
     }
@@ -254,14 +254,14 @@ Expr ShuffleNode::make_concat(Array<Expr> vectors) {
   return make(vectors, indices);
 }
 
-Expr ShuffleNode::make_extract_element(Expr vector, int index) {
+PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) {
   return make({vector}, {Integer(index)});
 }
 
 CommReducer CommReducerNode::make(Array<Var> lhs,
                                   Array<Var> rhs,
-                                  Array<Expr> result,
-                                  Array<Expr> identity_element) {
+                                  Array<PrimExpr> result,
+                                  Array<PrimExpr> identity_element) {
   auto node = make_object<CommReducerNode>();
   node->lhs = lhs;
   node->rhs = rhs;
@@ -270,22 +270,22 @@ CommReducer CommReducerNode::make(Array<Var> lhs,
   return CommReducer(node);
 }
 
-Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
+Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b) const {
   CHECK_EQ(a.size(), b.size());
   CHECK_EQ(lhs.size(), a.size());
   CHECK_EQ(rhs.size(), b.size());
-  Map<Var, Expr> value_map;
+  Map<Var, PrimExpr> value_map;
   for (size_t i = 0; i < a.size(); ++i) {
     value_map.Set(lhs[i], a[i]);
     value_map.Set(rhs[i], b[i]);
   }
-  return UpdateArray(result, [&value_map] (const Expr& e) {
+  return UpdateArray(result, [&value_map] (const PrimExpr& e) {
       return Substitute(e, value_map);
     });
 }
 
-Expr ReduceNode::make(CommReducer combiner, Array<Expr> source,
-                  Array<IterVar> axis, Expr condition, int value_index) {
+PrimExpr ReduceNode::make(CommReducer combiner, Array<PrimExpr> source,
+                  Array<IterVar> axis, PrimExpr condition, int value_index) {
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK_EQ(axis[i]->iter_type, kCommReduce)
         << "Can only take axis created by reduce_axis";
@@ -304,15 +304,15 @@ Expr ReduceNode::make(CommReducer combiner, Array<Expr> source,
   n->axis = std::move(axis);
   n->condition = condition;
   n->value_index = value_index;
-  return Expr(n);
+  return PrimExpr(n);
 }
 
-Expr AnyNode::make() {
+PrimExpr AnyNode::make() {
   auto n = make_object<AnyNode>();
-  return Expr(n);
+  return PrimExpr(n);
 }
 
-Stmt LetStmtNode::make(Var var, Expr value, Stmt body) {
+Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
   CHECK(value.defined());
   CHECK(body.defined());
   CHECK_EQ(value.dtype(), var.dtype());
@@ -326,7 +326,7 @@ Stmt LetStmtNode::make(Var var, Expr value, Stmt body) {
 
 Stmt AttrStmtNode::make(ObjectRef node,
                     std::string attr_key,
-                    Expr value,
+                    PrimExpr value,
                     Stmt body) {
   auto n = make_object<AttrStmtNode>();
   n->node = node;
@@ -336,7 +336,7 @@ Stmt AttrStmtNode::make(ObjectRef node,
   return Stmt(n);
 }
 
-Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) {
+Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) ||
         message.as<StringImmNode>())
@@ -361,8 +361,8 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
 }
 
 Stmt ForNode::make(Var loop_var,
-               Expr min,
-               Expr extent,
+               PrimExpr min,
+               PrimExpr extent,
                ForType for_type,
                DeviceAPI device_api,
                Stmt body) {
@@ -383,7 +383,7 @@ Stmt ForNode::make(Var loop_var,
   return Stmt(node);
 }
 
-Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
   CHECK(value.defined());
   CHECK(index.defined());
   CHECK(predicate.defined());
@@ -398,7 +398,7 @@ Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
   return Stmt(node);
 }
 
-Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array<PrimExpr> args) {
   CHECK(value_index >=0 && value_index < func->num_outputs())
       << "value index output function return value bound";
   CHECK(value.defined()) << "Provide of undefined value\n";
@@ -417,10 +417,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array<Expr
 
 Stmt AllocateNode::make(Var buffer_var,
                     DataType dtype,
-                    Array<Expr> extents,
-                    Expr condition,
+                    Array<PrimExpr> extents,
+                    PrimExpr condition,
                     Stmt body,
-                    Expr new_expr,
+                    PrimExpr new_expr,
                     std::string free_function) {
     for (size_t i = 0; i < extents.size(); ++i) {
       CHECK(extents[i].defined());
@@ -441,7 +441,7 @@ Stmt AllocateNode::make(Var buffer_var,
     return Stmt(node);
 }
 
-int32_t AllocateNode::constant_allocation_size(const Array<Expr>& extents) {
+int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
   int64_t result = 1;
   for (size_t i = 0; i < extents.size(); ++i) {
     if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
@@ -466,7 +466,7 @@ Stmt RealizeNode::make(FunctionRef func,
                    int value_index,
                    DataType dtype,
                    Region bounds,
-                   Expr condition,
+                   PrimExpr condition,
                    Stmt body) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
@@ -510,7 +510,7 @@ SeqStmt::SeqStmt(Array<Stmt> seq) {
   data_ = std::move(node);
 }
 
-Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) {
+Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
   CHECK(condition.defined());
   CHECK(then_case.defined());
   // else_case may be null.
@@ -522,7 +522,7 @@ Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) {
   return Stmt(node);
 }
 
-Stmt EvaluateNode::make(Expr value) {
+Stmt EvaluateNode::make(PrimExpr value) {
   CHECK(value.defined());
 
   ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
index f797700..35b4029 100644 (file)
 namespace tvm {
 
 // Tensor
-Expr Tensor::operator()(Array<Var> indices) const {
-  Array<Expr> arr(indices.begin(), indices.end());
+PrimExpr Tensor::operator()(Array<Var> indices) const {
+  Array<PrimExpr> arr(indices.begin(), indices.end());
   return operator()(arr);
 }
 
-Expr Tensor::operator()(Array<Expr> indices) const {
+PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
   using ir::CallNode;
   if (ndim() != 0) {
     CHECK_EQ(ndim(), indices.size())
@@ -55,7 +55,7 @@ Tensor Operation::output(size_t i) const {
   return Tensor(node);
 }
 
-Tensor TensorNode::make(Array<Expr> shape,
+Tensor TensorNode::make(Array<PrimExpr> shape,
                         DataType dtype,
                         Operation op,
                         int value_index) {
@@ -114,7 +114,7 @@ TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
                                             Array<Tensor> tensors,
                                             Array<Region> regions,
                                             Array<IterVar> reduce_axis,
-                                            Array<Expr> scalar_inputs) {
+                                            Array<PrimExpr> scalar_inputs) {
   auto n = make_object<TensorIntrinCallNode>();
   n->intrin = std::move(intrin);
   n->tensors = std::move(tensors);
index 0ad68b1..7c8427a 100644 (file)
@@ -75,10 +75,10 @@ DataType ComputeOpNode::output_dtype(size_t idx) const {
   return body[idx].dtype();
 }
 
-Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
+Array<PrimExpr> BaseComputeOpNode::output_shape(size_t idx) const {
   CHECK_LT(idx, num_outputs());
   // for now, all outputs of a BaseComputeOp have the same shape
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   for (const auto& ivar : this->axis) {
     const Range& r = ivar->dom;
     shape.push_back(r->extent);
@@ -86,7 +86,7 @@ Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
   return shape;
 }
 
-Tensor compute(Array<Expr> shape,
+Tensor compute(Array<PrimExpr> shape,
                FCompute fcompute,
                std::string name,
                std::string tag,
@@ -108,7 +108,7 @@ Tensor compute(Array<Expr> shape,
       name, tag, attrs, axis, {fcompute(args)}).output(0);
 }
 
-Array<Tensor> compute(Array<Expr> shape,
+Array<Tensor> compute(Array<PrimExpr> shape,
                       FBatchCompute fcompute,
                       std::string name,
                       std::string tag,
@@ -138,7 +138,7 @@ Operation ComputeOpNode::make(std::string name,
                               std::string tag,
                               Map<std::string, ObjectRef> attrs,
                               Array<IterVar> axis,
-                              Array<Expr> body) {
+                              Array<PrimExpr> body) {
   if (!attrs.defined()) {
     attrs = Map<std::string, ObjectRef>();
   }
@@ -180,24 +180,24 @@ Operation ComputeOpNode::ReplaceInputs(
     const std::unordered_map<Tensor, Tensor>& rmap) const {
   CHECK_EQ(self.operator->(), this);
   VerifyComputeOp(this);
-  Array<Expr> arr;
+  Array<PrimExpr> arr;
   if (this->body[0]->IsInstance<ir::ReduceNode>()) {
     // Specially handle reduce so the replaced op
     // still share all the components
-    Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
+    PrimExpr new_reduce = op::ReplaceTensor(this->body[0], rmap);
     if (!new_reduce.same_as(this->body[0])) {
       const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
       for (size_t k = 0; k < this->body.size(); ++k) {
         auto n = make_object<ir::ReduceNode>(*r);
         n->value_index = static_cast<int>(k);
         n->dtype = r->source[k].dtype();
-        arr.push_back(Expr(n));
+        arr.push_back(PrimExpr(n));
       }
     } else {
       arr = this->body;
     }
   } else {
-    arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
+    arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) {
         return op::ReplaceTensor(e, rmap);
       });
   }
@@ -229,10 +229,10 @@ void ComputeOpNode::PropBoundToInputs(
           IntSet arg_intset = EvalSet(call->args[i], dom_map);
           const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
           if (arg_interval) {
-            Expr shape_i_min_value = make_zero(t->shape[i].dtype());
-            Expr shape_i_max_value = t->shape[i] - 1;
-            Expr min_value = arg_interval->min_value;
-            Expr max_value = arg_interval->max_value;
+            PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
+            PrimExpr shape_i_max_value = t->shape[i] - 1;
+            PrimExpr min_value = arg_interval->min_value;
+            PrimExpr max_value = arg_interval->max_value;
             // Prefer the shape bounds only when we can prove they are tighter.
             if (arith::is_neg_inf(min_value) ||
                 analyzer->CanProve(shape_i_min_value >= min_value)) {
@@ -290,7 +290,7 @@ Stmt BaseComputeOpNode::BuildRealize(
       if (it != stage->iter_var_attrs.end()) {
         IterVarAttr attr = (*it).second;
         if (attr->dim_align_factor != 0) {
-          Array<Expr> tuple = {static_cast<int>(i),
+          Array<PrimExpr> tuple = {static_cast<int>(i),
                                attr->dim_align_factor,
                                attr->dim_align_offset};
           realize = ir::AttrStmtNode::make(
@@ -315,7 +315,7 @@ void MakeReduction(const ComputeOpNode* op,
                    const Array<Tensor>& tensors,
                    Stmt* init,
                    Stmt* provide) {
-  Array<Expr>  args;
+  Array<PrimExpr>  args;
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
@@ -326,12 +326,12 @@ void MakeReduction(const ComputeOpNode* op,
   CHECK(reduce);
   const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
   CHECK(combiner);
-  Array<Expr> lhs;
+  Array<PrimExpr> lhs;
   for (size_t i = 0; i < size; ++i) {
     lhs.push_back(tensors[i](args));
   }
-  Array<Expr> init_value = combiner->identity_element;
-  Array<Expr> update_value = (*combiner)(lhs, reduce->source);
+  Array<PrimExpr> init_value = combiner->identity_element;
+  Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
   for (size_t i = 0; i < size; ++i) {
     Tensor t = tensors[i];
     inits.emplace_back(ProvideNode::make(
@@ -349,7 +349,7 @@ void MakeReduction(const ComputeOpNode* op,
 // Normal computation.
 Stmt MakeProvide(const ComputeOpNode* op,
                  const Tensor& t) {
-  Array<Expr> args;
+  Array<PrimExpr> args;
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
@@ -555,7 +555,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
 
   /// Interface to perform compute verification
   void Run() {
-    for (const Expr e : compute_->body) {
+    for (const PrimExpr e : compute_->body) {
       // Check for consistency of top level reductions
       const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
       CHECK((reduce && reduce_) || (!reduce && !reduce_))
@@ -576,7 +576,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
  protected:
   /// Visitor implementation
   //@{
-  void VisitExpr(const Expr& n) final {
+  void VisitExpr(const PrimExpr& n) final {
     ++level_;
     ExprVisitor::VisitExpr(n);
     --level_;
@@ -608,7 +608,7 @@ Stmt TransformUpdate(const Stage& stage,
                      const ComputeLoopNest& n,
                      Stmt body,
                      Stmt update) {
-  Array<Expr> conds;
+  Array<PrimExpr> conds;
   std::unordered_set<const VarNode*> banned;
   for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
     IterVar iv = stage->leaf_iter_vars[i];
@@ -627,7 +627,7 @@ Stmt TransformUpdate(const Stage& stage,
       banned.insert(iv->var.get());
     }
   }
-  for (const Expr& pred : n.main_predicates) {
+  for (const PrimExpr& pred : n.main_predicates) {
     if (ir::ExprUseVar(pred, banned)) {
       LOG(FATAL) << "Tensorize update transform failed, the condition "
                  << pred << " has a conflict with the reset condition";
index f5735d8..3fe98e8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,17 +38,17 @@ struct ComputeLoopNest {
   // The common number of loops between init and main
   size_t num_common_loop;
   // predicates for the initialize loop
-  std::vector<Expr> init_predicates;
+  std::vector<PrimExpr> init_predicates;
   // Initialization nest involved.
   std::vector<std::vector<Stmt> > init_nest;
   // Value map for the init code
-  std::unordered_map<IterVar, Expr> init_vmap;
+  std::unordered_map<IterVar, PrimExpr> init_vmap;
   // Predicates for the main update loop
-  std::vector<Expr> main_predicates;
+  std::vector<PrimExpr> main_predicates;
   // The general loop nest
   std::vector<std::vector<Stmt> > main_nest;
   // Value map for the IterVar.
-  std::unordered_map<IterVar, Expr> main_vmap;
+  std::unordered_map<IterVar, PrimExpr> main_vmap;
 
   /*!
    * \brief constructor to build ComputeOpNest
index 89d0ca7..9de4bde 100644 (file)
@@ -33,11 +33,11 @@ Stmt MakeCrossThreadReduction(
     const Stage& stage,
     const std::unordered_map<IterVar, Range>& dom_map,
     bool debug_keep_trivial_loop) {
-  Array<Expr>  args;
+  Array<PrimExpr>  args;
   for (IterVar iv : self->axis) {
     args.push_back(iv->var);
   }
-  std::unordered_map<IterVar, Expr> value_map;
+  std::unordered_map<IterVar, PrimExpr> value_map;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, debug_keep_trivial_loop);
   auto conds = schedule::MakeBoundCheck(
@@ -52,11 +52,11 @@ Stmt MakeCrossThreadReduction(
     CHECK(reduce);
     reduces[i] = reduce;
   }
-  Expr cond = reduces[0]->condition;
-  for (Expr v : conds) {
+  PrimExpr cond = reduces[0]->condition;
+  for (PrimExpr v : conds) {
     cond = cond && v;
   }
-  Array<Expr> freduce_args;
+  Array<PrimExpr> freduce_args;
   freduce_args.push_back(make_const(DataType::UInt(32), static_cast<uint32_t>(size)));
   for (size_t i = 0; i < size; ++i) {
     freduce_args.push_back(reduces[0]->source[i]);
@@ -79,7 +79,7 @@ Stmt MakeCrossThreadReduction(
     }
   }
   // Checks for the thread.
-  std::vector<Expr> thread_head_check;
+  std::vector<PrimExpr> thread_head_check;
   if (stage->store_predicate.defined()) {
     thread_head_check.emplace_back(stage->store_predicate);
   }
index ee958da..6fc54a8 100644 (file)
@@ -50,7 +50,7 @@ DataType ExternOpNode::output_dtype(size_t i) const {
   return output_placeholders[i]->dtype;
 }
 
-Array<Expr> ExternOpNode::output_shape(size_t i) const {
+Array<PrimExpr> ExternOpNode::output_shape(size_t i) const {
   return output_placeholders[i]->shape;
 }
 
@@ -162,7 +162,7 @@ Stmt ExternOpNode::BuildProvide(
   Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
   auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
     Array<ObjectRef> bind_spec;
-    Array<Expr> tuple;
+    Array<PrimExpr> tuple;
     bind_spec.push_back(buffer);
     bind_spec.push_back(tensor);
     for (size_t k = 0; k < buffer->shape.size(); ++k) {
index 5364c38..c3be234 100644 (file)
@@ -56,7 +56,7 @@ DataType HybridOpNode::output_dtype(size_t i) const {
   return outputs[i]->dtype;
 }
 
-Array<Expr> HybridOpNode::output_shape(size_t i) const {
+Array<PrimExpr> HybridOpNode::output_shape(size_t i) const {
   return outputs[i]->shape;
 }
 
@@ -222,7 +222,7 @@ namespace op {
 Stmt ApplyLoopShapes(const Stage &stage,
                  const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
   class LoopSpliter : public StmtExprMutator {
-    Expr factor;
+    PrimExpr factor;
     const VarNode *parent;
     IterVar inner, outer;
 
@@ -249,14 +249,14 @@ Stmt ApplyLoopShapes(const Stage &stage,
 
     Stmt VisitStmt_(const ForNode *op) final {
       if (op->loop_var.get() == parent) {
-        std::unordered_map<const VarNode *, Expr> rmap;
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
         rmap[op->loop_var.get()] = inner + outer * factor;
         Stmt ret = ir::Substitute(op->body, rmap);
-        Expr cond = likely(outer * factor < (op->extent - inner));
+        PrimExpr cond = likely(outer * factor < (op->extent - inner));
         ret = IfThenElseNode::make(cond, ret);
-        ret = ForNode::make(inner->var, Expr(0), inner->dom->extent,
+        ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
                         IterVarTypeToForType(inner->iter_type), op->device_api, ret);
-        ret = ForNode::make(outer->var, Expr(0), outer->dom->extent,
+        ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
                         IterVarTypeToForType(outer->iter_type), op->device_api, ret);
         splitted = true;
         return ret;
@@ -270,7 +270,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
     const VarNode *inner;
     const VarNode *outer;
     bool under_outer;
-    Expr extent;
+    PrimExpr extent;
 
    public:
     bool fused;
@@ -283,7 +283,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
     Stmt VisitStmt_(const ForNode* op) final {
       if (op->loop_var.get() == inner) {
         CHECK(under_outer);
-        std::unordered_map<const VarNode *, Expr> rmap;
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
         rmap[op->loop_var.get()] = indexmod(parent, op->extent);
         extent = op->extent;
         fused = true;
@@ -291,15 +291,15 @@ Stmt ApplyLoopShapes(const Stage &stage,
       } else if (op->loop_var.get() == outer) {
         under_outer = true;
         Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const VarNode *, Expr> rmap;
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
         rmap[op->loop_var.get()] = indexdiv(parent, extent);
         body = ir::Substitute(body, rmap);
         under_outer = false;
-        return ForNode::make(parent->var, Expr(0), extent * op->extent,
+        return ForNode::make(parent->var, PrimExpr(0), extent * op->extent,
                          op->for_type, op->device_api, body);
       } else if (under_outer) {
         Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const VarNode *, Expr> rmap;
+        std::unordered_map<const VarNode *, PrimExpr> rmap;
         rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
         body = ir::Substitute(body, rmap);
         extent = extent * op->extent;
@@ -342,7 +342,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
             CHECK(Equal(iter_var->dom->extent, op->extent))
               << "Thread extent and loop extent mismatch!\n";
           }
-          std::unordered_map<const VarNode *, Expr> rmap;
+          std::unordered_map<const VarNode *, PrimExpr> rmap;
           rmap[op->loop_var.get()] = iter_var;
           Stmt body = ir::Substitute(op->body, rmap);
           return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
index 31d736d..e108ad3 100644 (file)
@@ -42,14 +42,14 @@ MakeLoopNest(const Stage& stage,
              size_t begin_iter_pos,
              bool new_loop_var,
              const std::unordered_set<IterVar>& skip_iter,
-             std::unordered_map<IterVar, Expr>* p_value_map,
+             std::unordered_map<IterVar, PrimExpr>* p_value_map,
              bool debug_keep_trivial_loop) {
   auto leaf_iter_vars = stage->leaf_iter_vars;
   Stmt no_op = EvaluateNode::make(0);
   // create the loop nest
   std::vector<std::vector<Stmt> > nest;
   nest.resize(leaf_iter_vars.size() + 1);
-  std::unordered_map<IterVar, Expr>& value_map = *p_value_map;
+  std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
 
   for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
     auto iv = leaf_iter_vars[i];
@@ -96,7 +96,7 @@ MakeLoopNest(const Stage& stage,
         CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
         for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
           const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
-          Expr pvalue = it_attr->pragma_values[k];
+          PrimExpr pvalue = it_attr->pragma_values[k];
           if (!pvalue.defined()) {
             pvalue = make_const(DataType::Int(32), 1);
           }
@@ -118,7 +118,7 @@ MakeLoopNest(const Stage& stage,
         nest[i + 1].emplace_back(
             ForNode::make(idx, 0, dom->extent,
                       for_type, DeviceAPI::None, no_op));
-        Expr new_value = dom->min + idx;
+        PrimExpr new_value = dom->min + idx;
         value_map[iv] = new_value;
         nest[i + 1].emplace_back(
             LetStmtNode::make(var, new_value, no_op));
@@ -176,10 +176,10 @@ MakeLoopNest(const Stage& stage,
   return nest;
 }
 
-std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
   Stmt no_op = EvaluateNode::make(0);
   std::vector<Stmt> nest;
-  for (const Expr& cond : predicates) {
+  for (const PrimExpr& cond : predicates) {
     nest.emplace_back(IfThenElseNode::make(cond, no_op));
   }
   return nest;
@@ -191,12 +191,12 @@ class TensorReplacer : public ir::StmtExprMutator {
   explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
       : vmap_(vmap) {}
 
-  Expr VisitExpr_(const ir::CallNode* op) final {
+  PrimExpr VisitExpr_(const ir::CallNode* op) final {
     if (op->call_type == ir::CallNode::Halide) {
       Tensor t = Downcast<Operation>(op->func).output(op->value_index);
       auto it = vmap_.find(t);
       if (it != vmap_.end()) {
-        Expr ret = ir::CallNode::make(
+        PrimExpr ret = ir::CallNode::make(
             op->dtype, it->second->op->name, op->args,
             op->call_type, it->second->op, it->second->value_index);
         found = true;
@@ -219,17 +219,17 @@ Stmt ReplaceTensor(Stmt stmt,
   Stmt ret = repl(stmt);
   return repl.found ? ret : stmt;
 }
-Expr ReplaceTensor(Expr expr,
+PrimExpr ReplaceTensor(PrimExpr expr,
                    const std::unordered_map<Tensor, Tensor>& replace) {
   TensorReplacer repl(replace);
-  Expr ret = repl(expr);
+  PrimExpr ret = repl(expr);
   return repl.found ? ret : expr;
 }
 
 
 Stmt Substitute(Stmt s,
-                const std::unordered_map<IterVar, Expr>& value_map) {
-  std::unordered_map<const VarNode*, Expr> init;
+                const std::unordered_map<IterVar, PrimExpr>& value_map) {
+  std::unordered_map<const VarNode*, PrimExpr> init;
   for (const auto& kv : value_map) {
     init[kv.first->var.get()] = kv.second;
   }
index b57000f..cea050b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -55,7 +55,7 @@ MakeLoopNest(const Stage& stage,
              size_t begin_iter_pos,
              bool new_loop_var,
              const std::unordered_set<IterVar>& skip_iter,
-             std::unordered_map<IterVar, Expr>* p_value_map,
+             std::unordered_map<IterVar, PrimExpr>* p_value_map,
              bool debug_keep_trivial_loop);
 
 /*!
@@ -64,7 +64,7 @@ MakeLoopNest(const Stage& stage,
  * \param predicates The predicates to be checked.
  * \return List of If nest that checks the predicates.
  */
-std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates);
+std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
 
 /*!
  * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
@@ -78,7 +78,7 @@ Stmt ReplaceTensor(Stmt stmt,
  * \param expr The expression to be processed.
  * \param replace The replacement rule.
  */
-Expr ReplaceTensor(Expr expr,
+PrimExpr ReplaceTensor(PrimExpr expr,
                    const std::unordered_map<Tensor, Tensor>& replace);
 
 /*!
@@ -88,7 +88,7 @@ Expr ReplaceTensor(Expr expr,
  * \return Substituted result.
  */
 Stmt Substitute(Stmt stmt,
-                const std::unordered_map<IterVar, Expr>& value_map);
+                const std::unordered_map<IterVar, PrimExpr>& value_map);
 
 /*!
  * \brief Converts Halide ForType to its corresponding IterVarType
index 2ec10ca..22e0ad4 100644 (file)
@@ -47,13 +47,13 @@ DataType PlaceholderOpNode::output_dtype(size_t i) const {
   return dtype;
 }
 
-Array<Expr> PlaceholderOpNode::output_shape(size_t i) const {
+Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
   CHECK_EQ(i, 0U);
   return shape;
 }
 
 Operation PlaceholderOpNode::make(std::string name,
-                                  Array<Expr> shape,
+                                  Array<PrimExpr> shape,
                                   DataType dtype) {
   auto n = make_object<PlaceholderOpNode>();
   n->name = name;
@@ -62,7 +62,7 @@ Operation PlaceholderOpNode::make(std::string name,
   return Operation(n);
 }
 
-Tensor placeholder(Array<Expr> shape, DataType dtype, std::string name) {
+Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
   return PlaceholderOpNode::make(name, shape, dtype).output(0);
 }
 
index c4c0960..f7b16f2 100644 (file)
@@ -38,7 +38,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 TVM_REGISTER_NODE_TYPE(ScanOpNode);
 
-inline bool prove_equal(Expr lhs, Expr rhs) {
+inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) {
   return is_zero(ir::Simplify(lhs - rhs));
 }
 
@@ -57,7 +57,7 @@ DataType ScanOpNode::output_dtype(size_t i) const {
   return update[i]->dtype;
 }
 
-Array<Expr> ScanOpNode::output_shape(size_t i) const {
+Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
   CHECK_LT(i, state_placeholder.size());
   return state_placeholder[i]->shape;
 }
@@ -232,7 +232,7 @@ void ScanOpNode::GatherBound(
   Range r = arith::Union(time_dom).cover_range(sdom);
   (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent(
       sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
-  Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(self);
+  Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
   // Update for spatial axis.
   size_t sp_idx = 0;
   for (size_t i = 0; i < output.size(); ++i) {
@@ -295,7 +295,7 @@ Stmt ScanOpNode::BuildProvide(
       begin_scan = i + 1;
     }
   }
-  std::unordered_map<IterVar, Expr> vmap;
+  std::unordered_map<IterVar, PrimExpr> vmap;
   std::unordered_set<IterVar> empty;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop);
index e0656ea..08df8b7 100644 (file)
@@ -57,7 +57,7 @@ Operation TensorComputeOpNode::make(std::string name,
                                     TensorIntrin intrin,
                                     Array<Tensor> tensors,
                                     Array<Region> regions,
-                                    Array<Expr> scalar_inputs) {
+                                    Array<PrimExpr> scalar_inputs) {
   auto n = make_object<TensorComputeOpNode>();
   n->name = std::move(name);
   n->tag = std::move(tag);
@@ -147,7 +147,7 @@ Stmt TensorComputeOpNode::BuildProvide(
     Buffer buffer = this->intrin->buffers[i];
     Array<ObjectRef> bind_spec{buffer, tensor};
 
-    Array<Expr> tuple;
+    Array<PrimExpr> tuple;
     for (size_t i = 0; i < region.size(); ++i) {
       tuple.push_back(region[i]->min);
       tuple.push_back(region[i]->extent);
@@ -165,7 +165,7 @@ Stmt TensorComputeOpNode::BuildProvide(
     Buffer buffer = this->intrin->buffers[num_inputs + i];
     Array<ObjectRef> bind_spec{buffer, tensor};
 
-    Array<Expr> tuple;
+    Array<PrimExpr> tuple;
     for (size_t i = 0; i < this->axis.size(); ++i) {
       auto ivar = this->axis[i];
       if (i < static_cast<size_t>(this->schedulable_ndim)) {
@@ -186,16 +186,16 @@ Stmt TensorComputeOpNode::BuildProvide(
   }
 
   // Check variable remap
-  std::unordered_map<const VarNode*, Expr> vmap;
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
   ir::ArgBinder binder(&vmap);
 
   // Map the expressions passed in the call to the TensorIntrin, to the placeholder
   // variables
-  Array<Expr> user_expr = this->scalar_inputs;
+  Array<PrimExpr> user_expr = this->scalar_inputs;
   Array<Var> scalar_params = this->intrin->scalar_params;
-  Array<Expr> sp_expr;
+  Array<PrimExpr> sp_expr;
   for (auto sp : scalar_params) {
-    Expr esp = sp;
+    PrimExpr esp = sp;
     sp_expr.push_back(esp);
   }
   CHECK_EQ(sp_expr.size(), user_expr.size());
index 601c444..4460d90 100644 (file)
@@ -144,13 +144,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
       }
     }
   }
-  for (const Expr& pred : n.main_predicates) {
+  for (const PrimExpr& pred : n.main_predicates) {
     if (ir::ExprUseVar(pred, banned)) {
       LOG(FATAL) << "Tensorize failed, split condition "
                  << pred << " relies on var defined inside tensorize scope";
     }
   }
-  for (const Expr& pred : n.init_predicates) {
+  for (const PrimExpr& pred : n.init_predicates) {
     if (ir::ExprUseVar(pred, banned)) {
       LOG(FATAL) << "Tensorize failed, split condition "
                  << pred << " relies on var defined inside tensorize scope";
@@ -161,8 +161,8 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
 // Remap the tensor placeholder, index and inline things.
 class TensorIntrinMatcher final : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const CallNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     if (op->call_type == CallNode::Halide) {
       Tensor t = Downcast<Operation>(op->func).output(op->value_index);
@@ -170,7 +170,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
       if (it != in_remap_.end()) {
         const InputEntry& e = it->second;
         CHECK_EQ(op->args.size(), e.region.size());
-        Array<Expr> args;
+        Array<PrimExpr> args;
         for (size_t i = e.start; i < e.region.size(); ++i) {
           args.push_back(op->args[i] - e.region[i]->min);
         }
@@ -182,17 +182,17 @@ class TensorIntrinMatcher final : public StmtExprMutator {
     return expr;
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = var_remap_.find(op);
     if (it != var_remap_.end()) {
       return it->second;
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
 
-  Expr VisitExpr_(const ReduceNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const ReduceNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<ReduceNode>();
     Array<IterVar> axis;
     for (size_t i = 0; i < op->axis.size(); ++i) {
@@ -301,13 +301,13 @@ class TensorIntrinMatcher final : public StmtExprMutator {
   // input data remap
   std::unordered_map<Tensor, InputEntry> in_remap_;
   // variable remap.
-  std::unordered_map<const VarNode*, Expr> var_remap_;
+  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
   // IterVar remap.
   std::unordered_map<IterVar, IterVar> axis_remap_;
 };
 
 // Try to match tensor dataflow of the stage with the intrinsic
-Array<Expr> MatchTensorizeBody(
+Array<PrimExpr> MatchTensorizeBody(
     const ComputeOpNode* self,
     const Stage& stage,
     const std::unordered_map<IterVar, Range>& dom_map,
@@ -317,8 +317,8 @@ Array<Expr> MatchTensorizeBody(
     Map<Var, Range>* compute_intrin_iter_space) {
   TensorIntrinMatcher matcher;
   matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
-  Array<Expr> ret;
-  for (Expr expr : self->body) {
+  Array<PrimExpr> ret;
+  for (PrimExpr expr : self->body) {
     ret.push_back(matcher(expr));
   }
   return ret;
@@ -332,16 +332,16 @@ void VerifyTensorizeBody(
     const std::unordered_map<Tensor, Array<Range> >& in_region,
     const TensorIntrin& intrin) {
   Map<Var, Range> compute_intrin_iter_space;
-  Array<Expr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
+  Array<PrimExpr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
                                         &compute_intrin_iter_space);
   const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
   CHECK(intrin_compute) << "Only support compute intrinsic for now";
   CHECK_EQ(body.size(), intrin_compute->body.size())
       << "Tensorize failed: body size mismatch";
   for (size_t i = 0; i < body.size(); ++i) {
-    Expr lhs = Simplify(body[i], compute_intrin_iter_space);
+    PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space);
     lhs = CanonicalSimplify(lhs, compute_intrin_iter_space);
-    Expr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
+    PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space);
     rhs = CanonicalSimplify(rhs, compute_intrin_iter_space);
     if (lhs.dtype() != rhs.dtype()) {
       LOG(FATAL)
@@ -385,7 +385,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
     auto it = in_region.find(tensor);
     CHECK(it != in_region.end());
     const Array<Range>& region = it->second;
-    Array<Expr> tuple;
+    Array<PrimExpr> tuple;
     for (const Range r : region) {
       tuple.push_back(r->min);
       tuple.push_back(r->extent);
@@ -401,7 +401,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
   CHECK(intrin_compute) << "Only support compute intrinsic for now";
   CHECK_EQ(intrin->inputs.size() + intrin_compute->body.size(), intrin->buffers.size());
   CHECK_EQ(intrin_compute->body.size(), self->body.size());
-  Array<Expr> tuple;
+  Array<PrimExpr> tuple;
   for (IterVar iv : self->axis) {
     auto it = out_dom.find(iv);
     CHECK(it != out_dom.end());
@@ -419,7 +419,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
                        tuple, CallNode::Intrinsic), nop));
   }
   // Check variable remap
-  std::unordered_map<const VarNode*, Expr> vmap;
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
   ir::ArgBinder binder(&vmap);
   CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
       << "Tensorization fail: reduction axis size do not match";
index 340f3a8..2c04de3 100644 (file)
 namespace tvm {
 namespace ir {
 
-void BinderAddAssert(Expr cond,
+void BinderAddAssert(PrimExpr cond,
                      const std::string& arg_name,
                      std::vector<Stmt>* asserts) {
-  Expr scond = Simplify(cond);
+  PrimExpr scond = Simplify(cond);
   if (is_zero(scond)) {
     LOG(FATAL) << "Bind have an unmet assertion: "
                << cond << ", " << " on argument " << arg_name;
@@ -46,8 +46,8 @@ void BinderAddAssert(Expr cond,
   }
 }
 
-bool ArgBinder::Bind_(const Expr& arg,
-                      const Expr& value,
+bool ArgBinder::Bind_(const PrimExpr& arg,
+                      const PrimExpr& value,
                       const std::string& arg_name,
                       bool with_lets) {
   CHECK_EQ(arg.dtype(), value.dtype());
@@ -72,15 +72,15 @@ bool ArgBinder::Bind_(const Expr& arg,
   return false;
 }
 
-void ArgBinder::Bind(const Expr& arg,
-                     const Expr& value,
+void ArgBinder::Bind(const PrimExpr& arg,
+                     const PrimExpr& value,
                      const std::string& arg_name,
                      bool with_let) {
   Bind_(arg, value, arg_name, with_let);
 }
 
-void ArgBinder::BindArray(const Array<Expr>& arg,
-                          const Array<Expr>& value,
+void ArgBinder::BindArray(const Array<PrimExpr>& arg,
+                          const Array<PrimExpr>& value,
                           const std::string& arg_name) {
   CHECK_EQ(arg.size(), value.size())
       << "Argument " << arg_name << " array size mismatch";
@@ -117,9 +117,9 @@ void ArgBinder::BindBuffer(const Buffer& arg,
   this->Bind(arg->data, value->data, arg_name + ".data");
   if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
     if (arg->offset_factor > 1) {
-      Expr offset = value->elem_offset;
-      Expr factor = make_const(offset.dtype(), arg->offset_factor);
-      Expr zero = make_zero(offset.dtype());
+      PrimExpr offset = value->elem_offset;
+      PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
+      PrimExpr zero = make_zero(offset.dtype());
       BinderAddAssert(truncmod(offset, factor) == zero,
                       arg_name + ".elem_offset", &asserts_);
     }
@@ -153,21 +153,21 @@ void ArgBinder::BindBuffer(const Buffer& arg,
   }
 }
 
-inline Expr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
+inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) {
   return TVMStructGet(t, arr, 0, kind);
 }
 
 void ArgBinder::BindDLTensor(const Buffer& buffer,
-                             const Expr& device_type,
-                             const Expr& device_id,
+                             const PrimExpr& device_type,
+                             const PrimExpr& device_id,
                              const Var& handle,
                              const std::string& arg_name) {
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
   const Stmt nop = EvaluateNode::make(0);
   // dimension checks
-  Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
-  Expr a_ndim = make_const(tvm_ndim_type,
+  PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
+  PrimExpr a_ndim = make_const(tvm_ndim_type,
                            static_cast<int64_t>(buffer->shape.size()));
   std::ostringstream ndim_err_msg;
   ndim_err_msg << arg_name
@@ -178,7 +178,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   DataType dtype = buffer->dtype;
   std::ostringstream type_err_msg;
   type_err_msg << arg_name << ".dtype is expected to be " << dtype;
-  Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
+  PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
                UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
                TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
                UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
@@ -215,17 +215,17 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   init_nest_.emplace_back(LetStmtNode::make(
       v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
       nop));
-  Expr is_null = CallNode::make(
+  PrimExpr is_null = CallNode::make(
     DataType::Bool(1), intrinsic::tvm_handle_is_null,
     {v_strides}, CallNode::PureIntrinsic);
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
-    Expr expect_stride = make_const(stype, 1);
-    Array<Expr> conds;
+    PrimExpr expect_stride = make_const(stype, 1);
+    Array<PrimExpr> conds;
     for (size_t i = buffer->shape.size(); i != 0; --i) {
       size_t k = i - 1;
-      Expr svalue = cast(
+      PrimExpr svalue = cast(
           stype,
           LoadNode::make(tvm_shape_type, v_strides,
                      IntImmNode::make(DataType::Int(32), k), const_true(1)));
@@ -237,19 +237,19 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
                    << " expected to be compact array";
     if (conds.size() != 0) {
       Stmt check =
-          AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, Expr()),
+          AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, PrimExpr()),
                            stride_err_msg.str(), EvaluateNode::make(0));
       check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
       asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
     }
   } else if (buffer->buffer_type == kAutoBroadcast) {
     DataType stype = buffer->DefaultIndexType();
-    Expr stride = make_const(stype, 1);
+    PrimExpr stride = make_const(stype, 1);
     for (size_t i = buffer->shape.size(); i != 0; --i) {
       size_t k = i - 1;
       std::ostringstream field_name;
       field_name << v_strides->name_hint << '[' << k << ']';
-      Expr value = cast(buffer->shape[k].dtype(),
+      PrimExpr value = cast(buffer->shape[k].dtype(),
                         LoadNode::make(tvm_shape_type, v_strides,
                                    IntImmNode::make(DataType::Int(32), k), const_true(1)));
       value = tvm::if_then_else(is_null, stride, value);
@@ -288,9 +288,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
                     make_const(DataType::UInt(64), data_bytes))),
               arg_name + ".elem_offset", true)) {
       if (buffer->offset_factor > 1) {
-        Expr offset = buffer->elem_offset;
-        Expr factor = make_const(offset.dtype(), buffer->offset_factor);
-        Expr zero = make_zero(offset.dtype());
+        PrimExpr offset = buffer->elem_offset;
+        PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
+        PrimExpr zero = make_zero(offset.dtype());
         BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_);
       }
     }
index 55d8c22..75006d6 100644 (file)
@@ -62,7 +62,7 @@ class ArgBinder {
    *   ArgBinder will update this def_map when adding new definitions.
    */
   explicit ArgBinder(
-      std::unordered_map<const VarNode*, Expr>* def_map)
+      std::unordered_map<const VarNode*, PrimExpr>* def_map)
       : def_map_(def_map) {
   }
   /*!
@@ -72,8 +72,8 @@ class ArgBinder {
    * \param arg_name argument name.
    * \param with_let Whether add lets during bind
    */
-  void Bind(const Expr& arg,
-            const Expr& value,
+  void Bind(const PrimExpr& arg,
+            const PrimExpr& value,
             const std::string& arg_name,
             bool with_let = false);
   /*!
@@ -82,8 +82,8 @@ class ArgBinder {
    * \param value The target expression value
    * \param arg_name argument name.
    */
-  void BindArray(const Array<Expr>& arg,
-                 const Array<Expr>& value,
+  void BindArray(const Array<PrimExpr>& arg,
+                 const Array<PrimExpr>& value,
                  const std::string& arg_name);
   /*!
    * \brief Bind symbolic buffer to another symbolic buffer
@@ -105,8 +105,8 @@ class ArgBinder {
    * \param arg_name argument name.
    */
   void BindDLTensor(const Buffer& buffer,
-                    const Expr& device_type,
-                    const Expr& device_id,
+                    const PrimExpr& device_type,
+                    const PrimExpr& device_id,
                     const Var& handle,
                     const std::string& arg_name);
 
@@ -133,24 +133,24 @@ class ArgBinder {
     return init_nest_;
   }
   /*! \return Handle data type of the data */
-  const Map<Var, Expr>& def_handle_dtype() const {
+  const Map<Var, PrimExpr>& def_handle_dtype() const {
     return def_handle_dtype_;
   }
 
  private:
   // Internal bind function
-  bool Bind_(const Expr& arg,
-             const Expr& value,
+  bool Bind_(const PrimExpr& arg,
+             const PrimExpr& value,
              const std::string& arg_name,
              bool with_lets);
   /*! \brief The definition map, can be uses to substitute */
-  std::unordered_map<const VarNode*, Expr>* def_map_;
+  std::unordered_map<const VarNode*, PrimExpr>* def_map_;
   /*! \brief defs generated in the current binder */
   std::vector<Var> defs_;
   /*! \brief Initialize nest */
   std::vector<Stmt> init_nest_;
   /*! \brief handle data type in the defintiions */
-  Map<Var, Expr> def_handle_dtype_;
+  Map<Var, PrimExpr> def_handle_dtype_;
   /*! \brief asserts generated */
   std::vector<Stmt> asserts_;
 };
index 84939fc..439c886 100644 (file)
@@ -45,13 +45,13 @@ class BoundCollector : public StmtVisitor {
     StmtVisitor::VisitStmt_(op);
   }
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const VarNode *, Expr> mem_to_shape;
+  std::unordered_map<const VarNode *, PrimExpr> mem_to_shape;
 };
 
 class BoundChecker : public StmtExprMutator {
  public:
   explicit BoundChecker(
-      const std::unordered_map<const VarNode *, Expr> &mem_to_shape)
+      const std::unordered_map<const VarNode *, PrimExpr> &mem_to_shape)
       : mem_to_shape_(mem_to_shape) {}
 
   Stmt VisitStmt_(const AllocateNode* op) final {
@@ -62,7 +62,7 @@ class BoundChecker : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
       unsafe_rewritten_ = true;
     }
@@ -80,7 +80,7 @@ class BoundChecker : public StmtExprMutator {
     }
     // The collector should has at least one item.
     if (store_scope_bound_collector_.size()) {
-      Expr condition = MakeCondition();
+      PrimExpr condition = MakeCondition();
       if (!condition.as<StringImmNode>()) {
         Stmt nop = EvaluateNode::make(1);
         Stmt then_case =
@@ -94,7 +94,7 @@ class BoundChecker : public StmtExprMutator {
     return GetRef<Stmt>(op);
   }
 
-  Expr VisitExpr_(const LoadNode* op) final {
+  PrimExpr VisitExpr_(const LoadNode* op) final {
     if (CanInstrument(op->index, op->buffer_var)) {
       Collect(op->index, op->buffer_var);
     }
@@ -102,12 +102,12 @@ class BoundChecker : public StmtExprMutator {
   }
 
  private:
-  bool UpdateIsNeeded(const VarExpr& buffer_var) const {
+  bool UpdateIsNeeded(const Var& buffer_var) const {
     return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
   }
 
-  void Update(const VarExpr& buffer_var,
-              const Array<Expr>& new_shape,
+  void Update(const Var& buffer_var,
+              const Array<PrimExpr>& new_shape,
               const DataType& type) {
     // Sanity check at first.
     if (!new_shape.size()) {
@@ -122,7 +122,7 @@ class BoundChecker : public StmtExprMutator {
     }
 
     // Scalarize the shape.
-    Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+    PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
                            CastNode::make(DataType::UInt(64), new_shape[0]));
     for (size_t i = 1; i < new_shape.size(); ++i) {
       // Cast to unsigned to avoid integer overlow at frist.
@@ -132,7 +132,7 @@ class BoundChecker : public StmtExprMutator {
     mem_to_shape_[buffer_var.get()] = shape;
   }
 
-  bool IndexIsValid(const Expr& index) const {
+  bool IndexIsValid(const PrimExpr& index) const {
     if (!index.defined()) {
       return false;
     }
@@ -146,22 +146,22 @@ class BoundChecker : public StmtExprMutator {
     return true;
   }
 
-  bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const {
+  bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const {
     return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
            IndexIsValid(index) && !unsafe_rewritten_;
   }
 
-  void Collect(Expr index, VarExpr buffer_var) {
+  void Collect(PrimExpr index, Var buffer_var) {
     store_scope_bound_collector_.push_back(
         std::make_pair(index, mem_to_shape_[buffer_var.get()]));
   }
 
-  Expr MakeCondition() {
-    Expr condition;
+  PrimExpr MakeCondition() {
+    PrimExpr condition;
     for (size_t i = 0; i < store_scope_bound_collector_.size(); ++i) {
-      std::pair<Expr, Expr> buffer_to_mem = store_scope_bound_collector_[i];
-      Expr index = buffer_to_mem.first;
-      Expr upper_bound = buffer_to_mem.second;
+      std::pair<PrimExpr, PrimExpr> buffer_to_mem = store_scope_bound_collector_[i];
+      PrimExpr index = buffer_to_mem.first;
+      PrimExpr upper_bound = buffer_to_mem.second;
 
       if (const RampNode *ramp_index = index.as<RampNode>()) {
         // In case index is base + stride * i.
@@ -181,9 +181,9 @@ class BoundChecker : public StmtExprMutator {
       upper_bound = CastNode::make(DataType::Int(64), upper_bound);
 
       // Looks like a lower bound should always be zero after normalization.
-      Expr lower_bound = make_zero(DataType::Int(64));
+      PrimExpr lower_bound = make_zero(DataType::Int(64));
 
-      Expr current_condition =
+      PrimExpr current_condition =
           AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
       condition =
           !i ? current_condition : AndNode::make(condition, current_condition);
@@ -196,11 +196,11 @@ class BoundChecker : public StmtExprMutator {
   // Whether we face tvm_if_then_else intrinsic.
   bool unsafe_rewritten_{false};
   // Pool which collects the pair of index and shape for specific store/load.
-  std::vector<std::pair<Expr, Expr>> store_scope_bound_collector_;
+  std::vector<std::pair<PrimExpr, PrimExpr>> store_scope_bound_collector_;
   // Error message.
   const char *const error_message_ = "OUT OF THE BOUNDS";
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const VarNode *, Expr> mem_to_shape_;
+  std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
 };
 
 Stmt InstrumentBoundCheckers(Stmt stmt) {
index 62ceede..4561dba 100644 (file)
@@ -35,15 +35,15 @@ namespace ir {
 class ContextCallCombiner final : public StmtExprMutator {
  public:
   struct CompareExpr {
-    bool operator()(const Expr& lhs, const Expr& rhs) const {
+    bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
       return Compare(lhs, rhs) < 0;
     }
   };
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
       CHECK_EQ(op->args.size(), 1U);
-      Expr ctx = op->args[0];
+      PrimExpr ctx = op->args[0];
       auto it  = ctx_map_.find(ctx);
       if (it != ctx_map_.end()) {
         return it->second;
@@ -68,7 +68,7 @@ class ContextCallCombiner final : public StmtExprMutator {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::coproc_uop_scope) {
       // Map of comparison expression to variable
-      std::map<Expr, Var, CompareExpr> temp;
+      std::map<PrimExpr, Var, CompareExpr> temp;
       std::swap(temp, ctx_map_);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       std::swap(temp, ctx_map_);
@@ -81,7 +81,7 @@ class ContextCallCombiner final : public StmtExprMutator {
   Stmt VisitStmt_(const ForNode* op) final {
     if (op->for_type == ForType::Parallel) {
       // Map of comparison expression to variable
-      std::map<Expr, Var, CompareExpr> temp;
+      std::map<PrimExpr, Var, CompareExpr> temp;
       std::swap(temp, ctx_map_);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       std::swap(temp, ctx_map_);
@@ -96,7 +96,7 @@ class ContextCallCombiner final : public StmtExprMutator {
   }
 
  private:
-  static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap,
+  static Stmt BuildContext(const std::map<PrimExpr, Var, CompareExpr>& cmap,
                            Stmt body) {
     for (const auto& kv : cmap) {
       body = LetStmtNode::make(kv.second, kv.first, body);
@@ -104,7 +104,7 @@ class ContextCallCombiner final : public StmtExprMutator {
     return body;
   }
   // Map of comparison expression to variable
-  std::map<Expr, Var, CompareExpr> ctx_map_;
+  std::map<PrimExpr, Var, CompareExpr> ctx_map_;
 };
 
 LoweredFunc CombineContextCall(LoweredFunc f) {
index a7afd46..4e68793 100644 (file)
@@ -341,8 +341,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
     Range r = arith::Union(wset).cover_range(none);
     CHECK(r.defined())
         << "Cannot deduce write range of " << wvec[0].buffer;
-    Expr min = r->min;
-    Expr extent = r->extent;
+    PrimExpr min = r->min;
+    PrimExpr extent = r->extent;
     return EvaluateNode::make(CallNode::make(
         DataType::Int(32), func,
         {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic));
index 789877f..aff7d3d 100644 (file)
@@ -161,7 +161,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
     });
 
   return IRTransform(parent_for_stmt, nullptr, replace_target_for,
-                     {Expr("For")});
+                     {PrimExpr("For")});
 }
 
 // Remove IfThenElse node from a For node.
@@ -188,10 +188,10 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
     });
 
   then_for = IRTransform(for_stmt, nullptr, replace_then_case,
-                         {Expr("IfThenElse")});
+                         {PrimExpr("IfThenElse")});
   if (if_stmt.as<IfThenElseNode>()->else_case) {
     else_for = IRTransform(for_stmt, nullptr, replace_else_case,
-                           {Expr("IfThenElse")});
+                           {PrimExpr("IfThenElse")});
   }
 
   return std::make_pair(then_for, else_for);
@@ -412,7 +412,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
         *ret = new_for;
       }
     });
-  return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")});
+  return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
 }
 
 Stmt HoistIfThenElse(Stmt stmt) {
index 8f6c06d..6dfa509 100644 (file)
@@ -185,7 +185,7 @@ class InferFragmenter : public StmtMutator {
       std::string shape = std::to_string(info.m) + ", " +
                           std::to_string(info.n) + ", " +
                           std::to_string(info.k);
-      Expr shape_expr = StringImmNode::make(shape);
+      PrimExpr shape_expr = StringImmNode::make(shape);
       Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
       if (info.layout != "") {
         // Add shape attribute to matrix_a and matrix_b
index 0a19c69..e41a868 100644 (file)
@@ -69,7 +69,7 @@ class CopyIntrinInjector : public StmtMutator {
     if (store == nullptr) return false;
     // Expr sel_cond, sel_true_value, sel_false_value;
     // match select or if
-    PVar<Expr> sel_cond, sel_true_value, sel_false_value;
+    PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
     bool has_cond =
         if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
         select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
@@ -93,12 +93,12 @@ class CopyIntrinInjector : public StmtMutator {
     for (const ForNode* op : loops) {
       loop_vars.push_back(op->loop_var);
     }
-    Array<Expr> store_strides =
+    Array<PrimExpr> store_strides =
         arith::DetectLinearEquation(store->index, loop_vars);
-    Array<Expr> load_strides =
+    Array<PrimExpr> load_strides =
         arith::DetectLinearEquation(load->index, loop_vars);
     if (load_strides.size()  == 0 || store_strides.size() == 0) return false;
-    Array<Expr> dst_shape;
+    Array<PrimExpr> dst_shape;
     const size_t loop_var_size = loop_vars.size();
     if (loop_var_size == 0) {
       dst_shape.push_back(make_const(DataType::Int(32), 1));
@@ -107,24 +107,24 @@ class CopyIntrinInjector : public StmtMutator {
         dst_shape.push_back(op->extent);
       }
     }
-    Array<Expr> src_shape = dst_shape;
-    Array<Expr> pad_before, pad_after;
-    Expr pad_value;
-    Expr src_elem_offset = load_strides[loop_var_size];
+    Array<PrimExpr> src_shape = dst_shape;
+    Array<PrimExpr> pad_before, pad_after;
+    PrimExpr pad_value;
+    PrimExpr src_elem_offset = load_strides[loop_var_size];
     if (has_cond) {
-      Array<Expr> clip_bound =
+      Array<PrimExpr> clip_bound =
           arith::DetectClipBound(sel_cond.Eval(), loop_vars);
       pad_value = sel_false_value.Eval();
       if (clip_bound.size() == 0) return false;
       CHECK_EQ(src_shape.size(), loop_vars.size());
       CHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
       for (size_t i = 0; i < src_shape.size(); ++i) {
-        Expr min_value = clip_bound[2 * i];
-        Expr max_value = clip_bound[2 * i + 1];
+        PrimExpr min_value = clip_bound[2 * i];
+        PrimExpr max_value = clip_bound[2 * i + 1];
         DataType t = loop_vars[i].dtype();
-        Expr svalue = src_shape[i];
+        PrimExpr svalue = src_shape[i];
         if (min_value.defined()) {
-          Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
+          PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
           src_elem_offset = src_elem_offset + pbefore * load_strides[i];
           svalue = svalue - pbefore;
           pad_before.push_back(pbefore);
@@ -132,7 +132,7 @@ class CopyIntrinInjector : public StmtMutator {
           pad_before.push_back(make_zero(t));
         }
         if (max_value.defined()) {
-          Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
+          PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
                                            make_zero(t)));
           svalue = svalue - pafter;
           pad_after.push_back(pafter);
@@ -145,8 +145,8 @@ class CopyIntrinInjector : public StmtMutator {
     }
     CHECK_EQ(load_strides.size(), store_strides.size());
     CHECK_EQ(load_strides.size(), loop_var_size + 1);
-    Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
-    Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
+    Array<PrimExpr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
+    Array<PrimExpr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
     if (loop_var_size == 0) {
         src_strides.push_back(make_const(DataType::Int(32), 1));
         dst_strides.push_back(make_const(DataType::Int(32), 1));
index 4bd431e..9ed5606 100644 (file)
@@ -99,11 +99,11 @@ class DoubleBufferInjector : public StmtExprMutator {
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
       it->second.stride = arith::ComputeReduce<MulNode>(
-          op->extents, Expr()) * op->dtype.lanes();
+          op->extents, PrimExpr()) * op->dtype.lanes();
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AllocateNode>();
-      Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
-      for (Expr e : op->extents) {
+      Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
+      for (PrimExpr e : op->extents) {
         new_extents.push_back(e);
       }
       CHECK(it->second.loop != nullptr);
@@ -132,14 +132,14 @@ class DoubleBufferInjector : public StmtExprMutator {
         CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
             << "It is better to split with multiple of 2";
         CHECK(is_zero(old_loop->min));
-        Expr zero = old_loop->min;
-        Expr new_ext =
+        PrimExpr zero = old_loop->min;
+        PrimExpr new_ext =
             old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
-        Expr factor = make_const(new_ext.dtype(), split_loop_);
-        Expr outer_ext = new_ext / factor;
-        Expr tail_base = outer_ext * factor;
+        PrimExpr factor = make_const(new_ext.dtype(), split_loop_);
+        PrimExpr outer_ext = new_ext / factor;
+        PrimExpr tail_base = outer_ext * factor;
         Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
-        std::unordered_map<const VarNode*, Expr> vmap;
+        std::unordered_map<const VarNode*, PrimExpr> vmap;
         std::vector<Stmt> loop_seq;
         for (int32_t i = 0; i < split_loop_; ++i) {
           vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
@@ -152,7 +152,7 @@ class DoubleBufferInjector : public StmtExprMutator {
         std::vector<Stmt> tail_seq;
         Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
         for (int32_t i = 0; i < split_loop_; ++i) {
-          Expr idx = tail_base + make_const(tail_base.dtype(), i);
+          PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
           vmap[old_loop->loop_var.get()] = idx;
           tail_seq.emplace_back(
               IfThenElseNode::make(idx < old_loop->extent,
@@ -187,8 +187,8 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
@@ -204,14 +204,14 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     CHECK(!dbuffer_info_.count(op));
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   }
 
  private:
   Stmt MakeProducer(const AttrStmtNode* op) {
-    const VarExpr buffer = Downcast<VarExpr>(op->node);
+    const Var buffer = Downcast<Var>(op->node);
     CHECK_NE(loop_nest_.size(), 0U)
         << "Double buffer scope must be inside a loop";
     auto it = dbuffer_info_.find(buffer.get());
@@ -221,17 +221,17 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
     StorageEntry& e = it->second;
     e.loop = loop_nest_.back();
-    Expr zero = make_const(e.loop->loop_var.dtype(), 0);
-    Expr one = make_const(e.loop->loop_var.dtype(), 1);
-    Expr two = make_const(e.loop->loop_var.dtype(), 2);
-    Expr loop_shift = e.loop->loop_var + one;
+    PrimExpr zero = make_const(e.loop->loop_var.dtype(), 0);
+    PrimExpr one = make_const(e.loop->loop_var.dtype(), 1);
+    PrimExpr two = make_const(e.loop->loop_var.dtype(), 2);
+    PrimExpr loop_shift = e.loop->loop_var + one;
     e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db",
                              e.loop->loop_var.dtype());
     e.switch_read_var = indexmod(e.loop->loop_var, two);
     in_double_buffer_scope_ = true;
     Stmt body = this->VisitStmt(op->body);
     in_double_buffer_scope_ = false;
-    std::unordered_map<const VarNode*, Expr> vmap;
+    std::unordered_map<const VarNode*, PrimExpr> vmap;
     vmap[e.switch_write_var.get()] = zero;
     vmap[e.loop->loop_var.get()] = zero;
     loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
@@ -245,13 +245,13 @@ class DoubleBufferInjector : public StmtExprMutator {
   // Storage entry for those who need double buffering.
   struct StorageEntry {
     // The size of the buffer
-    Expr stride;
+    PrimExpr stride;
     // The loop we need
     const ForNode* loop{nullptr};
     // The switch variable.
-    VarExpr switch_write_var;
+    Var switch_write_var;
     // The switch variable for reading.
-    Expr switch_read_var;
+    PrimExpr switch_read_var;
     // The storage scope.
     std::string scope;
   };
index c58a91d..d7abed8 100644 (file)
@@ -79,7 +79,7 @@ class PrefetchInjector : public StmtMutator {
   }
 
  private:
-  std::vector<VarExpr> loop_nest_;
+  std::vector<Var> loop_nest_;
   std::unordered_map<const VarNode *, IntSet> vectorized_;
   static const Range none;
 };
index 8eeee9d..83fc582 100644 (file)
@@ -36,7 +36,7 @@ class ExprTouched final : public StmtExprVisitor {
                        bool check_write)
       : touched_var_(touched), check_write_(check_write) {}
 
-  void VisitExpr(const Expr& n) final {
+  void VisitExpr(const PrimExpr& n) final {
     // early stopping
     if (expr_touched_ && !check_write_) return;
     StmtExprVisitor::VisitExpr(n);
@@ -205,20 +205,20 @@ class VTInjector : public StmtExprMutator {
     return stmt;
   }
   // Variable
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     CHECK(!alloc_remap_.count(op))
         << "Buffer address may get rewritten in virtual thread";
     if (touched_var_.count(op)) {
       visit_touched_var_ = true;
     }
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   }
-  Expr RewriteIndex(Expr index, Expr alloc_extent) const {
+  PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
     return index + var_ * alloc_extent;
   }
   // Load
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     if (touched_var_.count(op->buffer_var.get())) {
       visit_touched_var_ = true;
@@ -233,7 +233,7 @@ class VTInjector : public StmtExprMutator {
     }
   }
   // Expression.
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       CHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
@@ -241,9 +241,9 @@ class VTInjector : public StmtExprMutator {
       auto it = alloc_remap_.find(buffer);
       if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
       visit_touched_var_ = true;
-      Expr offset = this->VisitExpr(op->args[2]);
-      Expr extent = this->VisitExpr(op->args[3]);
-      Expr stride =
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
+      PrimExpr stride =
           it->second / make_const(offset.dtype(), dtype.lanes());
       offset = stride * var_ + offset;
       return CallNode::make(
@@ -251,7 +251,7 @@ class VTInjector : public StmtExprMutator {
           {op->args[0], op->args[1], offset, extent, op->args[4]},
           op->call_type);
     } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
-      return allow_share_ ? GetRef<Expr>(op) : var_;
+      return allow_share_ ? GetRef<PrimExpr>(op) : var_;
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
@@ -280,7 +280,7 @@ class VTInjector : public StmtExprMutator {
   }
   // Attribute
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    Expr value = this->VisitExpr(op->value);
+    PrimExpr value = this->VisitExpr(op->value);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     } else if (!allow_share_ && !vt_loop_injected_ &&
@@ -299,7 +299,7 @@ class VTInjector : public StmtExprMutator {
   }
   // LetStmt
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    Expr value = this->VisitExpr(op->value);
+    PrimExpr value = this->VisitExpr(op->value);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
@@ -315,7 +315,7 @@ class VTInjector : public StmtExprMutator {
   // For
   Stmt VisitStmt_(const ForNode* op) final {
     CHECK(is_zero(op->min));
-    Expr extent = this->VisitExpr(op->extent);
+    PrimExpr extent = this->VisitExpr(op->extent);
     if (visit_touched_var_ && !vt_loop_injected_) {
       Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
       ++max_loop_depth_;
@@ -334,7 +334,7 @@ class VTInjector : public StmtExprMutator {
   }
   // IfThenElse
   Stmt VisitStmt_(const IfThenElseNode* op) final {
-    Expr condition = this->VisitExpr(op->condition);
+    PrimExpr condition = this->VisitExpr(op->condition);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
@@ -374,15 +374,15 @@ class VTInjector : public StmtExprMutator {
     if (op->new_expr.defined() && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
-    Expr condition = this->VisitExpr(op->condition);
+    PrimExpr condition = this->VisitExpr(op->condition);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
 
     bool changed = false;
-    Array<Expr> extents;
+    Array<PrimExpr> extents;
     for (size_t i = 0; i < op->extents.size(); i++) {
-      Expr new_ext = this->VisitExpr(op->extents[i]);
+      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
       if (visit_touched_var_ && !vt_loop_injected_) {
         return InjectVTLoop(GetRef<Stmt>(op), true);
       }
@@ -395,11 +395,11 @@ class VTInjector : public StmtExprMutator {
     // always rewrite if not allow sharing.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      Expr stride = arith::ComputeReduce<MulNode>(
-          op->extents, Expr()) * op->dtype.lanes();
-      Array<Expr> other;
+      PrimExpr stride = arith::ComputeReduce<MulNode>(
+          op->extents, PrimExpr()) * op->dtype.lanes();
+      Array<PrimExpr> other;
       other.push_back(make_const(op->extents[0].dtype(), num_threads_));
-      for (Expr e : extents) {
+      for (PrimExpr e : extents) {
         other.push_back(e);
       }
       extents = other;
@@ -448,7 +448,7 @@ class VTInjector : public StmtExprMutator {
     } else {
       // insert a for loop
       Var idx(var_->name_hint + ".s", var_->dtype);
-      Map<Var, Expr> values{{var_, idx}};
+      Map<Var, PrimExpr> values{{var_, idx}};
       stmt = Substitute(stmt, values);
       return ForNode::make(idx, make_zero(idx.dtype()),
                        make_const(idx.dtype(), num_threads_),
@@ -474,7 +474,7 @@ class VTInjector : public StmtExprMutator {
   // Whether allow shareding.
   bool allow_share_;
   // The allocations that get touched -> extent
-  std::unordered_map<const VarNode*, Expr> alloc_remap_;
+  std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
 };
 
 
index 4a087dd..fad3f17 100644 (file)
@@ -32,11 +32,11 @@ namespace ir {
 // ConvertSSA need to be applied after this pass
 class IRInline final : public StmtExprMutator {
  public:
-  IRInline(FunctionRef f, Array<Var> args, Expr body)
+  IRInline(FunctionRef f, Array<Var> args, PrimExpr body)
       : f_(f), args_(args), body_(body) {}
 
-  Expr VisitExpr_(const CallNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
 
     if (op->func == f_) {
@@ -53,7 +53,7 @@ class IRInline final : public StmtExprMutator {
           expr = LetNode::make(args_[i], op->args[i], expr);
         }
       } else {
-        Map<Var, Expr> vmap;
+        Map<Var, PrimExpr> vmap;
         for (size_t i = 0; i < args_.size(); ++i) {
           vmap.Set(args_[i], op->args[i]);
         }
@@ -69,13 +69,13 @@ class IRInline final : public StmtExprMutator {
  private:
   FunctionRef f_;
   Array<Var> args_;
-  Expr body_;
+  PrimExpr body_;
 };
 
 Stmt Inline(Stmt stmt,
             FunctionRef f,
             Array<Var> args,
-            Expr body) {
+            PrimExpr body) {
   CHECK_EQ(f->num_outputs(), 1)
       << "can only inline output single value operation";
   Stmt ret = IRInline(f, args, body)(std::move(stmt));
index a1218f2..6eacb14 100644 (file)
 namespace tvm {
 namespace ir {
 
-using ExprComparator = ExprFunctor<void(const Expr& n, const Expr &other)>;
+using ExprComparator = ExprFunctor<void(const PrimExpr& n, const PrimExpr &other)>;
 using StmtComparator = StmtFunctor<void(const Stmt& n, const Stmt &other)>;
 
 #define DEFINE_BIOP_EXPR_CMP_(OP)                                 \
-  void VisitExpr_(const OP* op, const Expr& other) final {        \
+  void VisitExpr_(const OP* op, const PrimExpr& other) final {    \
     const OP* rhs = other.as<OP>();                               \
-    if (CompareExpr(op->a, rhs->a) != 0) return;                      \
-    if (CompareExpr(op->b, rhs->b) != 0) return;                      \
+    if (CompareExpr(op->a, rhs->a) != 0) return;                  \
+    if (CompareExpr(op->b, rhs->b) != 0) return;                  \
   }
 
 // Deep comparison to check if two IR graph are equivalent
@@ -47,19 +47,19 @@ class IRDeepCompare :
     return order_ == 0;
   }
 
-  bool Equal(const Expr& lhs, const Expr& rhs) {
+  bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) {
     tie_def_ = true;
     VisitExpr(lhs, rhs);
     return order_ == 0;
   }
 
-  int Compare(const Expr& lhs, const Expr& rhs) {
+  int Compare(const PrimExpr& lhs, const PrimExpr& rhs) {
     tie_def_ = false;
     VisitExpr(lhs, rhs);
     return order_;
   }
 
-  void VisitExpr(const Expr& n, const Expr& other) override {
+  void VisitExpr(const PrimExpr& n, const PrimExpr& other) override {
     if (order_ != 0) return;
     if (n.same_as(other)) return;
     if (CompareValue(n->type_index(), other->type_index()) != 0) return;
@@ -193,7 +193,7 @@ class IRDeepCompare :
   }
 
   // Exprs
-  void VisitExpr_(const VarNode* op, const Expr& other) final {
+  void VisitExpr_(const VarNode* op, const PrimExpr& other) final {
     const VarNode* rhs = other.as<VarNode>();
     auto it = vmap_.find(op);
     if (it != vmap_.end()) op = it->second;
@@ -203,14 +203,14 @@ class IRDeepCompare :
       order_ = +1;
     }
   }
-  void VisitExpr_(const LoadNode* op, const Expr& other) final {
+  void VisitExpr_(const LoadNode* op, const PrimExpr& other) final {
     const LoadNode* rhs = other.as<LoadNode>();
     if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
     if (CompareExpr(op->index, rhs->index) != 0) return;
     if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
   }
 
-  void VisitExpr_(const LetNode* op, const Expr& other) final {
+  void VisitExpr_(const LetNode* op, const PrimExpr& other) final {
     const LetNode* rhs = other.as<LetNode>();
     if (tie_def_) {
       vmap_[op->var.get()] = rhs->var.get();
@@ -221,7 +221,7 @@ class IRDeepCompare :
     if (CompareExpr(op->body, rhs->body) != 0) return;
   }
 
-  void VisitExpr_(const CallNode* op, const Expr& other) final {
+  void VisitExpr_(const CallNode* op, const PrimExpr& other) final {
     const CallNode* rhs = other.as<CallNode>();
     if (CompareString(op->name, rhs->name)) return;
     if (CompareArray(op->args, rhs->args)) return;
@@ -230,7 +230,7 @@ class IRDeepCompare :
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
   }
 
-  void VisitExpr_(const ReduceNode *op, const Expr& other) final {
+  void VisitExpr_(const ReduceNode *op, const PrimExpr& other) final {
     const ReduceNode* rhs = other.as<ReduceNode>();
     if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
     if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
@@ -248,51 +248,51 @@ class IRDeepCompare :
     if (CompareArray(op->source, rhs->source) != 0) return;
   }
 
-  void VisitExpr_(const IntImmNode *op, const Expr& other) final {
+  void VisitExpr_(const IntImmNode *op, const PrimExpr& other) final {
     CompareValue(op->value, other.as<IntImmNode>()->value);
   }
 
-  void VisitExpr_(const UIntImmNode *op, const Expr& other) final {
+  void VisitExpr_(const UIntImmNode *op, const PrimExpr& other) final {
     CompareValue(op->value, other.as<UIntImmNode>()->value);
   }
 
-  void VisitExpr_(const FloatImmNode *op, const Expr& other) final {
+  void VisitExpr_(const FloatImmNode *op, const PrimExpr& other) final {
     CompareValue(op->value, other.as<FloatImmNode>()->value);
   }
 
-  void VisitExpr_(const StringImmNode *op, const Expr& other) final {
+  void VisitExpr_(const StringImmNode *op, const PrimExpr& other) final {
     CompareString(op->value, other.as<StringImmNode>()->value);
   }
 
-  void VisitExpr_(const CastNode *op, const Expr& other) final {
+  void VisitExpr_(const CastNode *op, const PrimExpr& other) final {
     CompareExpr(op->value, other.as<CastNode>()->value);
   }
 
-  void VisitExpr_(const NotNode *op, const Expr& other) final {
+  void VisitExpr_(const NotNode *op, const PrimExpr& other) final {
     CompareExpr(op->a, other.as<NotNode>()->a);
   }
 
-  void VisitExpr_(const SelectNode *op, const Expr& other) final {
+  void VisitExpr_(const SelectNode *op, const PrimExpr& other) final {
     const SelectNode* rhs = other.as<SelectNode>();
     if (CompareExpr(op->condition, rhs->condition) != 0) return;
     if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
     if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
   }
 
-  void VisitExpr_(const RampNode *op, const Expr& other) final {
+  void VisitExpr_(const RampNode *op, const PrimExpr& other) final {
     const RampNode* rhs = other.as<RampNode>();
     if (CompareExpr(op->base, rhs->base) != 0) return;
     if (CompareExpr(op->stride, rhs->stride) != 0) return;
     if (CompareValue(op->lanes, rhs->lanes) != 0) return;
   }
 
-  void VisitExpr_(const BroadcastNode *op, const Expr& other) final {
+  void VisitExpr_(const BroadcastNode *op, const PrimExpr& other) final {
     const BroadcastNode* rhs = other.as<BroadcastNode>();
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (CompareValue(op->lanes, rhs->lanes) != 0) return;
   }
 
-  void VisitExpr_(const ShuffleNode *op, const Expr& other) final {
+  void VisitExpr_(const ShuffleNode *op, const PrimExpr& other) final {
     const ShuffleNode* rhs = other.as<ShuffleNode>();
     if (CompareArray(op->vectors, rhs->vectors) != 0) return;
     if (CompareArray(op->indices, rhs->indices) != 0) return;
@@ -317,7 +317,7 @@ class IRDeepCompare :
   DEFINE_BIOP_EXPR_CMP_(OrNode)
 
  private:
-  int CompareExpr(const Expr& lhs, const Expr& rhs) {
+  int CompareExpr(const PrimExpr& lhs, const PrimExpr& rhs) {
     if (order_ != 0) return order_;
     if (!lhs.defined() && rhs.defined()) {
       order_ = -1; return order_;
@@ -341,7 +341,7 @@ class IRDeepCompare :
     return order_;
   }
 
-  int CompareArray(const Array<Expr>& lhs, const Array<Expr>& rhs) {
+  int CompareArray(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
     if (order_ != 0) return order_;
     if (CompareValue(lhs.size(), rhs.size()) != 0) return order_;
     for (size_t i = 0; i < lhs.size(); ++i) {
@@ -438,7 +438,7 @@ bool Equal(const Stmt& lhs, const Stmt& rhs) {
   return IRDeepCompare().Equal(lhs, rhs);
 }
 
-bool Equal(const Expr& lhs, const Expr& rhs) {
+bool Equal(const PrimExpr& lhs, const PrimExpr& rhs) {
   // quick pass for constant expressions.
   if (const int64_t *a = as_const_int(lhs)) {
     if (const int64_t *b = as_const_int(rhs)) {
@@ -455,7 +455,7 @@ bool Equal(const Expr& lhs, const Expr& rhs) {
   return IRDeepCompare().Equal(lhs, rhs);
 }
 
-int Compare(const Expr& lhs, const Expr& rhs) {
+int Compare(const PrimExpr& lhs, const PrimExpr& rhs) {
   return IRDeepCompare().Compare(lhs, rhs);
 }
 
index b7a7362..67acec6 100644 (file)
@@ -31,7 +31,7 @@ class IRApplyVisit :
  public:
   explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
 
-  void VisitExpr(const Expr& node) final {
+  void VisitExpr(const PrimExpr& node) final {
     if (visited_.count(node.get()) != 0) return;
     visited_.insert(node.get());
     ExprVisitor::VisitExpr(node);
@@ -57,7 +57,7 @@ void PostOrderVisit(const ObjectRef& node,
     visitor(Downcast<Stmt>(node));
   } else {
     IRApplyVisit visitor(fvisit);
-    visitor(Downcast<Expr>(node));
+    visitor(Downcast<PrimExpr>(node));
   }
 }
 
@@ -77,8 +77,8 @@ class IRTransformer final :
       return this->BaseVisitStmt(s);
     });
   }
-  Expr VisitExpr(const Expr& expr) final {
-    return MutateInternal<Expr>(expr, [this](const Expr& e) {
+  PrimExpr VisitExpr(const PrimExpr& expr) final {
+    return MutateInternal<PrimExpr>(expr, [this](const PrimExpr& e) {
       return this->BaseVisitExpr(e);
     });
   }
@@ -89,7 +89,7 @@ class IRTransformer final :
   Stmt BaseVisitStmt(const Stmt& s) {
     return StmtMutator::VisitStmt(s);
   }
-  Expr BaseVisitExpr(const Expr& e) {
+  PrimExpr BaseVisitExpr(const PrimExpr& e) {
     return ExprMutator::VisitExpr(e);
   }
 
@@ -120,9 +120,9 @@ class IRTransformer final :
 Stmt IRTransform(Stmt ir_node,
                  const runtime::PackedFunc& f_preorder,
                  const runtime::PackedFunc& f_postorder,
-                 const Array<Expr>& only_enable) {
+                 const Array<PrimExpr>& only_enable) {
   std::unordered_set<uint32_t> only_type_index;
-  for (Expr s : only_enable) {
+  for (PrimExpr s : only_enable) {
     only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
   }
   IRTransformer transform(f_preorder, f_postorder, only_type_index);
@@ -154,7 +154,7 @@ void StmtVisitor::VisitStmt_(const ForNode* op) {
 }
 
 void StmtVisitor::VisitStmt_(const AllocateNode* op) {
-  VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); });
+  VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
   this->VisitStmt(op->body);
   this->VisitExpr(op->condition);
   if (op->new_expr.defined()) {
@@ -189,7 +189,7 @@ void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) {
 }
 
 void StmtVisitor::VisitStmt_(const ProvideNode* op) {
-  VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
+  VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
   this->VisitExpr(op->value);
 }
 
@@ -232,7 +232,7 @@ void ExprVisitor::VisitExpr_(const LetNode* op) {
 }
 
 void ExprVisitor::VisitExpr_(const CallNode* op) {
-  VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
+  VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); });
 }
 
 #define DEFINE_BINOP_VISIT_(OP)                           \
@@ -269,7 +269,7 @@ void ExprVisitor::VisitExpr_(const ReduceNode* op) {
       this->VisitExpr(r->dom->min);
       this->VisitExpr(r->dom->extent);
     });
-  VisitArray(op->source, [this](const Expr& e) { this->VisitExpr(e); });
+  VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
   this->VisitExpr(op->condition);
 }
 
@@ -293,8 +293,8 @@ void ExprVisitor::VisitExpr_(const RampNode* op) {
 }
 
 void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
-  VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); });
-  VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); });
+  VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
+  VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); });
 }
 
 void ExprVisitor::VisitExpr_(const BroadcastNode* op) {
@@ -320,8 +320,8 @@ inline Array<T> MutateArray(const Array<T>& arr,
 
 class StmtMutator::Internal {
  public:
-  static Array<Expr> Mutate(StmtMutator* self, const Array<Expr>& arr) {
-    auto fmutate = [self](const Expr& e) { return self->VisitExpr(e); };
+  static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
+    auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
     return MutateArray(arr, fmutate, self->allow_copy_on_write_);
   }
 
@@ -332,8 +332,8 @@ class StmtMutator::Internal {
 
   static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
     auto fmutate = [self](const Range& r) {
-      Expr min = self->VisitExpr(r->min);
-      Expr extent = self->VisitExpr(r->extent);
+      PrimExpr min = self->VisitExpr(r->min);
+      PrimExpr extent = self->VisitExpr(r->extent);
       if (min.same_as(r->min) && extent.same_as(r->extent)) {
         return r;
       } else {
@@ -345,7 +345,7 @@ class StmtMutator::Internal {
 };
 
 Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   Stmt body = this->VisitStmt(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
@@ -359,7 +359,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   Stmt body = this->VisitStmt(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
@@ -373,8 +373,8 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const ForNode* op) {
-  Expr min = this->VisitExpr(op->min);
-  Expr extent = this->VisitExpr(op->extent);
+  PrimExpr min = this->VisitExpr(op->min);
+  PrimExpr extent = this->VisitExpr(op->extent);
   Stmt body = this->VisitStmt(op->body);
   if (min.same_as(op->min) &&
       extent.same_as(op->extent) &&
@@ -390,10 +390,10 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
-  Array<Expr> extents = Internal::Mutate(this, op->extents);
+  Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
   Stmt body = this->VisitStmt(op->body);
-  Expr condition = this->VisitExpr(op->condition);
-  Expr new_expr;
+  PrimExpr condition = this->VisitExpr(op->condition);
+  PrimExpr new_expr;
   if (op->new_expr.defined()) {
     new_expr = this->VisitExpr(op->new_expr);
   }
@@ -413,7 +413,7 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
-  Expr condition = this->VisitExpr(op->condition);
+  PrimExpr condition = this->VisitExpr(op->condition);
   Stmt then_case = this->VisitStmt(op->then_case);
   Stmt else_case;
   if (op->else_case.defined()) {
@@ -433,9 +433,9 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
-  Expr value = this->VisitExpr(op->value);
-  Expr index = this->VisitExpr(op->index);
-  Expr predicate = this->VisitExpr(op->predicate);
+  PrimExpr value = this->VisitExpr(op->value);
+  PrimExpr index = this->VisitExpr(op->index);
+  PrimExpr predicate = this->VisitExpr(op->predicate);
   if (value.same_as(op->value) &&
       index.same_as(op->index) &&
       predicate.same_as(op->predicate)) {
@@ -450,8 +450,8 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
-  Array<Expr> args = Internal::Mutate(this, op->args);
-  Expr value = this->VisitExpr(op->value);
+  Array<PrimExpr> args = Internal::Mutate(this, op->args);
+  PrimExpr value = this->VisitExpr(op->value);
   if (args.same_as(op->args) &&
       value.same_as(op->value)) {
     return GetRef<Stmt>(op);
@@ -466,7 +466,7 @@ Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
 Stmt StmtMutator::VisitStmt_(const RealizeNode* op) {
   Region bounds = Internal::Mutate(this, op->bounds);
   Stmt body = this->VisitStmt(op->body);
-  Expr condition = this->VisitExpr(op->condition);
+  PrimExpr condition = this->VisitExpr(op->condition);
   if (bounds.same_as(op->bounds) &&
       body.same_as(op->body) &&
       condition.same_as(op->condition)) {
@@ -549,8 +549,8 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op,
 }
 
 Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
-  Expr condition = this->VisitExpr(op->condition);
-  Expr message = this->VisitExpr(op->message);
+  PrimExpr condition = this->VisitExpr(op->condition);
+  PrimExpr message = this->VisitExpr(op->message);
   Stmt body = this->VisitStmt(op->body);
 
   if (condition.same_as(op->condition) &&
@@ -578,7 +578,7 @@ Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
-  Expr value = this->VisitExpr(op->value);
+  PrimExpr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
     return GetRef<Stmt>(op);
   } else {
@@ -593,37 +593,37 @@ Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
 }
 
 
-Expr ExprMutator::VisitExpr_(const VarNode* op) {
-  return GetRef<Expr>(op);
+PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
+  return GetRef<PrimExpr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const LoadNode* op) {
-  Expr index = this->VisitExpr(op->index);
-  Expr predicate = this->VisitExpr(op->predicate);
+PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
+  PrimExpr index = this->VisitExpr(op->index);
+  PrimExpr predicate = this->VisitExpr(op->predicate);
   if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return LoadNode::make(op->dtype, op->buffer_var, index, predicate);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const LetNode* op) {
-  Expr value = this->VisitExpr(op->value);
-  Expr body = this->VisitExpr(op->body);
+PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
+  PrimExpr value = this->VisitExpr(op->value);
+  PrimExpr body = this->VisitExpr(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return LetNode::make(op->var, value, body);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const CallNode* op) {
-  auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); };
-  Array<Expr> args = MutateArray(op->args, fmutate);
+PrimExpr ExprMutator::VisitExpr_(const CallNode* op) {
+  auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+  Array<PrimExpr> args = MutateArray(op->args, fmutate);
 
   if (args.same_as(op->args)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return CallNode::make(op->dtype,
                       op->name,
@@ -635,8 +635,8 @@ Expr ExprMutator::VisitExpr_(const CallNode* op) {
 }
 
 #define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP)                    \
-  Expr ExprMutator::VisitExpr_(const OP *op) {                    \
-    return GetRef<Expr>(op);                                      \
+  PrimExpr ExprMutator::VisitExpr_(const OP *op) {                    \
+    return GetRef<PrimExpr>(op);                                      \
   }
 
 DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
@@ -645,12 +645,12 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
 DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
 
 #define DEFINE_BIOP_EXPR_MUTATE_(OP)                                    \
-  Expr ExprMutator::VisitExpr_(const OP* op) {                          \
-    Expr a = this->VisitExpr(op->a);                                    \
-    Expr b = this->VisitExpr(op->b);                                    \
+  PrimExpr ExprMutator::VisitExpr_(const OP* op) {                          \
+    PrimExpr a = this->VisitExpr(op->a);                                    \
+    PrimExpr b = this->VisitExpr(op->b);                                    \
     if (a.same_as(op->a) &&                                             \
         b.same_as(op->b)) {                                             \
-      return GetRef<Expr>(op);                                          \
+      return GetRef<PrimExpr>(op);                                          \
     } else {                                                            \
       return OP::make(a, b);                                            \
     }                                                                   \
@@ -674,11 +674,11 @@ DEFINE_BIOP_EXPR_MUTATE_(GENode);
 DEFINE_BIOP_EXPR_MUTATE_(AndNode);
 DEFINE_BIOP_EXPR_MUTATE_(OrNode);
 
-Expr ExprMutator::VisitExpr_(const ReduceNode* op) {
+PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
   auto fitervar =  [this](const IterVar& v) {
     Range r = v->dom;
-    Expr min = this->VisitExpr(r->min);
-    Expr extent = this->VisitExpr(r->extent);
+    PrimExpr min = this->VisitExpr(r->min);
+    PrimExpr extent = this->VisitExpr(r->extent);
     if (min.same_as(r->min) &&
         extent.same_as(r->extent)) {
       return v;
@@ -690,77 +690,77 @@ Expr ExprMutator::VisitExpr_(const ReduceNode* op) {
   };
   Array<IterVar> axis = MutateArray(op->axis, fitervar);
 
-  auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
-  Array<Expr> source = MutateArray(op->source, fexpr);
+  auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
+  Array<PrimExpr> source = MutateArray(op->source, fexpr);
 
-  Expr condition = this->VisitExpr(op->condition);
+  PrimExpr condition = this->VisitExpr(op->condition);
 
   if (axis.same_as(op->axis) &&
       source.same_as(op->source) &&
       condition.same_as(op->condition)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return ReduceNode::make(
       op->combiner, source, axis, condition, op->value_index);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const CastNode* op) {
-  Expr value = this->VisitExpr(op->value);
+PrimExpr ExprMutator::VisitExpr_(const CastNode* op) {
+  PrimExpr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return CastNode::make(op->dtype, value);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const NotNode* op) {
-  Expr a = this->VisitExpr(op->a);
+PrimExpr ExprMutator::VisitExpr_(const NotNode* op) {
+  PrimExpr a = this->VisitExpr(op->a);
   if (a.same_as(op->a)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return NotNode::make(a);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const SelectNode* op) {
-  Expr condition = this->VisitExpr(op->condition);
-  Expr true_value = this->VisitExpr(op->true_value);
-  Expr false_value = this->VisitExpr(op->false_value);
+PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) {
+  PrimExpr condition = this->VisitExpr(op->condition);
+  PrimExpr true_value = this->VisitExpr(op->true_value);
+  PrimExpr false_value = this->VisitExpr(op->false_value);
   if (condition.same_as(op->condition) &&
       true_value.same_as(op->true_value) &&
       false_value.same_as(op->false_value)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return SelectNode::make(condition, true_value, false_value);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const RampNode* op) {
-  Expr base = this->VisitExpr(op->base);
-  Expr stride = this->VisitExpr(op->stride);
+PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
+  PrimExpr base = this->VisitExpr(op->base);
+  PrimExpr stride = this->VisitExpr(op->stride);
   if (base.same_as(op->base) &&
       stride.same_as(op->stride)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return RampNode::make(base, stride, op->lanes);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const BroadcastNode* op) {
-  Expr value = this->VisitExpr(op->value);
+PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
+  PrimExpr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return BroadcastNode::make(value, op->lanes);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const ShuffleNode* op) {
-  auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
+PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
+  auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
   auto vectors = MutateArray(op->vectors, fexpr);
   if (vectors.same_as(op->vectors)) {
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   } else {
     return ShuffleNode::make(vectors, op->indices);
   }
index 74d5781..f1a0195 100644 (file)
@@ -81,10 +81,10 @@ inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
  * \param kind The data kind.
  * \return the get expression.
  */
-inline Expr TVMStructGet(
+inline PrimExpr TVMStructGet(
     DataType dtype, Var handle, int index,
     intrinsic::TVMStructFieldKind kind) {
-  Array<Expr> args ={
+  Array<PrimExpr> args ={
     handle,
     make_const(DataType::Int(32), index),
     make_const(DataType::Int(32), static_cast<int>(kind))};
@@ -97,7 +97,7 @@ inline Expr TVMStructGet(
  * \param dtype The data type.
  * \param offset the offset index.
  */
-inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
+inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) {
   return CallNode::make(
       DataType::Handle(), intrinsic::tvm_address_of,
       {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
@@ -111,7 +111,7 @@ inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
  * \param dtype The data type.
  * \param offset the offset index.
  */
-inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
+inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) {
   if (dtype.lanes() != 1) {
     offset = offset * make_const(offset.dtype(), dtype.lanes());
     offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
@@ -133,8 +133,8 @@ inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
  */
 inline Stmt TVMStructSet(
     Var handle, int index,
-    intrinsic::TVMStructFieldKind kind, Expr value) {
-  Array<Expr> args ={
+    intrinsic::TVMStructFieldKind kind, PrimExpr value) {
+  Array<PrimExpr> args ={
     handle,
     make_const(DataType::Int(32), index),
     make_const(DataType::Int(32), static_cast<int>(kind)),
@@ -182,7 +182,7 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
  * \param base The result base.
  * \return true if pattern match success and store the base to base.
  */
-inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
+inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
   const RampNode* r = index.as<RampNode>();
   if (!r) return false;
   if (!is_one(r->stride)) return false;
index 9a97031..7b760fa 100644 (file)
@@ -55,7 +55,7 @@ class AttrScopeLifter : public StmtMutator {
           attr_node_, attr_key_, attr_value_, op->body);
       // undefine them
       attr_node_ = ObjectRef();
-      attr_value_ = Expr();
+      attr_value_ = PrimExpr();
       return AllocateNode::make(
         op->buffer_var, op->dtype,
         op->extents, op->condition, body,
@@ -78,11 +78,11 @@ class AttrScopeLifter : public StmtMutator {
   Stmt VisitStmt_(const SeqStmtNode* op) final {
     // remember the decorations.
     std::vector<ObjectRef> attr_node;
-    std::vector<Expr> attr_value;
+    std::vector<PrimExpr> attr_value;
 
     auto fmutate = [&](const Stmt& s) {
       attr_node_ = ObjectRef();
-      attr_value_ = Expr();
+      attr_value_ = PrimExpr();
       Stmt ret = this->VisitStmt(s);
       attr_node.push_back(attr_node_);
       attr_value.push_back(attr_value_);
@@ -123,7 +123,7 @@ class AttrScopeLifter : public StmtMutator {
       begin = end;
     }
     attr_node_ = ObjectRef();
-    attr_value_ = Expr();
+    attr_value_ = PrimExpr();
     return SeqStmt::Flatten(reorg);
   }
 
@@ -133,7 +133,7 @@ class AttrScopeLifter : public StmtMutator {
     }
     Stmt then_case = this->VisitStmt(op->then_case);
     ObjectRef first_node;
-    Expr first_value;
+    PrimExpr first_value;
     std::swap(first_node, attr_node_);
     std::swap(first_value, attr_value_);
     Stmt else_case = this->VisitStmt(op->else_case);
@@ -159,7 +159,7 @@ class AttrScopeLifter : public StmtMutator {
             attr_node_, attr_key_, attr_value_, else_case);
         // undefine them
         attr_node_ = ObjectRef();
-        attr_value_ = Expr();
+        attr_value_ = PrimExpr();
       }
       if (then_case.same_as(op->then_case) &&
           else_case.same_as(op->else_case)) {
@@ -172,7 +172,7 @@ class AttrScopeLifter : public StmtMutator {
 
  private:
   // value comparison that also compares content of int constant
-  static bool ValueSame(const Expr& a, const Expr& b) {
+  static bool ValueSame(const PrimExpr& a, const PrimExpr& b) {
     if (a.same_as(b)) return true;
     if (!a.defined() || !b.defined()) return false;
     if (a->type_index() != b->type_index()) return false;
@@ -188,7 +188,7 @@ class AttrScopeLifter : public StmtMutator {
 
   std::string attr_key_;
   ObjectRef attr_node_;
-  Expr attr_value_;
+  PrimExpr attr_value_;
 };
 
 Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
index 7d9ce62..adcd5ec 100644 (file)
@@ -49,7 +49,7 @@ struct PartitionKeyHash {
 // condition cond is proven to have value cond_value (true or false) in interval.
 using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
 
-bool ExprUseVars(Expr expr, const std::unordered_set<const VarNode*>& vars) {
+bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
   bool success = false;
   PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
     if (const VarNode* v = node.as<VarNode>()) {
@@ -152,7 +152,7 @@ class CandidateSelector final : public StmtExprVisitor {
 // (currently, "likely" conditions) has fixed true or false value
 class PartitionFinder : public StmtExprVisitor {
  public:
-  explicit PartitionFinder(VarExpr current_var,
+  explicit PartitionFinder(Var current_var,
     const std::unordered_map<const VarNode*, IntSet>& hint_map,
     const std::unordered_map<const VarNode*, IntSet>& relax_map)
       : current_var_(current_var), hint_map_(hint_map),  relax_map_(relax_map) {
@@ -194,7 +194,7 @@ class PartitionFinder : public StmtExprVisitor {
 
   void VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(CallNode::likely)) {
-      Expr cond = op->args[0];
+      PrimExpr cond = op->args[0];
       if (ExprUseVars(cond,
           std::unordered_set<const VarNode*>({current_var_.get()}))) {
         // For cond, find out the interval, if exists, in which we can prove that cond is
@@ -206,7 +206,7 @@ class PartitionFinder : public StmtExprVisitor {
           // cond is true within interval
           partitions[{cond.get(), true}] = interval;
         }
-        Expr inverse_cond = InverseCond(cond);
+        PrimExpr inverse_cond = InverseCond(cond);
         if (inverse_cond.defined()) {
           IntSet interval =
                   DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
@@ -224,8 +224,8 @@ class PartitionFinder : public StmtExprVisitor {
   Partition partitions;
 
  private:
-  Expr InverseCond(const Expr& cond) {
-    Expr inverse_cond;
+  PrimExpr InverseCond(const PrimExpr& cond) {
+    PrimExpr inverse_cond;
     if (const LTNode* op = cond.as<LTNode>()) {
       // a < b -> a >= b
       inverse_cond = GENode::make(op->a, op->b);
@@ -248,7 +248,7 @@ class PartitionFinder : public StmtExprVisitor {
     return inverse_cond;
   }
 
-  VarExpr current_var_;
+  Var current_var_;
   std::unordered_set<const VarNode*> out_vars_;
   std::unordered_map<const VarNode*, IntSet> hint_map_;
   std::unordered_map<const VarNode*, IntSet> relax_map_;
@@ -260,7 +260,7 @@ class ConditionEliminator : public StmtExprMutator {
   explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
     : ps_(ps), cond_value_(cond_value) {}
 
-  Expr VisitExpr(const Expr& e) final {
+  PrimExpr VisitExpr(const PrimExpr& e) final {
     if (ps_.find(e.get()) != ps_.end()) {
       return VisitExpr(cond_value_ ? const_true() : const_false());
     }
@@ -277,7 +277,7 @@ class ConditionEliminator : public StmtExprMutator {
 class ThreadPartitionInserter : public StmtMutator {
  public:
   explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
-    Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
+    PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
@@ -287,7 +287,7 @@ class ThreadPartitionInserter : public StmtMutator {
       if (innermost_thread_scope_) {
         Stmt simplified_body = ConditionEliminator(ps_)(op->body);
         Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
-        Expr value = this->VisitExpr(op->value);
+        PrimExpr value = this->VisitExpr(op->value);
         stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
       }
       innermost_thread_scope_ = false;
@@ -299,7 +299,7 @@ class ThreadPartitionInserter : public StmtMutator {
 
  private:
   const std::unordered_set<const Object*>& ps_;
-  Expr cond_;
+  PrimExpr cond_;
   bool innermost_thread_scope_;
 };
 
@@ -363,15 +363,15 @@ class LoopPartitioner : public StmtMutator {
   }
 
  private:
-  Stmt TryPartition(const Object* op, const Stmt& stmt, VarExpr var,
-      Expr min, Expr max, Stmt body, bool partition_thread_scope);
+  Stmt TryPartition(const Object* op, const Stmt& stmt, Var var,
+      PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope);
 
   std::pair<IntSet, std::unordered_set<const Object*>>
   GetIntervalAndCondset(const Partition &partitions,
                         const arith::IntervalSet &for_interval,
                         bool cond_value);
 
-  inline Stmt MakeFor(const Object* op, Expr extent, Stmt body);
+  inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
 
   /* Candidate IRs that may be partitioned potentially */
   std::unordered_map<const VarNode*, IntSet> hint_map_;
@@ -452,9 +452,9 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
  */
 Stmt LoopPartitioner::TryPartition(const Object* node,
                                    const Stmt& stmt,
-                                   VarExpr var,
-                                   Expr min,
-                                   Expr max,
+                                   Var var,
+                                   PrimExpr min,
+                                   PrimExpr max,
                                    Stmt body,
                                    bool partition_thread_scope) {
   using namespace arith;
@@ -496,13 +496,13 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
 
   // Calculating pre-subrange and generating code for it.
   // pre-subrange = [min, body_begin)
-  Expr body_begin;
+  PrimExpr body_begin;
   Stmt pre_stmt;
   bool pre_stmt_recurse = true;
   if (middle_interval_i->HasLowerBound()) {
     body_begin = ir::Simplify(middle_interval.min());
     if (!analyzer_.CanProve(body_begin == min)) {
-      Expr cond = (body_begin - min >= 0);
+      PrimExpr cond = (body_begin - min >= 0);
       if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the pre doubt loop";
@@ -521,14 +521,14 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
 
   // Calculating post-subrange and generating code for it.
   // post-subrange = [post_doubt_begin, max+1)
-  Expr post_doubt_begin;
+  PrimExpr post_doubt_begin;
   Stmt post_stmt;
   bool post_stmt_recurse = true;
   if (middle_interval_i->HasUpperBound()) {
     post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
     if (!analyzer_.CanProve(middle_interval.max() == max)) {
       // require the extent to be non-negative
-      Expr cond = (max - post_doubt_begin + 1 >= 0);
+      PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
       if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the post doubt loop";
@@ -571,7 +571,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
     }
     s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
   } else {
-    Expr cond = const_true();
+    PrimExpr cond = const_true();
     if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
     if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
     s = ThreadPartitionInserter(cond_set, cond)(stmt);
@@ -580,7 +580,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
   return s;
 }
 
-inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) {
+inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) {
   const ForNode *for_node = static_cast<const ForNode*>(node);
   CHECK(for_node);
   if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
@@ -594,7 +594,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body)
 
 class RemoveLikelyTags : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const CallNode *op) final {
+  PrimExpr VisitExpr_(const CallNode *op) final {
     if (op->is_intrinsic(CallNode::likely)) {
       CHECK_EQ(op->args.size(), 1);
       return StmtExprMutator::VisitExpr(op->args[0]);
index ded17d4..98eaf8c 100644 (file)
@@ -41,13 +41,13 @@ class CustomDatatypesLowerer : public StmtExprMutator {
  public:
   explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
 
-  inline Expr VisitExpr_(const CastNode* op) final {
+  inline PrimExpr VisitExpr_(const CastNode* op) final {
     auto type_code = op->dtype.code();
     auto src_type_code = op->value.dtype().code();
     // If either datatype is a registered custom datatype, we must lower.
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
                        datatype::Registry::Global()->GetTypeRegistered(src_type_code);
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CastNode>();
     if (toBeLowered) {
       auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
@@ -59,9 +59,9 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
-  inline Expr VisitExpr_(const FloatImmNode* imm) final {
+  inline PrimExpr VisitExpr_(const FloatImmNode* imm) final {
     auto type_code = imm->dtype.code();
-    auto e = GetRef<Expr>(imm);
+    auto e = GetRef<PrimExpr>(imm);
     if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
       auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
       CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
@@ -85,9 +85,9 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return stmt;
   }
 
-  inline Expr VisitExpr_(const LoadNode* load) final {
+  inline PrimExpr VisitExpr_(const LoadNode* load) final {
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
-    Expr expr = StmtExprMutator::VisitExpr_(load);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(load);
     load = expr.as<LoadNode>();
     if (toBeLowered) {
       auto new_load_type = DataType::UInt(load->dtype.bits());
@@ -97,10 +97,10 @@ class CustomDatatypesLowerer : public StmtExprMutator {
   }
 
 #define DEFINE_MUTATE__(OP, NodeName)                                              \
-  inline Expr VisitExpr_(const NodeName* op) final {                                     \
+  inline PrimExpr VisitExpr_(const NodeName* op) final {                                     \
     auto type_code = op->dtype.code();                                             \
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
-    Expr expr = StmtExprMutator::VisitExpr_(op);                                   \
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);                                   \
     op = expr.as<NodeName>();                                                            \
     if (toBeLowered) {                                                             \
       auto lower = datatype::Get##OP##LowerFunc(target_, type_code);               \
index b46bf18..ed8be8b 100644 (file)
@@ -53,16 +53,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     }
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->call_type == CallNode::Intrinsic ||
         op->call_type == CallNode::PureIntrinsic) {
-      Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
+      PrimExpr r = ApplyPattern(op->name, GetRef<PrimExpr>(op));
       if (r.defined()) return r;
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const AddNode* op) final {
+  PrimExpr VisitExpr_(const AddNode* op) final {
     if (const MulNode* mb = op->b.as<MulNode>()) {
       return MakeFMA(mb->a, mb->b, op->a, op);
     } else if (const MulNode* ma = op->a.as<MulNode>()) {
@@ -73,9 +73,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
 
   // We use floordiv for integer analysis,
   // but will need to lower them to native truncdiv instructions
-  Expr VisitExpr_(const FloorDivNode* op) final {
-    auto e = GetRef<Expr>(op);
-    Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr VisitExpr_(const FloorDivNode* op) final {
+    auto e = GetRef<PrimExpr>(op);
+    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
     op = ret.as<FloorDivNode>();
     if (op == nullptr) return ret;
     int shift;
@@ -95,8 +95,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         return truncdiv(op->a, op->b);
       } else {
         DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
-        Expr rdiv = truncdiv(op->a, op->b);
-        Expr rmod = truncmod(op->a, op->b);
+        PrimExpr rdiv = truncdiv(op->a, op->b);
+        PrimExpr rmod = truncmod(op->a, op->b);
         // condition on b >= 0.
         // truncmod(a, b) < 0 will implies ceildiv,
         // So we need to correct these cases.
@@ -112,16 +112,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
       DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
       // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
       // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
-      Expr rdiv = truncdiv(op->a, op->b);
-      Expr rmod = truncmod(op->a, op->b);
+      PrimExpr rdiv = truncdiv(op->a, op->b);
+      PrimExpr rmod = truncmod(op->a, op->b);
       return ir::SelectNode::make(
           (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
           rdiv, rdiv - make_const(dtype, 1));
     }
   }
 
-  Expr VisitExpr_(const FloorModNode* op) final {
-    Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
+  PrimExpr VisitExpr_(const FloorModNode* op) final {
+    PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
     op = ret.as<FloorModNode>();
     if (op == nullptr) return ret;
     // Lower floordiv to native truncdiv.
@@ -146,7 +146,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         // NOTE:condition on b >= 0.
         // mod(a, b) < 0 will imply we are doing ceildiv,
         // So we need to correct these cases.
-        Expr rmod = truncmod(op->a, op->b);
+        PrimExpr rmod = truncmod(op->a, op->b);
         if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
           // (rmod >> shift) & b
           // -> (rmod >= 0 ? 0: -1) & b
@@ -159,7 +159,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     } else {
       // uncommon case
       DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
-      Expr rmod = truncmod(op->a, op->b);
+      PrimExpr rmod = truncmod(op->a, op->b);
       // b > 0 && rmod >= 0 -> rmod
       // b > 0 && rmod < 0  -> rmod + b
       // b < 0 && rmod < 0 -> rmod
@@ -170,11 +170,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     }
   }
 
-  Expr VisitExpr_(const MaxNode* op) final {
+  PrimExpr VisitExpr_(const MaxNode* op) final {
     using namespace arith;
-    PVar<Expr> x, y;
+    PVar<PrimExpr> x, y;
     PVar<Integer> c;
-    auto e = GetRef<Expr>(op);
+    auto e = GetRef<PrimExpr>(op);
     if (max(floordiv(x, y), c).Match(e) &&
         c.Eval()->value >= 0 &&
         analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
@@ -183,20 +183,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const EQNode* op) final {
+  PrimExpr VisitExpr_(const EQNode* op) final {
     using namespace arith;
-    PVar<Expr> x, y;
-    auto e = GetRef<Expr>(op);
+    PVar<PrimExpr> x, y;
+    auto e = GetRef<PrimExpr>(op);
     if ((floormod(x, y) == 0).Match(e)) {
       return VisitExpr((truncmod(x, y) == 0).Eval());
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const NENode* op) final {
+  PrimExpr VisitExpr_(const NENode* op) final {
     using namespace arith;
-    PVar<Expr> x, y;
-    auto e = GetRef<Expr>(op);
+    PVar<PrimExpr> x, y;
+    auto e = GetRef<PrimExpr>(op);
     if ((floormod(x, y) != 0).Match(e)) {
       return VisitExpr((truncmod(x, y) != 0).Eval());
     }
@@ -204,7 +204,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   }
 
  private:
-  Expr SwapBroadcastCast(const Expr& e) {
+  PrimExpr SwapBroadcastCast(const PrimExpr& e) {
     // Try to change broadcast(cast(x)) to cast(broadcast(x))
     // For some targets, LLVM will generate more efficient FMA
     // instruction with the latter. For example, vmla vs. vmlal
@@ -228,7 +228,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         };
 
         if (should_swap()) {
-          Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
+          PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
           return CastNode::make(bcast->dtype, new_bcast);
         }
       }
@@ -236,26 +236,26 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     return e;
   }
 
-  Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
+  PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c,
                const AddNode* op) {
     // emit fma instruction: a * b + c
-    Expr lhs = SwapBroadcastCast(a);
-    Expr rhs = SwapBroadcastCast(b);
+    PrimExpr lhs = SwapBroadcastCast(a);
+    PrimExpr rhs = SwapBroadcastCast(b);
 
     if (fma_ != nullptr && op->dtype.is_float()) {
-      Expr r = (*fma_)(CallNode::make(
+      PrimExpr r = (*fma_)(CallNode::make(
           op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
       if (r.defined()) return this->VisitExpr(r);
     } else {
       if (!lhs.same_as(a) || !rhs.same_as(b)) {
-        Expr mul = this->VisitExpr(MulNode::make(lhs, rhs));
+        PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs));
         return AddNode::make(mul, this->VisitExpr(c));
       }
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr ApplyPattern(const std::string& name, const Expr& e) {
+  PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) {
     for (size_t i = 0; i < patterns_.size(); ++i) {
       std::string& p = patterns_[i];
       size_t psize = p.length();
@@ -265,14 +265,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
       p.resize(psize);
       // if pattern exists.
       if (f != nullptr) {
-        Expr r = (*f)(e);
+        PrimExpr r = (*f)(e);
         CHECK(r.defined()) << "intrinsic rule must always return valid Expr";
         if (!r.same_as(e)) {
           return this->VisitExpr(r);
         }
       }
     }
-    return Expr();
+    return PrimExpr();
   }
 
   // patterns
index d38d1da..a0b07c2 100644 (file)
@@ -93,7 +93,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return stmt;
     }
   }
-  Expr VisitExpr_(const LoadNode* op) final {
+  PrimExpr VisitExpr_(const LoadNode* op) final {
     auto it = load_remap_.find(op->buffer_var.get());
     if (it != load_remap_.end()) {
       CHECK(is_zero(op->index));
@@ -123,10 +123,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
     CHECK(size_of_args) << call->args[0]->GetTypeKey();
     CHECK_EQ(size, size_of_args->value);
-    Array<Expr> inits = combiner->identity_element;
-    std::vector<Expr> values(size);
+    Array<PrimExpr> inits = combiner->identity_element;
+    std::vector<PrimExpr> values(size);
     std::vector<DataType> types(size);
-    Expr cond  = call->args[size+1];
+    PrimExpr cond  = call->args[size+1];
     for (size_t idx = 0; idx < size; ++idx) {
       values[idx] = call->args[1+idx];
       if (!is_one(cond)) {
@@ -175,13 +175,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // the size of each index.
     int reduce_extent, group_extent;
     int threadx_extent = 1;
-    Expr reduce_index = FlattenThread(vred, &reduce_extent);
-    Expr group_index = FlattenThread(vpar, &group_extent);
+    PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
+    PrimExpr group_index = FlattenThread(vpar, &group_extent);
     if (reduce_extent == 1) {
       // special case, no reduction is needed.
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
-        Expr pred = const_true(types[i].lanes());
+        PrimExpr pred = const_true(types[i].lanes());
         Var buffer_var = Downcast<Var>(call->args[2+size+i]);
         stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
       }
@@ -198,7 +198,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     seq.emplace_back(SyncThread("shared"));
     for (size_t idx = 0; idx < size; ++idx) {
       shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
-      Expr pred = const_true(types[idx].lanes());
+      PrimExpr pred = const_true(types[idx].lanes());
       seq.emplace_back(StoreNode::make(
           shared_bufs[idx], values[idx],
           BufIndex(reduce_index, group_index, reduce_extent), pred));
@@ -209,13 +209,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         reduce_index, group_index, reduce_extent, threadx_extent));
     for (size_t idx = 0; idx < size; ++idx) {
       CHECK(!load_remap_.count(buffers[idx]));
-      Expr pred = const_true(types[idx].lanes());
+      PrimExpr pred = const_true(types[idx].lanes());
       load_remap_[buffers[idx]] = LoadNode::make(
         types[idx], shared_bufs[idx],
         BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
       alloc_remap_[buffers[idx]] = AllocateNode::make(
         shared_bufs[idx], types[idx],
-        {Expr(group_extent), Expr(reduce_extent)},
+        {PrimExpr(group_extent), PrimExpr(reduce_extent)},
         pred, EvaluateNode::make(0));
     }
     return SeqStmt::Flatten(seq);
@@ -224,8 +224,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   Stmt MakeBufAllreduce(const CommReducerNode *combiner,
                         const std::vector<DataType>& types,
                         const Array<Var>& shared_bufs,
-                        Expr reduce_index,
-                        Expr group_index,
+                        PrimExpr reduce_index,
+                        PrimExpr group_index,
                         int reduce_extent,
                         int threadx_extent) {
     // Get next power of two
@@ -237,17 +237,17 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     std::vector<Stmt> seq;
 
     size_t size = shared_bufs.size();
-    Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
+    PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
     // make reduction
     auto freduce = [&](int offset) {
-      Array<Expr> a, b;
+      Array<PrimExpr> a, b;
       for (size_t i = 0; i < size; ++i) {
         b.push_back(LoadNode::make(types[i], shared_bufs[i],
           BufIndex(reduce_index + offset, group_index, reduce_extent),
           const_true()));
         a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true()));
       }
-      Array<Expr> ret = (*combiner)(a, b);
+      Array<PrimExpr> ret = (*combiner)(a, b);
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
         stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
@@ -258,7 +258,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     if (reduce_align > reduce_extent) {
       // reduction with the boundary condition
       reduce_align = reduce_align >> 1;
-      Expr cond = reduce_index < (reduce_extent - reduce_align);
+      PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
       seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
@@ -267,13 +267,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     while (reduce_align > threadx_extent ||
            reduce_align > warp_size_) {
       reduce_align =  reduce_align >> 1;
-      Expr cond = reduce_index < reduce_align;
+      PrimExpr cond = reduce_index < reduce_align;
       seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
     // in warp synchronization.
     std::vector<Stmt> in_warp_seq;
-    Expr in_warp_cond = reduce_index < (reduce_align >> 1);
+    PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
     while (reduce_align > 1) {
       reduce_align = reduce_align >> 1;
       in_warp_seq.emplace_back(freduce(reduce_align));
@@ -288,7 +288,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   }
   // Flatten the thread index.
   // Also return a warp number,
-  Expr FlattenThread(const std::vector<ThreadEntry>& tvec,
+  PrimExpr FlattenThread(const std::vector<ThreadEntry>& tvec,
                      int* out_total_extent) {
     int& total_extent = *out_total_extent;
     total_extent = 1;
@@ -296,7 +296,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return make_zero(DataType::Int(32));
     }
 
-    Expr ret;
+    PrimExpr ret;
     for (const ThreadEntry& e : tvec) {
       if (ret.defined()) {
         ret = ret + e.iv->var * total_extent;
@@ -316,7 +316,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
                    CallNode::Intrinsic));
   }
   // The local buffer index.
-  static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
+  static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) {
     if (!is_zero(group_index)) {
       return ir::Simplify(group_index * reduce_extent + reduce_index);
     } else {
@@ -330,7 +330,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   std::vector<const AttrStmtNode*> thread_extents_;
   std::vector<const CommReducerNode*> reduce_combiner_;
   // The load remap
-  std::unordered_map<const VarNode *, Expr> load_remap_;
+  std::unordered_map<const VarNode *, PrimExpr> load_remap_;
   // Allocate remap
   std::unordered_map<const VarNode *, Stmt> alloc_remap_;
 };
index a9b401f..8e7f1d8 100644 (file)
 namespace tvm {
 namespace ir {
 
-inline Expr ConstInt32(size_t index) {
+inline PrimExpr ConstInt32(size_t index) {
   CHECK_LE(index, std::numeric_limits<int>::max());
   return make_const(DataType::Int(32), static_cast<int>(index));
 }
 
-inline Expr StackAlloca(std::string type, size_t num) {
-  Array<Expr> args = {StringImmNode::make(type), ConstInt32(num)};
+inline PrimExpr StackAlloca(std::string type, size_t num) {
+  Array<PrimExpr> args = {StringImmNode::make(type), ConstInt32(num)};
   return CallNode::make(
       DataType::Handle(),
       intrinsic::tvm_stack_alloca,
@@ -103,7 +103,7 @@ class BuiltinLower : public StmtExprMutator {
         }
       }
     }
-    Expr total_bytes = make_const(op->extents[0].dtype(), nbytes);
+    PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
     for (size_t i = 0; i < op->extents.size(); ++i) {
       total_bytes = total_bytes * op->extents[i];
     }
@@ -134,7 +134,7 @@ class BuiltinLower : public StmtExprMutator {
                        CallNode::Extern),
         body);
 
-    Expr free_op = CallNode::make(DataType::Int(32),
+    PrimExpr free_op = CallNode::make(DataType::Int(32),
                                   "TVMBackendFreeWorkspace",
                                   {cast(DataType::Int(32), device_type_),
                                    cast(DataType::Int(32), device_id_),
@@ -163,7 +163,7 @@ class BuiltinLower : public StmtExprMutator {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
       return MakeCallPacked(op);
     } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
@@ -179,10 +179,10 @@ class BuiltinLower : public StmtExprMutator {
     }
   }
   // call shape
-  Expr MakeShape(const CallNode* op) {
+  PrimExpr MakeShape(const CallNode* op) {
     size_t stack_begin = run_shape_stack_;
     run_shape_stack_ += op->args.size();
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     for (size_t i = 0; i < op->args.size(); ++i) {
       prep_seq_.emplace_back(
@@ -192,16 +192,16 @@ class BuiltinLower : public StmtExprMutator {
     return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
   }
   // make array
-  Expr MakeArray(const CallNode* op) {
+  PrimExpr MakeArray(const CallNode* op) {
     size_t idx = run_array_stack_;
     run_array_stack_ += 1;
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     prep_seq_.emplace_back(
         TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
     prep_seq_.emplace_back(
         TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
-    Expr strides = op->args[2];
+    PrimExpr strides = op->args[2];
     if (!strides.defined() || is_zero(strides)) {
       strides = make_zero(DataType::Handle());
     }
@@ -221,7 +221,7 @@ class BuiltinLower : public StmtExprMutator {
                      make_const(DataType::UInt(16), dtype.lanes())));
     // set byte offset
     int data_bytes = GetVectorBytes(dtype);
-    Expr byte_offset = op->args[5];
+    PrimExpr byte_offset = op->args[5];
     if (!is_zero(byte_offset)) {
       byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
     }
@@ -239,17 +239,17 @@ class BuiltinLower : public StmtExprMutator {
     return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
   }
   // call packed.
-  Expr MakeCallPacked(const CallNode* op) {
+  PrimExpr MakeCallPacked(const CallNode* op) {
     size_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
     run_arg_stack_ += op->args.size();
     // Specially handle the buffer packed intrinsic
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     for (size_t i = 1; i < op->args.size(); ++i) {
-      Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
-      Expr arg = op->args[i];
+      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
+      PrimExpr arg = op->args[i];
       DataType t = arg.dtype();
       DataType api_type = APIType(t);
       if (t != api_type) {
@@ -275,7 +275,7 @@ class BuiltinLower : public StmtExprMutator {
     run_shape_stack_ = restore_shape_stack;
     run_array_stack_ = restore_array_stack;
     run_arg_stack_ = arg_stack_begin;
-    Array<Expr> packed_args = {
+    Array<PrimExpr> packed_args = {
       op->args[0],
       stack_value_,
       stack_tcode_,
@@ -287,18 +287,18 @@ class BuiltinLower : public StmtExprMutator {
         packed_args, CallNode::Intrinsic);
   }
 
-  Expr MakeCallTracePacked(const CallNode *op) {
+  PrimExpr MakeCallTracePacked(const CallNode *op) {
     size_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
     run_arg_stack_ += op->args.size();
     size_t args_size = op->args.size();
     CHECK_GT(args_size, 0);
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     for (size_t i = 1; i < op->args.size(); ++i) {
-      Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
-      Expr arg = op->args[i];
+      PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
+      PrimExpr arg = op->args[i];
       DataType t = arg.dtype();
       DataType api_type = APIType(t);
       if (t != api_type) {
@@ -323,7 +323,7 @@ class BuiltinLower : public StmtExprMutator {
     // Update the top of the stack, so we can use more than one
     // packed function's arguments with the one stack.
     run_arg_stack_ = arg_stack_begin + args_size - 1;
-    Array<Expr> packed_args = {
+    Array<PrimExpr> packed_args = {
       op->args[0],
       stack_value_,
       stack_tcode_,
@@ -338,7 +338,7 @@ class BuiltinLower : public StmtExprMutator {
   }
 
  private:
-  bool IsArrayHandle(const Expr& arg) {
+  bool IsArrayHandle(const PrimExpr& arg) {
     // specially set array handle.
     if (const CallNode* buf = arg.as<CallNode>()) {
       if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
@@ -351,8 +351,8 @@ class BuiltinLower : public StmtExprMutator {
 
   // The prepration sequence to be emitted.
   std::vector<Stmt> prep_seq_;
-  Expr device_type_;
-  Expr device_id_;
+  PrimExpr device_type_;
+  PrimExpr device_id_;
   // Var handle for each stack.
   Var stack_shape_;
   Var stack_array_;
index 75f128e..6a1c3c4 100644 (file)
@@ -96,7 +96,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
       if (op->value.dtype().lanes() == 1) {
         UpdatePattern(op->index);
       } else {
-        Expr base;
+        PrimExpr base;
         CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
             << "LowerWarpMemory failed due to store index=" << op->index
             << ", can only handle continuous store";
@@ -107,13 +107,13 @@ class WarpStoreCoeffFinder : private StmtVisitor {
     }
   }
 
-  void UpdatePattern(const Expr& index) {
-    Array<Expr> m =
+  void UpdatePattern(const PrimExpr& index) {
+    Array<PrimExpr> m =
         arith::DetectLinearEquation(index, {warp_index_});
     CHECK_EQ(m.size(), 2U)
         << "LowerWarpMemory failed due to store index=" << index;
     int coeff = 0;
-    Expr mcoeff = analyzer_->canonical_simplify(m[0]);
+    PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
 
     CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
         << "LowerWarpMemory failed due to store index=" << index
@@ -211,7 +211,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
  protected:
-  Expr Mutate_(const VarNode* op) {
+  PrimExpr Mutate_(const VarNode* op) {
     CHECK(op != buffer_)
         << "Cannot access address of warp memory directly";
     return StmtExprMutator::VisitExpr_(op);
@@ -219,7 +219,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
 
   Stmt VisitStmt_(const StoreNode* op) {
     if (op->buffer_var.get() == buffer_) {
-      Expr local_index, group;
+      PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
       return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
     } else {
@@ -227,15 +227,15 @@ class WarpAccessRewriter : protected StmtExprMutator {
     }
   }
 
-  Expr Mutate_(const LoadNode* op) {
+  PrimExpr Mutate_(const LoadNode* op) {
     if (op->buffer_var.get() == buffer_) {
-      Expr local_index, group;
+      PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
       // invariance: local index must do not contain warp id
       CHECK(!ExprUseVar(local_index, {warp_index_.get()}))
           << "LowerWarpMemory failed to rewrite load to shuffle for index "
           << op->index << " local_index=" << local_index;
-      Expr load_value = LoadNode::make(
+      PrimExpr load_value = LoadNode::make(
           op->dtype, op->buffer_var, local_index, op->predicate);
       return CallNode::make(load_value.dtype(),
                         intrinsic::tvm_warp_shuffle,
@@ -250,27 +250,27 @@ class WarpAccessRewriter : protected StmtExprMutator {
   // local index is the index in the local
   // source index is the corresponding source index
   // in this access pattern.
-  std::pair<Expr, Expr> SplitIndexByGroup(const Expr& index) {
+  std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
     if (index.dtype().lanes() != 1) {
-      Expr base, local_index, group;
+      PrimExpr base, local_index, group;
       CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
       std::tie(local_index, group) = SplitIndexByGroup(base);
       local_index =
           RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
       return std::make_pair(local_index, group);
     }
-    Expr m = make_const(index.dtype(), warp_coeff_);
+    PrimExpr m = make_const(index.dtype(), warp_coeff_);
 
     // simple case, warp index is on the highest.
     if (warp_group_ == 1) {
-      Expr x = analyzer_->canonical_simplify(indexmod(index, m));
-      Expr z = analyzer_->canonical_simplify(indexdiv(index, m));
+      PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
+      PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m));
       return std::make_pair(x, z);
     } else {
-      Expr x = analyzer_->canonical_simplify(indexmod(index, m));
-      Expr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
+      PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
+      PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_);
       y = y * m + x;
-      Expr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
+      PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)),
                         m);
       return std::make_pair(analyzer_->canonical_simplify(y),
                             analyzer_->canonical_simplify(z));
index 56609bb..d5c73a2 100644 (file)
@@ -35,7 +35,7 @@
 namespace tvm {
 namespace ir {
 
-inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
+inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
   return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
 }
 
@@ -62,18 +62,18 @@ LoweredFunc MakeAPI(Stmt body,
   // seq_init gives sequence of initialization
   // seq_check gives sequence of later checks after init
   std::vector<Stmt> seq_init, seq_check;
-  std::unordered_map<const VarNode*, Expr> vmap;
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
   ArgBinder binder(&vmap);
   // ---------------------------
   // local function definitions
   // load i-th argument as type t
   auto f_arg_value = [&](DataType t, int i) {
-    Array<Expr> call_args{v_packed_args,
+    Array<PrimExpr> call_args{v_packed_args,
                           IntImmNode::make(DataType::Int(32), i),
                           IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
     // load 64 bit version
     DataType api_type = APIType(t);
-    Expr res = CallNode::make(
+    PrimExpr res = CallNode::make(
         api_type, intrinsic::tvm_struct_get, call_args,
         CallNode::PureIntrinsic);
     // cast to the target version.
@@ -189,7 +189,7 @@ LoweredFunc MakeAPI(Stmt body,
       StringImmNode::make(name + "_compute_"), body);
   // Set device context
   if (vmap.count(device_id.get())) {
-    Expr node = StringImmNode::make("default");
+    PrimExpr node = StringImmNode::make("default");
     CHECK(vmap.count(device_type.get()));
     seq_check.push_back(AttrStmtNode::make(
         node, attr::device_context_id, device_id, nop));
@@ -226,7 +226,7 @@ class DeviceTypeBinder: public StmtExprMutator {
     if (op->attr_key == attr::device_context_type) {
       if (const VarNode* var = op->value.as<VarNode>()) {
         var_ = var;
-        Expr value = make_const(op->value.dtype(), device_type_);
+        PrimExpr value = make_const(op->value.dtype(), device_type_);
         Stmt body = StmtExprMutator::VisitStmt_(op);
         var_ = nullptr;
         std::ostringstream os;
@@ -251,9 +251,9 @@ class DeviceTypeBinder: public StmtExprMutator {
     return res;
   }
 
-  Expr VisitExpr_(const NENode* op) final {
+  PrimExpr VisitExpr_(const NENode* op) final {
     // eager check NE for device check
-    Expr res = StmtExprMutator::VisitExpr_(op);
+    PrimExpr res = StmtExprMutator::VisitExpr_(op);
     op = res.as<NENode>();
     if (ir::Equal(op->a, op->b)) {
       return make_const(op->dtype, false);
@@ -261,11 +261,11 @@ class DeviceTypeBinder: public StmtExprMutator {
     return res;
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     if (op == var_) {
       return make_const(op->dtype, device_type_);
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
 
index 2a486b5..4201e78 100644 (file)
@@ -63,7 +63,7 @@ class ThreadAxisRewriter : private StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = vmap_.find(op);
     if (it != vmap_.end()) return it->second;
     return StmtExprMutator::VisitExpr_(op);
@@ -75,7 +75,7 @@ class ThreadAxisRewriter : private StmtExprMutator {
 };
 
 LoweredFunc
-RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
+RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
   std::unordered_map<std::string, IterVar> tmap;
   for (const auto& kv : thread_map) {
     const StringImmNode* str = kv.first.as<StringImmNode>();
index 3c9114d..eecbe30 100644 (file)
@@ -126,16 +126,16 @@ class NoOpRemover : public StmtMutator {
   }
 
  private:
-  Stmt MakeEvaluate(Expr value) {
+  Stmt MakeEvaluate(PrimExpr value) {
     if (HasSideEffect(value)) {
       return EvaluateNode::make(value);
     } else {
       return EvaluateNode::make(0);
     }
   }
-  Stmt MakeEvaluate(const Array<Expr>& values) {
+  Stmt MakeEvaluate(const Array<PrimExpr>& values) {
     Stmt stmt;
-    for (Expr e : values) {
+    for (PrimExpr e : values) {
       if (HasSideEffect(e)) {
         if (stmt.defined()) {
           stmt = SeqStmt({stmt, EvaluateNode::make(e)});
index c38fac1..224a81c 100644 (file)
@@ -31,7 +31,7 @@ namespace ir {
 
 // For now, rewrite unsafe select expression to if_then_else
 // TODO(tqchen) pattern matching to support masked load
-class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
+class UnsafeExprDetector : public ExprFunctor<bool(const PrimExpr& n)> {
  public:
   // select itself is always considered safe if condition is safe
   // Because we will issue guard to make sure it is.
@@ -45,7 +45,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
       const LoadNode* l = op->args[0].as<LoadNode>();
       return this->VisitExpr(l->index);
     } else if (op->is_pure()) {
-      for (Expr e : op->args) {
+      for (PrimExpr e : op->args) {
         if (VisitExpr(e)) return true;
       }
       return false;
@@ -90,7 +90,7 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
     return VisitExpr(op->base) && VisitExpr(op->stride);
   }
   bool VisitExpr_(const ShuffleNode* op) final {
-    for (Expr e : op->vectors) {
+    for (PrimExpr e : op->vectors) {
       if (VisitExpr(e)) return true;
     }
     return false;
@@ -110,8 +110,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
 
 class UnsafeSelectRewriter : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const SelectNode* op) {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const SelectNode* op) {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<SelectNode>();
     UnsafeExprDetector unsafe;
     bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
index 3233e50..9737f70 100644 (file)
@@ -30,7 +30,7 @@ namespace ir {
 
 class IRSideEffect : public ExprVisitor {
  public:
-  void VisitExpr(const Expr& e) final {
+  void VisitExpr(const PrimExpr& e) final {
     if (has_side_effect_) return;
     ExprVisitor::VisitExpr(e);
   }
@@ -46,7 +46,7 @@ class IRSideEffect : public ExprVisitor {
   bool has_side_effect_{false};
 };
 
-bool HasSideEffect(const Expr& e) {
+bool HasSideEffect(const PrimExpr& e) {
   IRSideEffect v;
   v(e);
   return v.has_side_effect_;
@@ -55,45 +55,45 @@ bool HasSideEffect(const Expr& e) {
 class IRSubstitue : public StmtExprMutator {
  public:
   explicit IRSubstitue(
-      const std::unordered_map<const VarNode*, Expr>& smap)
+      const std::unordered_map<const VarNode*, PrimExpr>& smap)
       : smap_(smap) {
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = smap_.find(op);
     if (it != smap_.end()) {
       return it->second;
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
 
  private:
-  const std::unordered_map<const VarNode*, Expr>& smap_;
+  const std::unordered_map<const VarNode*, PrimExpr>& smap_;
 };
 
 Stmt Substitute(Stmt stmt,
-                const std::unordered_map<const VarNode*, Expr>& value_map) {
+                const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
   if (value_map.size() == 0) return stmt;
   return IRSubstitue(value_map)(std::move(stmt));
 }
 
-Expr Substitute(Expr expr,
-                const std::unordered_map<const VarNode*, Expr>& value_map) {
+PrimExpr Substitute(PrimExpr expr,
+                const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
   if (value_map.size() == 0) return expr;
   return IRSubstitue(value_map)(std::move(expr));
 }
 
-Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
-  std::unordered_map<const VarNode*, Expr> vmap;
+Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map) {
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
   for (const auto& kv : value_map) {
     vmap[kv.first.get()] = kv.second;
   }
   return Substitute(stmt, vmap);
 }
 
-Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
-  std::unordered_map<const VarNode*, Expr> vmap;
+PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map) {
+  std::unordered_map<const VarNode*, PrimExpr> vmap;
   for (const auto& kv : value_map) {
     vmap[kv.first.get()] = kv.second;
   }
@@ -102,7 +102,7 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
 
 class VarTouchVisitor : public ExprVisitor {
  public:
-  void VisitExpr(const Expr& e) final {
+  void VisitExpr(const PrimExpr& e) final {
     if (use_var_) return;
     ExprVisitor::VisitExpr(e);
   }
@@ -146,13 +146,13 @@ class ExprUseVSetVisitor : public VarTouchVisitor {
   const std::unordered_set<const VarNode*>& vset_;
 };
 
-bool ExprUseVar(const Expr& e, const Var& v) {
+bool ExprUseVar(const PrimExpr& e, const Var& v) {
   ExprUseVarVisitor visitor(v.get());
   visitor(e);
   return visitor.use_var_;
 }
 
-bool ExprUseVar(const Expr& e,
+bool ExprUseVar(const PrimExpr& e,
                 const std::unordered_set<const VarNode*>& vset) {
   ExprUseVSetVisitor visitor(vset);
   visitor(e);
index f71f13b..7309c72 100644 (file)
@@ -46,7 +46,7 @@ class IRUseDefAnalysis : public StmtExprMutator {
         thread_extent_.push_back(op->value);
       }
 
-      Expr value = op->value;
+      PrimExpr value = op->value;
       if (visit_thread_extent_) {
         value = this->VisitExpr(value);
       }
@@ -68,7 +68,7 @@ class IRUseDefAnalysis : public StmtExprMutator {
         !HasSideEffect(op->value)) {
       return body;
     } else {
-      Expr value = this->VisitExpr(op->value);
+      PrimExpr value = this->VisitExpr(op->value);
       if (body.same_as(op->body) &&
           value.same_as(op->value)) {
         return GetRef<Stmt>(op);
@@ -93,30 +93,30 @@ class IRUseDefAnalysis : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const LetNode* op) final {
+  PrimExpr VisitExpr_(const LetNode* op) final {
     this->HandleDef(op->var.get());
-    Expr body = this->VisitExpr(op->body);
+    PrimExpr body = this->VisitExpr(op->body);
     // eliminate unreferenced let
     if (use_count_.at(op->var.get()) == 0 &&
         !HasSideEffect(op->value)) {
       return body;
     } else {
-      Expr value = this->VisitExpr(op->value);
+      PrimExpr value = this->VisitExpr(op->value);
       if (body.same_as(op->body) &&
           value.same_as(op->value)) {
-        return GetRef<Expr>(op);
+        return GetRef<PrimExpr>(op);
       } else {
         return LetNode::make(op->var, value, body);
       }
     }
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
-    this->HandleUse(GetRef<Expr>(op));
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    this->HandleUse(GetRef<PrimExpr>(op));
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const LoadNode* op) final {
+  PrimExpr VisitExpr_(const LoadNode* op) final {
     this->HandleUse(op->buffer_var);
     return StmtExprMutator::VisitExpr_(op);
   }
@@ -132,7 +132,7 @@ class IRUseDefAnalysis : public StmtExprMutator {
     def_count_[v] = 1;
   }
 
-  void HandleUse(const Expr& v) {
+  void HandleUse(const PrimExpr& v) {
     CHECK(v.as<VarNode>());
     Var var = Downcast<Var>(v);
     auto it = use_count_.find(var.get());
@@ -151,7 +151,7 @@ class IRUseDefAnalysis : public StmtExprMutator {
   bool visit_thread_extent_{true};
   Array<Var> undefined_;
   Array<IterVar> thread_axis_;
-  Array<Expr> thread_extent_;
+  Array<PrimExpr> thread_extent_;
   std::unordered_map<const VarNode*, int> use_count_;
   std::unordered_map<const VarNode*, int> def_count_;
 };
@@ -218,12 +218,12 @@ class HostDeviceSplitter : public StmtMutator {
       }
     }
     LoweredFunc f_device(n);
-    Array<Expr> call_args;
+    Array<PrimExpr> call_args;
     call_args.push_back(StringImmNode::make(f_device->name));
     for (Var arg : n->args) {
       call_args.push_back(arg);
     }
-    for (Expr ext : m.thread_extent_) {
+    for (PrimExpr ext : m.thread_extent_) {
       call_args.push_back(ext);
     }
     device_funcs_.emplace_back(f_device);
@@ -236,7 +236,7 @@ class HostDeviceSplitter : public StmtMutator {
   std::string name_;
   // the device functions
   std::vector<LoweredFunc> device_funcs_;
-  std::unordered_map<const VarNode*, Expr> handle_data_type_;
+  std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
 };
 
 
index 3dafb40..8375e80 100644 (file)
@@ -37,7 +37,7 @@ class IRVerifySSA final : public StmtExprVisitor {
  public:
   bool is_ssa{true};
 
-  void VisitExpr(const Expr& n) final {
+  void VisitExpr(const PrimExpr& n) final {
     if (!is_ssa) return;
     StmtExprVisitor::VisitExpr(n);
   }
@@ -76,20 +76,20 @@ class IRVerifySSA final : public StmtExprVisitor {
 
 class IRConvertSSA final : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     if (scope_.count(op)) {
       return scope_[op].back();
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
-  Expr VisitExpr_(const LetNode* op) final {
-    const VarExpr& v = op->var;
+  PrimExpr VisitExpr_(const LetNode* op) final {
+    const Var& v = op->var;
     if (defined_.count(v.get())) {
-      Expr value = this->VisitExpr(op->value);
-      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+      PrimExpr value = this->VisitExpr(op->value);
+      Var new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
-      Expr body = this->VisitExpr(op->body);
+      PrimExpr body = this->VisitExpr(op->body);
       scope_[v.get()].pop_back();
       return LetNode::make(new_var, value, body);
     } else {
@@ -97,8 +97,8 @@ class IRConvertSSA final : public StmtExprMutator {
       return StmtExprMutator::VisitExpr_(op);
     }
   }
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     if (scope_.count(op->buffer_var.get())) {
       return LoadNode::make(
@@ -120,10 +120,10 @@ class IRConvertSSA final : public StmtExprMutator {
     }
   }
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    const VarExpr& v = op->var;
+    const Var& v = op->var;
     if (defined_.count(v.get())) {
-      Expr value = this->VisitExpr(op->value);
-      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+      PrimExpr value = this->VisitExpr(op->value);
+      Var new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt body = this->VisitStmt(op->body);
       scope_[v.get()].pop_back();
@@ -134,9 +134,9 @@ class IRConvertSSA final : public StmtExprMutator {
     }
   }
   Stmt VisitStmt_(const ForNode* op) final {
-    const VarExpr& v = op->loop_var;
+    const Var& v = op->loop_var;
     if (defined_.count(v.get())) {
-      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+      Var new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
@@ -149,9 +149,9 @@ class IRConvertSSA final : public StmtExprMutator {
     }
   }
   Stmt VisitStmt_(const AllocateNode* op) final {
-    const VarExpr& v = op->buffer_var;
+    const Var& v = op->buffer_var;
     if (defined_.count(v.get())) {
-      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
+      Var new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
@@ -191,7 +191,7 @@ class IRConvertSSA final : public StmtExprMutator {
   }
 
  private:
-  std::unordered_map<const VarNode*, std::vector<VarExpr> > scope_;
+  std::unordered_map<const VarNode*, std::vector<Var> > scope_;
   std::unordered_set<const VarNode*> defined_;
 };
 
index cb779f9..d98299f 100644 (file)
@@ -187,8 +187,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
     CHECK_EQ(op->args.size(), 5U);
     DataType dtype = op->args[0].dtype();
     const VarNode* buffer = op->args[1].as<VarNode>();
-    Expr offset = op->args[2];
-    Expr extent = op->args[3];
+    PrimExpr offset = op->args[2];
+    PrimExpr extent = op->args[3];
     const IntImmNode* flag = op->args[4].as<IntImmNode>();
     StorageScope scope = GetScope(buffer);
     // The buffer scope.
@@ -197,7 +197,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
       AccessEntry e;
       e.threads = env_threads();
       e.dtype = dtype;
-      e.buffer = Downcast<VarExpr>(op->args[1]);
+      e.buffer = Downcast<Var>(op->args[1]);
       e.touched = arith::IntSet::range(
           Range::make_by_min_extent(offset, extent));
       e.scope = scope;
@@ -277,7 +277,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       return MakeAccessPtr(op);
     } else {
@@ -287,15 +287,15 @@ class StorageAccessInfoLower : public StmtExprMutator {
 
  private:
   // tvm_access_ptr
-  Expr MakeAccessPtr(const CallNode* op) {
+  PrimExpr MakeAccessPtr(const CallNode* op) {
     // Specially handle the buffer packed intrinsic
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     CHECK_EQ(op->args.size(), 5U);
     DataType dtype = op->args[0].dtype();
     const VarNode* buffer = op->args[1].as<VarNode>();
     Var buffer_var = Downcast<Var>(op->args[1]);
-    Expr offset = op->args[2];
+    PrimExpr offset = op->args[2];
     auto it = storage_info_.find(buffer);
     if (it != storage_info_.end() && it->second.info.defined()) {
       return MakeTaggedAccessPtr(
@@ -307,10 +307,10 @@ class StorageAccessInfoLower : public StmtExprMutator {
     return AddressOffset(buffer_var, dtype, offset);
   }
 
-  Expr MakeTaggedAccessPtr(DataType ptr_type,
+  PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
                            Var buffer_var,
                            DataType dtype,
-                           Expr offset,
+                           PrimExpr offset,
                            const MemoryInfo& info) {
     if (ptr_type.is_handle()) {
       CHECK(info->head_address.defined())
index ea828ff..08c61aa 100644 (file)
@@ -70,7 +70,7 @@ class StorageFlattener : public StmtExprMutator {
     if (it != var_remap_.end() &&
         !it->second.same_as(op->buffer_var)) {
       CHECK(it->second.as<VarNode>());
-      VarExpr buf_var = Downcast<VarExpr>(it->second);
+      Var buf_var = Downcast<Var>(it->second);
       return StoreNode::make(buf_var, op->value, op->index, op->predicate);
     } else {
       return stmt;
@@ -167,7 +167,7 @@ class StorageFlattener : public StmtExprMutator {
       // create a buffer entry
       BufferEntry e;
       e.bounds = op->bounds;
-      Array<Expr> shape;
+      Array<PrimExpr> shape;
       for (auto r : e.bounds) {
         shape.push_back(r->extent);
       }
@@ -198,29 +198,29 @@ class StorageFlattener : public StmtExprMutator {
               << "Allocation exceed bound of memory tag " << skey.to_string();
         }
       }
-      Array<Expr> strides;
+      Array<PrimExpr> strides;
       if (dim_align_.count(key) != 0 && shape.size() != 0) {
-        std::vector<Expr> rstrides;
+        std::vector<PrimExpr> rstrides;
         const std::vector<DimAlignInfo>& avec = dim_align_[key];
         int first_dim = 0;
-        Expr stride = make_const(shape[first_dim].dtype(), 1);
+        PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
         for (size_t i = shape.size(); i != 0; --i) {
           size_t dim = i - 1;
           if (dim < avec.size() && avec[dim].align_factor != 0) {
-            Expr factor = make_const(stride.dtype(), avec[dim].align_factor);
-            Expr offset = make_const(stride.dtype(), avec[dim].align_offset);
+            PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+            PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
             stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
             stride = ir::Simplify(stride);
           }
           rstrides.push_back(stride);
           stride = stride * shape[dim];
         }
-        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+        strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
       }
 
       e.buffer = BufferNode::make(
           Var(key.GetName(), DataType::Handle()),
-          op->dtype, shape, strides, Expr(),
+          op->dtype, shape, strides, PrimExpr(),
           key.GetName(), skey.to_string(),
           align, 0, kDefault);
 
@@ -262,31 +262,31 @@ class StorageFlattener : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     auto it = var_remap_.find(op->buffer_var.get());
     if (it != var_remap_.end() &&
         !it->second.same_as(op->buffer_var)) {
       CHECK(it->second.as<VarNode>());
-      VarExpr buf_var = Downcast<VarExpr>(it->second);
+      Var buf_var = Downcast<Var>(it->second);
       return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
     } else {
       return expr;
     }
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = var_remap_.find(op);
     if (it != var_remap_.end()) {
       return it->second;
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     if (op != nullptr && op->call_type == CallNode::Halide) {
       TensorKey key{op->func, op->value_index};
@@ -332,20 +332,20 @@ class StorageFlattener : public StmtExprMutator {
       block_size *= shape;
       starts--;
     }
-    Expr stride(elem_cnt / block_size);
+    PrimExpr stride(elem_cnt / block_size);
 
-    Array<Expr> args;
-    std::vector<VarExpr> vars;
+    Array<PrimExpr> args;
+    std::vector<Var> vars;
 
     for (int i = op->bounds.size() - 1; i > starts; --i) {
       args.push_back(op->bounds[i]->min);
     }
     auto &func_name = op->func->func_name();
-    vars.push_back(VarExpr(
+    vars.push_back(Var(
         "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
     args.push_back(op->bounds[starts]->min + stride * vars.back());
     for (int i = starts - 1; i >= 0; --i) {
-      vars.push_back(VarExpr(
+      vars.push_back(Var(
           "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
       args.push_back(vars.back() + op->bounds[i]->min);
     }
@@ -354,13 +354,13 @@ class StorageFlattener : public StmtExprMutator {
         stmt = ForNode::make(
             vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
       } else {
-        Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
-        Expr address = CallNode::make(
+        PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
+        PrimExpr address = CallNode::make(
             DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
-        Expr prefetch = CallNode::make(
+        PrimExpr prefetch = CallNode::make(
             op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
         stmt = EvaluateNode::make(prefetch);
-        Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
+        PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
         stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
       }
     }
@@ -416,7 +416,7 @@ class StorageFlattener : public StmtExprMutator {
     const BufferEntry& be = buf_map_.at(key);
     CHECK(!be.released);
     CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
-    Array<Expr> begins, extents;
+    Array<PrimExpr> begins, extents;
     if (be.bounds.size() != 0) {
       CHECK_EQ(tuple->args.size(), be.bounds.size() * 2);
       for (size_t i = 0; i < be.buffer->shape.size(); ++i) {
@@ -467,9 +467,9 @@ class StorageFlattener : public StmtExprMutator {
     // Whether we are out of allocation bounds and buffer get released.
     bool released{false};
     // relative index
-    inline Array<Expr> RelIndex(Array<Expr> args) const {
+    inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
       if (bounds.size() != 0) {
-        Array<Expr> index;
+        Array<PrimExpr> index;
         CHECK_EQ(bounds.size(), args.size());
         for (size_t i = 0; i < bounds.size(); ++i) {
           index.push_back(args[i] - bounds[i]->min);
@@ -481,7 +481,7 @@ class StorageFlattener : public StmtExprMutator {
     }
   };
 
-  bool ShapeIsValid(const Array<Expr> &shape) {
+  bool ShapeIsValid(const Array<PrimExpr> &shape) {
     // Zero-dimensional tensor does not need boundary check.
     if (!shape.size())
       return false;
@@ -495,9 +495,9 @@ class StorageFlattener : public StmtExprMutator {
     return true;
   }
 
-  Expr MakeBound(const DataType &type, const Array<Expr> &shape) {
+  PrimExpr MakeBound(const DataType &type, const Array<PrimExpr> &shape) {
     // We have already checked the shape size to be greater then 0.
-    Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
+    PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
     for (size_t i = 1; i < shape.size(); ++i) {
       bound = MulNode::make(
           bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
@@ -507,7 +507,7 @@ class StorageFlattener : public StmtExprMutator {
 
   // The buffer assignment map
   // Variable remap
-  std::unordered_map<const VarNode*, Expr> var_remap_;
+  std::unordered_map<const VarNode*, PrimExpr> var_remap_;
   // Buffer map
   std::unordered_map<TensorKey, BufferEntry> buf_map_;
   // Dimension alignment
@@ -517,7 +517,7 @@ class StorageFlattener : public StmtExprMutator {
   // The current thread scope.
   std::vector<ThreadScope> curr_thread_scope_;
   // Collects shapes.
-  std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
+  std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
   // bounds populator. We really need the analyzer from it.
   // However
   IRVisitorWithAnalyzer* bounded_analyzer_;
index 928be4b..7a4b13c 100644 (file)
@@ -261,7 +261,7 @@ class InplaceOpVerifier : public StmtExprVisitor {
     if (!result_) return;
     StmtExprVisitor::VisitStmt(n);
   }
-  void VisitExpr(const Expr& n) final {
+  void VisitExpr(const PrimExpr& n) final {
     if (!result_) return;
     StmtExprVisitor::VisitExpr(n);
   }
@@ -376,8 +376,8 @@ class StoragePlanRewriter : public StmtExprMutator {
                        RemapIndex(op->value.dtype(), op->index, it->second),
                        op->predicate);
   }
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     auto it = alloc_map_.find(op->buffer_var.get());
     if (it == alloc_map_.end()) return expr;
@@ -386,7 +386,7 @@ class StoragePlanRewriter : public StmtExprMutator {
                       RemapIndex(op->dtype, op->index, it->second),
                       op->predicate);
   }
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = alloc_map_.find(op);
     if (it != alloc_map_.end()) {
       if (it->second->bits_offset != 0) {
@@ -394,10 +394,10 @@ class StoragePlanRewriter : public StmtExprMutator {
       }
       return it->second->alloc_var;
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       CHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
@@ -407,8 +407,8 @@ class StoragePlanRewriter : public StmtExprMutator {
         return StmtExprMutator::VisitExpr_(op);
       }
       const StorageEntry* se = it->second;
-      Expr offset = this->VisitExpr(op->args[2]);
-      Expr extent = this->VisitExpr(op->args[3]);
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
       uint64_t elem_bits = dtype.bits() * dtype.lanes();
       CHECK_EQ(se->bits_offset % elem_bits, 0U);
       if (se->bits_offset != 0) {
@@ -488,7 +488,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     // The replacement allocation, if any.
     Stmt new_alloc;
     // The var expr of new allocation.
-    VarExpr alloc_var;
+    Var alloc_var;
     // The allocation element type.
     DataType elem_type;
     // This is non-zero if this allocate is folded into another one
@@ -529,7 +529,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     return MergeNest(nest, body);
   }
   // Remap the index
-  Expr RemapIndex(DataType dtype, Expr index, StorageEntry* e) {
+  PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
     if (e->bits_offset == 0) return index;
     uint64_t elem_bits = dtype.bits() * dtype.lanes();
     CHECK_EQ(e->bits_offset % elem_bits, 0U);
@@ -577,7 +577,7 @@ class StoragePlanRewriter : public StmtExprMutator {
         }
         if (e->allocs.size() == 1) {
           // simply use the original allocation.
-          Expr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
+          PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
                                               make_const(DataType::Int(32), 1));
           e->new_alloc = AllocateNode::make(
               e->alloc_var, alloc_type, {sz},
@@ -590,9 +590,10 @@ class StoragePlanRewriter : public StmtExprMutator {
           }
         } else {
           // Build a merged allocation
-          Expr combo_size;
+          PrimExpr combo_size;
           for (const AllocateNode* op : e->allocs) {
-            Expr sz = arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
+            PrimExpr sz = arith::ComputeReduce<MulNode>(
+                op->extents, make_const(DataType::Int(32), 1));
             auto nbits = op->dtype.bits() * op->dtype.lanes();
             if (const auto* imm = sz.as<IntImmNode>()) {
               if (imm->value > std::numeric_limits<int>::max() / nbits) {
@@ -663,7 +664,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       }
     }
     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
-    Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
+    PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
                                  (total_bits + type_bits - 1) / type_bits);
     e->new_alloc = AllocateNode::make(
         e->alloc_var, e->elem_type, {alloc_size}, const_true(),
@@ -936,7 +937,7 @@ class StoragePlanRewriter : public StmtExprMutator {
 // if all its access is the same vector type.
 class VectorAllocRewriter : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const LoadNode* op) final {
+  PrimExpr VisitExpr_(const LoadNode* op) final {
     UpdateTypeMap(op->buffer_var.get(), op->dtype);
     return StmtExprMutator::VisitExpr_(op);
   }
@@ -945,7 +946,7 @@ class VectorAllocRewriter : public StmtExprMutator {
     UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
     return StmtExprMutator::VisitStmt_(op);
   }
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       DataType dtype = op->args[0].dtype();
       const VarNode* buffer = op->args[1].as<VarNode>();
@@ -964,7 +965,7 @@ class VectorAllocRewriter : public StmtExprMutator {
         tvec[0].lanes() % op->dtype.lanes() == 0 &&
         tvec[0].lanes() != op->dtype.lanes()) {
       int factor = tvec[0].lanes() / op->dtype.lanes();
-      Array<Expr> extents = op->extents;
+      Array<PrimExpr> extents = op->extents;
       arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
       if (me->base % factor == 0 && me->coeff % factor == 0) {
         extents.Set(extents.size() - 1,
@@ -999,13 +1000,13 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
     if (arg.dtype().is_handle()) {
       const auto& tvec = rewriter.acc_map_[arg.get()];
       if (tvec.size() == 1) {
-        Expr dtype = make_const(tvec[0], 0);
+        PrimExpr dtype = make_const(tvec[0], 0);
         n->handle_data_type.Set(arg, dtype);
       } else {
         // always set data type to be non vectorized so
         // load/store can still work via scalarization
         if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
-          Expr dtype = make_const(tvec[0].with_lanes(1), 0);
+          PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
           n->handle_data_type.Set(arg, dtype);
         }
       }
index 7edf98b..2358ce9 100644 (file)
@@ -222,7 +222,7 @@ class ThreadSyncInserter : public StmtExprMutator {
       return StmtExprMutator::VisitStmt(stmt);
     }
   }
-  Expr VisitExpr_(const LoadNode* op) final {
+  PrimExpr VisitExpr_(const LoadNode* op) final {
     if (sync_scope_.rank == StorageRank::kGlobal &&
         GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
       ++rw_stats_[op->buffer_var].read_count;
@@ -247,8 +247,8 @@ class ThreadSyncInserter : public StmtExprMutator {
       // first thread scope.
       if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
         ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
-        num_blocks_ = Expr();
-        is_lead_ = Expr();
+        num_blocks_ = PrimExpr();
+        is_lead_ = PrimExpr();
       }
       return ret;
     } else if (op->attr_key == attr::storage_scope) {
@@ -261,9 +261,9 @@ class ThreadSyncInserter : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
-      Expr expr = StmtExprMutator::VisitExpr_(op);
+      PrimExpr expr = StmtExprMutator::VisitExpr_(op);
       op = expr.as<CallNode>();
       CHECK_EQ(op->args.size(), 5U);
       const VarNode* buffer_var = op->args[1].as<VarNode>();
@@ -300,7 +300,7 @@ class ThreadSyncInserter : public StmtExprMutator {
   // private functions.
   Stmt InitGlobalBarrier(const AttrStmtNode* op) {
     CHECK(op != nullptr);
-    Array<Expr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
+    Array<PrimExpr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
     Stmt prep = EvaluateNode::make(
         CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
     Stmt body = op->body;
@@ -332,7 +332,7 @@ class ThreadSyncInserter : public StmtExprMutator {
           num_blocks_ = (num_blocks_.defined() ?
                          attr->value * num_blocks_ : attr->value);
         } else if (s.rank == 1) {
-          Expr cond = iv->var == make_zero(iv->var.dtype());
+          PrimExpr cond = iv->var == make_zero(iv->var.dtype());
           is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
         }
       }
@@ -351,14 +351,14 @@ class ThreadSyncInserter : public StmtExprMutator {
   // The storage scope of each buffer
   std::unordered_map<const VarNode*, StorageScope> storage_scope_;
   // The read write statistics of storage
-  std::unordered_map<VarExpr, Entry, ObjectHash, ObjectEqual> rw_stats_;
+  std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> rw_stats_;
   // The statistics for global barrier
   bool in_thread_env_{false};
   // memorized results
   std::vector<const AttrStmtNode*> thread_extents_;
   size_t num_work_dim_{0};
-  Expr num_blocks_;
-  Expr is_lead_;
+  PrimExpr num_blocks_;
+  PrimExpr is_lead_;
 };
 
 Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
index b2658d9..bb57fe8 100644 (file)
@@ -59,14 +59,14 @@ std::string simplify_name(std::string input) {
   }
 }
 
-Expr unpack_type_cast(const Expr &input, const DataType &target_type) {
+PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) {
   auto cast = input.as<CastNode>();
   if (cast == nullptr) {
     return input;
   } else if (cast->dtype == target_type) {
     return cast->value;
   }
-  return Expr();
+  return PrimExpr();
 }
 
 // MMAMatcher matches C = Cast(A)*Cast(B)+C,
@@ -217,14 +217,14 @@ class MMAMatcher: public StmtVisitor {
     buf_name_.insert(std::make_pair(load_a, buffer_a.name));
     buf_name_.insert(std::make_pair(load_b, buffer_b.name));
     mma_sync_.insert(std::make_pair(op,
-      Array<Expr>{load_a_expr, load_b_expr, add->a}));
+      Array<PrimExpr>{load_a_expr, load_b_expr, add->a}));
 
     return true;
   }
 
   std::unordered_map<TensorKey, BufferInfo> buf_map_;
   std::unordered_map<const Object*, std::string> storage_scope_;
-  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
+  std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
   std::unordered_map<const Object*, std::string> buf_name_;
   std::unordered_set<std::string> frag_reg_;
   bool matched_{false};
@@ -243,7 +243,7 @@ class BodyVisitor : public StmtExprVisitor {
     if (comm_add == nullptr || op->combiner->result.size() > 1) {
       return;
     }
-    for (Expr source : op->source) {
+    for (PrimExpr source : op->source) {
       auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<MulNode>();
       auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<MulNode>();
       if (mul_0 == nullptr && mul_1 == nullptr) {
@@ -263,7 +263,7 @@ class BodyVisitor : public StmtExprVisitor {
   friend class ScheduleAnalyser;
 
  private:
-  std::unordered_map<std::string, Array<Expr>> args_;
+  std::unordered_map<std::string, Array<PrimExpr>> args_;
   bool tensorcore_candidate_{false};
 };
 
@@ -294,7 +294,7 @@ class ScheduleAnalyser {
       reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
 
       BodyVisitor body_visitor;
-      for (Expr expr : compute->body) {
+      for (PrimExpr expr : compute->body) {
         body_visitor(expr);
       }
       if (!body_visitor.tensorcore_candidate_) {
@@ -347,7 +347,7 @@ class ScheduleAnalyser {
       if (it0->second == "matrix_a" && it1->second == "matrix_b") {
         return true;
       } else if (it0->second == "matrix_b" && it1->second == "matrix_a") {
-        mma_sync.second = Array<Expr>{operands[1], operands[0], operands[2]};
+        mma_sync.second = Array<PrimExpr>{operands[1], operands[0], operands[2]};
       } else {
         return false;
       }
@@ -361,7 +361,7 @@ class ScheduleAnalyser {
  private:
   std::unordered_map<std::string, std::string> matrix_abc_;
   std::unordered_map<std::string, std::string> matrix_major_;
-  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
+  std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
   std::unordered_map<const Object*, std::string> buf_name_;
 };
 
@@ -457,12 +457,12 @@ class BufferAnalyser : public StmtExprVisitor {
       }
     }
 
-    Array<Expr> strides;
+    Array<PrimExpr> strides;
     if (bi.strides.size() > 0) {
       strides = bi.strides;
     } else {
       for (size_t i = 1; i < bi.shape.size(); ++i) {
-        Expr stride = IntImmNode::make(DataType::Int(32), 1);
+        PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
         for (size_t j = bi.shape.size() - 1; j >= i; --j) {
           stride = MulNode::make(stride, bi.shape[j]);
         }
@@ -473,7 +473,7 @@ class BufferAnalyser : public StmtExprVisitor {
     strides_.insert(std::make_pair(key.GetName(), strides));
 
     if (frag_reg_.count(bi.name)) {
-      Expr dst = CallNode::make(bi.dtype,
+      PrimExpr dst = CallNode::make(bi.dtype,
                             bi.name,
                             op->args,
                             CallNode::Halide,
@@ -535,7 +535,7 @@ class BufferAnalyser : public StmtExprVisitor {
 
     const CallNode* value = op->value.as<CallNode>();
     if (value != nullptr && frag_reg_.count(value->name)) {
-      Expr dst = CallNode::make(bi.dtype,
+      PrimExpr dst = CallNode::make(bi.dtype,
                             bi.name,
                             op->args,
                             CallNode::Halide,
@@ -570,12 +570,12 @@ class BufferAnalyser : public StmtExprVisitor {
         }
       }
 
-      Array<Expr> strides;
+      Array<PrimExpr> strides;
       if (bi.strides.size() > 0) {
         strides = bi.strides;
       } else {
         for (size_t i = 1; i < bi.shape.size(); ++i) {
-          Expr stride = IntImmNode::make(DataType::Int(32), 1);
+          PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
           for (size_t j = bi.shape.size() - 1; j >= i; --j) {
             stride = MulNode::make(stride, bi.shape[j]);
           }
@@ -616,22 +616,22 @@ class BufferAnalyser : public StmtExprVisitor {
       BufferInfo bi;
 
       bi.bounds = op->bounds;
-      Array<Expr> shape;
+      Array<PrimExpr> shape;
       for (auto r : bi.bounds) {
         shape.push_back(r->extent);
       }
 
-      Array<Expr> strides;
+      Array<PrimExpr> strides;
       if (dim_align_.count(key) != 0 && shape.size() != 0) {
-        std::vector<Expr> rstrides;
+        std::vector<PrimExpr> rstrides;
         const std::vector<DimAlignInfo>& avec = dim_align_[key];
         int first_dim = 0;
-        Expr stride = make_const(shape[first_dim].dtype(), 1);
+        PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
         for (size_t i = shape.size(); i != 0; --i) {
           size_t dim = i - 1;
           if (dim < avec.size() && avec[dim].align_factor != 0) {
-            Expr factor = make_const(stride.dtype(), avec[dim].align_factor);
-            Expr offset = make_const(stride.dtype(), avec[dim].align_offset);
+            PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
+            PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
             stride = stride + \
               indexmod(factor + offset - indexmod(stride, factor), factor);
             stride = ir::Simplify(stride);
@@ -639,7 +639,7 @@ class BufferAnalyser : public StmtExprVisitor {
           rstrides.push_back(stride);
           stride = stride * shape[dim];
         }
-        strides = Array<Expr>(rstrides.rbegin(), rstrides.rend());
+        strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
       }
 
       bi.name = key.GetName();
@@ -689,14 +689,14 @@ class BufferAnalyser : public StmtExprVisitor {
   struct BufferInfo {
     std::string name;
     DataType dtype;
-    Array<Expr> strides;
-    Array<Expr> shape;
+    Array<PrimExpr> strides;
+    Array<PrimExpr> shape;
     Region bounds;
     bool external{false};
     bool released{false};
-    inline Array<Expr> RelIndex(Array<Expr> args) const {
+    inline Array<PrimExpr> RelIndex(Array<PrimExpr> args) const {
       if (bounds.size() != 0) {
-        Array<Expr> index;
+        Array<PrimExpr> index;
         CHECK_EQ(bounds.size(), args.size());
         for (size_t i = 0; i < bounds.size(); ++i) {
           index.push_back(args[i] - bounds[i]->min);
@@ -744,9 +744,9 @@ class BufferAnalyser : public StmtExprVisitor {
   std::unordered_map<std::string, std::string> matrix_abc_;
   std::unordered_map<std::string, std::string> matrix_major_;
   std::unordered_set<std::string> frag_reg_;
-  std::unordered_map<std::string, Array<Expr>> strides_;
-  std::unordered_map<const ProvideNode*, Expr> frag_load_;
-  std::unordered_map<const ProvideNode*, Expr> frag_store_;
+  std::unordered_map<std::string, Array<PrimExpr>> strides_;
+  std::unordered_map<const ProvideNode*, PrimExpr> frag_load_;
+  std::unordered_map<const ProvideNode*, PrimExpr> frag_store_;
   std::unordered_map<std::string, int> thread_extent_;
   IndexVisitor index_visitor;
   Tile warp_tile_;
@@ -758,19 +758,19 @@ class BufferAnalyser : public StmtExprVisitor {
 // ThreadIdxMutator does the thread index unification inside a warp
 class ThreadIdxMutator : public StmtExprMutator {
  public:
-  explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
+  explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {}
 
-  Expr VisitExpr_(const VarNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<VarNode>();
     if (op != nullptr) {
       if (op->name_hint == "threadIdx.x") {
-        Expr zero = IntImmNode::make(DataType::Int(32), 0);
+        PrimExpr zero = IntImmNode::make(DataType::Int(32), 0);
         return zero;
       }
       if (op->name_hint == "threadIdx.y") {
-        Expr div = DivNode::make(expr, warp_y_);
-        Expr mul = MulNode::make(div, warp_y_);
+        PrimExpr div = DivNode::make(expr, warp_y_);
+        PrimExpr mul = MulNode::make(div, warp_y_);
         return mul;
       }
     }
@@ -778,7 +778,7 @@ class ThreadIdxMutator : public StmtExprMutator {
   }
 
  private:
-  Expr warp_y_;
+  PrimExpr warp_y_;
 };
 
 // TensorCoreIRMutator mutates the AST for TensorCore CodeGen
@@ -856,11 +856,11 @@ class TensorCoreIRMutator : public StmtExprMutator {
     auto it = mma_sync_.find(op);
     if (it != mma_sync_.end()) {
       const auto &operands = it->second;
-      Expr a = operands[0];
+      PrimExpr a = operands[0];
       auto ca = a.as<CallNode>();
-      Expr b = operands[1];
+      PrimExpr b = operands[1];
       auto cb = b.as<CallNode>();
-      Expr c = operands[2];
+      PrimExpr c = operands[2];
       auto cc = c.as<CallNode>();
 
       ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
@@ -900,7 +900,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
     auto it2 = frag_load_.find(op);
     if (it2 != frag_load_.end()) {
-      Expr dst = it2->second;
+      PrimExpr dst = it2->second;
       if (op->value.as<FloatImmNode>() != nullptr ||
           op->value.as<IntImmNode>() != nullptr) {
         auto call = dst.as<CallNode>();
@@ -931,19 +931,19 @@ class TensorCoreIRMutator : public StmtExprMutator {
           << "Cannot find stride for " << value->name;
       auto strides = it->second;
       CHECK_GE(strides.size(), 2);
-      Expr stride = strides[strides.size()-2];
+      PrimExpr stride = strides[strides.size()-2];
 
       // thread index unification inside a warp
-      Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+      PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
-      Expr mutated_value = thread_idx_mutator(op->value);
-      Expr src = CallNode::make(value->dtype,
+      PrimExpr mutated_value = thread_idx_mutator(op->value);
+      PrimExpr src = CallNode::make(value->dtype,
                             "&",
                             {mutated_value},
                             CallNode::Extern);
 
       auto call = dst.as<CallNode>();
-      Expr matrix_major;
+      PrimExpr matrix_major;
       auto iter2 = matrix_major_.find(simplify_name(call->name));
       CHECK(iter2 != matrix_major_.end())
           << "Can not determine matrix major for " << call->name;
@@ -980,11 +980,11 @@ class TensorCoreIRMutator : public StmtExprMutator {
           << "Cannot find stride for " << key.GetName();
       auto strides = it->second;
       CHECK_GE(strides.size(), 2);
-      Expr stride = strides[strides.size()-2];
+      PrimExpr stride = strides[strides.size()-2];
 
-      Expr dst = it3->second;
+      PrimExpr dst = it3->second;
       // thread index unification inside a warp
-      Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
+      PrimExpr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       dst = thread_idx_mutator(dst);
       dst = CallNode::make(DataType::Handle(),
@@ -1027,7 +1027,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
           int ori_extent_value = ori_extent->value;
           scaled_extent_value = ori_extent_value / scale_factor;
         }
-        Expr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
+        PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
         stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type,
           op->device_api, op->body);
       }
@@ -1036,13 +1036,13 @@ class TensorCoreIRMutator : public StmtExprMutator {
   }
 
  private:
-  Array<Expr> get_tile_size_(const std::string &name) {
+  Array<PrimExpr> get_tile_size_(const std::string &name) {
       auto it = matrix_abc_.find(name);
       auto it2 = matrix_major_.find(name);
       CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end())
           << "Cannot find matrix info for " << name;
-      Expr size0 = make_const(DataType::Int(32), 16);
-      Expr size1 = make_const(DataType::Int(32), 16);
+      PrimExpr size0 = make_const(DataType::Int(32), 16);
+      PrimExpr size1 = make_const(DataType::Int(32), 16);
       if (it->second == "matrix_a" && it2->second == "col_major") {
         size0 = make_const(DataType::Int(32), warp_tile_.k);
         size1 = make_const(DataType::Int(32), warp_tile_.m);
@@ -1063,7 +1063,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
         size0 = make_const(DataType::Int(32), warp_tile_.n);
         size1 = make_const(DataType::Int(32), warp_tile_.m);
       }
-      Array<Expr> tile_size = {size0, size1};
+      Array<PrimExpr> tile_size = {size0, size1};
       return tile_size;
   }
 
@@ -1073,13 +1073,13 @@ class TensorCoreIRMutator : public StmtExprMutator {
       DataType datatype) {
     auto it = bounds_.find(key);
     CHECK(it != bounds_.end());
-    Array<Expr> min_bound;
+    Array<PrimExpr> min_bound;
     for (auto i : it->second) {
       min_bound.push_back(i->min);
     }
 
     CHECK_GE(it->second.size(), 2);
-    Array<Expr> shape;
+    Array<PrimExpr> shape;
     for (size_t i = 0; i < it->second.size() - 2; ++i) {
       shape.push_back(it->second[i]->extent);
     }
@@ -1087,9 +1087,9 @@ class TensorCoreIRMutator : public StmtExprMutator {
     shape.push_back(tile_size[0]);
     shape.push_back(tile_size[1]);
 
-    Array<Expr> strides;
+    Array<PrimExpr> strides;
     for (size_t i = 1; i < shape.size(); ++i) {
-      Expr stride = IntImmNode::make(DataType::Int(32), 1);
+      PrimExpr stride = IntImmNode::make(DataType::Int(32), 1);
       for (size_t j = shape.size() - 1; j >= i; --j) {
         stride = MulNode::make(stride, shape[j]);
       }
@@ -1097,7 +1097,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
     }
     strides.push_back(make_const(DataType::Int(32), 1));
 
-    Expr elem_offset = IntImmNode::make(DataType::Int(32), 0);
+    PrimExpr elem_offset = IntImmNode::make(DataType::Int(32), 0);
     CHECK_EQ(call->args.size(), min_bound.size());
     for (size_t i = 0; i < min_bound.size(); i++) {
       elem_offset = AddNode::make(
@@ -1126,7 +1126,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
     tensor_node->dtype = datatype;
     Tensor tensor(tensor_node);
 
-    Array<Expr> args;
+    Array<PrimExpr> args;
     for (size_t i = 0; i < call->args.size(); ++i) {
       args.push_back(call->args[i]);
       args.push_back(shape[i]);
@@ -1144,12 +1144,12 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
   std::unordered_map<std::string, std::string> matrix_abc_;
   std::unordered_map<std::string, std::string> matrix_major_;
-  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
-  std::unordered_map<std::string, Array<Expr>> strides_;
+  std::unordered_map<const ProvideNode*, Array<PrimExpr>> mma_sync_;
+  std::unordered_map<std::string, Array<PrimExpr>> strides_;
   std::unordered_set<std::string> frag_reg_;
   std::unordered_map<const VarNode*, unsigned> loop_scaling_;
-  std::unordered_map<const ProvideNode*, Expr> frag_load_;
-  std::unordered_map<const ProvideNode*, Expr> frag_store_;
+  std::unordered_map<const ProvideNode*, PrimExpr> frag_load_;
+  std::unordered_map<const ProvideNode*, PrimExpr> frag_store_;
   std::unordered_map<TensorKey, Region> bounds_;
   Tile warp_tile_;
   int warp_threads_y_{-1};
index e2e7ad0..b2c50f7 100644 (file)
@@ -143,7 +143,7 @@ class LoopUnroller : public StmtExprMutator {
     CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
     if (value == 0) return EvaluateNode::make(0);
     Stmt body = op->body;
-    Map<Var, Expr> vmap;
+    Map<Var, PrimExpr> vmap;
     Array<Stmt> unrolled;
     for (int i = 0; i < value; ++i) {
       vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
@@ -157,7 +157,7 @@ class LoopUnroller : public StmtExprMutator {
   // returns the extent of the loop if it's a constant integer, otherwise return -1
   int GetExtent(const ForNode* op) {
     // constant folding.
-    Expr extent = ir::Simplify(op->extent);
+    PrimExpr extent = ir::Simplify(op->extent);
     const IntImmNode  *v1 = extent.as<IntImmNode>();
     const UIntImmNode *v2 = extent.as<UIntImmNode>();
     int value = -1;
index 450c6ba..c9f3441 100644 (file)
@@ -33,7 +33,7 @@
 namespace tvm {
 namespace ir {
 
-inline Expr BroadcastTo(Expr e, int lanes) {
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
   if (e.dtype().lanes() == lanes) return e;
   if (const BroadcastNode* op = e.as<BroadcastNode>()) {
     if (lanes % op->lanes == 0) {
@@ -59,8 +59,8 @@ class VecAllocAccess : public StmtExprMutator {
   VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
       : buf_(buf), var_(var), var_lanes_(var_lanes) {}
   // Load
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr expr = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<LoadNode>();
     if (op->buffer_var.get() == buf_) {
       return LoadNode::make(op->dtype, op->buffer_var,
@@ -111,18 +111,18 @@ class Vectorizer : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const AddNode* op) final {
+  PrimExpr VisitExpr_(const AddNode* op) final {
     return AddSubVec(op);
   }
-  Expr VisitExpr_(const SubNode* op) final {
+  PrimExpr VisitExpr_(const SubNode* op) final {
     return AddSubVec(op);
   }
-  Expr VisitExpr_(const MulNode* op) final {
-    Expr a = this->VisitExpr(op->a);
-    Expr b = this->VisitExpr(op->b);
+  PrimExpr VisitExpr_(const MulNode* op) final {
+    PrimExpr a = this->VisitExpr(op->a);
+    PrimExpr b = this->VisitExpr(op->b);
     if (a.same_as(op->a) &&
         b.same_as(op->b)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
       if (lanes != 1) {
@@ -141,51 +141,51 @@ class Vectorizer : public StmtExprMutator {
     }
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const DivNode* op) final {
+  PrimExpr VisitExpr_(const DivNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const ModNode* op) final {
+  PrimExpr VisitExpr_(const ModNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const FloorDivNode* op) final {
+  PrimExpr VisitExpr_(const FloorDivNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const FloorModNode* op) final {
+  PrimExpr VisitExpr_(const FloorModNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const MinNode* op) final {
+  PrimExpr VisitExpr_(const MinNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const MaxNode* op) final {
+  PrimExpr VisitExpr_(const MaxNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const EQNode* op) final {
+  PrimExpr VisitExpr_(const EQNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const NENode* op) final {
+  PrimExpr VisitExpr_(const NENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const LTNode* op) final {
+  PrimExpr VisitExpr_(const LTNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const LENode* op) final {
+  PrimExpr VisitExpr_(const LENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const GTNode* op) final {
+  PrimExpr VisitExpr_(const GTNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const GENode* op) final {
+  PrimExpr VisitExpr_(const GENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const AndNode* op) final {
+  PrimExpr VisitExpr_(const AndNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const OrNode* op) final {
+  PrimExpr VisitExpr_(const OrNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const RampNode* op) final {
-    Expr base = this->VisitExpr(op->base);
-    Expr stride = this->VisitExpr(op->stride);
+  PrimExpr VisitExpr_(const RampNode* op) final {
+    PrimExpr base = this->VisitExpr(op->base);
+    PrimExpr stride = this->VisitExpr(op->stride);
     if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
       const RampNode* base_ramp = base.as<RampNode>();
       if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
@@ -195,7 +195,7 @@ class Vectorizer : public StmtExprMutator {
     int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
     base = BroadcastTo(base, lanes);
     stride = BroadcastTo(stride, lanes);
-    Array<Expr> elems;
+    Array<PrimExpr> elems;
     for (int i = 0; i < lanes; ++i) {
       elems.push_back(
           RampNode::make(ShuffleNode::make_extract_element(base, i),
@@ -204,14 +204,14 @@ class Vectorizer : public StmtExprMutator {
     }
     return ShuffleNode::make_concat(elems);
   }
-  Expr VisitExpr_(const SelectNode *op) final {
-    Expr cond = this->VisitExpr(op->condition);
-    Expr t = this->VisitExpr(op->true_value);
-    Expr f = this->VisitExpr(op->false_value);
+  PrimExpr VisitExpr_(const SelectNode *op) final {
+    PrimExpr cond = this->VisitExpr(op->condition);
+    PrimExpr t = this->VisitExpr(op->true_value);
+    PrimExpr f = this->VisitExpr(op->false_value);
     if (cond.same_as(op->condition) &&
         t.same_as(op->true_value) &&
         f.same_as(op->false_value)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(std::max(
           cond.dtype().lanes(),
@@ -219,37 +219,37 @@ class Vectorizer : public StmtExprMutator {
       return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
     }
   }
-  Expr VisitExpr_(const CastNode *op) final {
-    Expr value = this->VisitExpr(op->value);
+  PrimExpr VisitExpr_(const CastNode *op) final {
+    PrimExpr value = this->VisitExpr(op->value);
     if (value.same_as(op->value)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value);
     }
   }
   // Variable
-  Expr VisitExpr_(const VarNode* v) final {
+  PrimExpr VisitExpr_(const VarNode* v) final {
     if (v == var_.get()) {
       return ramp_;
     } else if (lets_.count(v)) {
         return lets_[v];
     } else {
-      return GetRef<Expr>(v);
+      return GetRef<PrimExpr>(v);
     }
   }
   // IfThenElse expr
-  Expr MutateIfThenElseExpr_(const CallNode *op) {
-    Expr cond = this->VisitExpr(op->args[0]);
+  PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
+    PrimExpr cond = this->VisitExpr(op->args[0]);
     if (cond.dtype().is_vector())  {
       need_scalarize_ = true;
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
-    Expr t = this->VisitExpr(op->args[1]);
-    Expr f = this->VisitExpr(op->args[2]);
+    PrimExpr t = this->VisitExpr(op->args[1]);
+    PrimExpr f = this->VisitExpr(op->args[2]);
     if (cond.same_as(op->args[0]) &&
         t.same_as(op->args[1]) &&
         f.same_as(op->args[2])) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
       t = BroadcastTo(t, lanes);
@@ -260,33 +260,33 @@ class Vectorizer : public StmtExprMutator {
     }
   }
   // Call
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->name == intrinsic::tvm_if_then_else) {
       return MutateIfThenElseExpr_(op);
     }
     if (!op->is_vectorizable()) {
       // Cannot vectorize this op
-      Array<Expr> new_args;
+      Array<PrimExpr> new_args;
       for (auto arg : op->args) {
         auto new_arg = this->VisitExpr(arg);
         if (new_arg.dtype().is_vector()) {
           need_scalarize_ = true;
-          return GetRef<Expr>(op);
+          return GetRef<PrimExpr>(op);
         }
         new_args.push_back(new_arg);
       }
       if (op->args.same_as(new_args)) {
-        return GetRef<Expr>(op);
+        return GetRef<PrimExpr>(op);
       } else {
         return CallNode::make(
             op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
       }
     } else {
       int lane = 0;
-      Array<Expr> new_args = MutateArray(op->args, &lane);
+      Array<PrimExpr> new_args = MutateArray(op->args, &lane);
       // normal code path.
       if (op->args.same_as(new_args)) {
-        return GetRef<Expr>(op);
+        return GetRef<PrimExpr>(op);
       } else {
         return CallNode::make(
             op->dtype.with_lanes(lane), op->name, new_args,
@@ -295,11 +295,11 @@ class Vectorizer : public StmtExprMutator {
     }
   }
   // Load
-  Expr VisitExpr_(const LoadNode* op) final {
-    Expr index = this->VisitExpr(op->index);
-    Expr pred = this->VisitExpr(op->predicate);
+  PrimExpr VisitExpr_(const LoadNode* op) final {
+    PrimExpr index = this->VisitExpr(op->index);
+    PrimExpr pred = this->VisitExpr(op->predicate);
     if (index.same_as(op->index) && pred.same_as(op->predicate)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
       return LoadNode::make(
@@ -310,18 +310,18 @@ class Vectorizer : public StmtExprMutator {
     }
   }
   // Let
-  Expr VisitExpr_(const LetNode* op) final {
-    Expr value = this->VisitExpr(op->value);
+  PrimExpr VisitExpr_(const LetNode* op) final {
+    PrimExpr value = this->VisitExpr(op->value);
     CHECK(!lets_.count(op->var.get())) << "not SSA";
     if (value.dtype().lanes() != op->value.dtype().lanes()) {
       Var v(op->var->name_hint, value.dtype());
       lets_[op->var.get()] = v;
       return LetNode::make(v, value, this->VisitExpr(op->body));
     } else {
-      Expr body = this->VisitExpr(op->body);
+      PrimExpr body = this->VisitExpr(op->body);
       if (value.same_as(op->value) &&
           body.same_as(op->body)) {
-        return GetRef<Expr>(op);
+        return GetRef<PrimExpr>(op);
       } else {
         return LetNode::make(op->var, value, body);
       }
@@ -329,9 +329,9 @@ class Vectorizer : public StmtExprMutator {
   }
   // Provide
   Stmt VisitStmt_(const ProvideNode* op) final {
-    Expr new_value = this->VisitExpr(op->value);
+    PrimExpr new_value = this->VisitExpr(op->value);
     int lane = new_value.dtype().lanes();
-    Array<Expr> new_args = MutateArray(op->args, &lane);
+    Array<PrimExpr> new_args = MutateArray(op->args, &lane);
     if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
       return GetRef<Stmt>(op);
     } else {
@@ -341,9 +341,9 @@ class Vectorizer : public StmtExprMutator {
   }
   // Store
   Stmt VisitStmt_(const StoreNode* op) final {
-    Expr value = this->VisitExpr(op->value);
-    Expr index = this->VisitExpr(op->index);
-    Expr pred = this->VisitExpr(op->predicate);
+    PrimExpr value = this->VisitExpr(op->value);
+    PrimExpr index = this->VisitExpr(op->index);
+    PrimExpr pred = this->VisitExpr(op->predicate);
     if (value.same_as(op->value) && index.same_as(op->index)) {
       return GetRef<Stmt>(op);
     } else {
@@ -362,7 +362,7 @@ class Vectorizer : public StmtExprMutator {
     }
     CHECK(is_zero(op->min));
     CHECK(!op->extent.dtype().is_vector());
-    Expr extent = this->VisitExpr(op->extent);
+    PrimExpr extent = this->VisitExpr(op->extent);
     if (extent.dtype().is_vector()) {
       return Scalarize(GetRef<Stmt>(op));
     }
@@ -379,7 +379,7 @@ class Vectorizer : public StmtExprMutator {
   // IfThenElse
   Stmt VisitStmt_(const IfThenElseNode* op) final {
     CHECK(!op->condition.dtype().is_vector());
-    Expr condition = this->VisitExpr(op->condition);
+    PrimExpr condition = this->VisitExpr(op->condition);
     if (condition.dtype().is_vector()) {
       return Scalarize(GetRef<Stmt>(op));
     }
@@ -407,14 +407,14 @@ class Vectorizer : public StmtExprMutator {
       LOG(WARNING) << "Cannot vectorize with new expr";
       return Scalarize(GetRef<Stmt>(op));
     }
-    Expr condition = this->VisitExpr(op->condition);
+    PrimExpr condition = this->VisitExpr(op->condition);
     if (condition.dtype().is_vector()) {
       LOG(WARNING) << "Cannot handle vector extent in alloc ";
       return Scalarize(GetRef<Stmt>(op));
     }
-    Array<Expr> extents;
+    Array<PrimExpr> extents;
     for (size_t i = 0; i < op->extents.size(); i++) {
-      Expr new_ext = this->VisitExpr(op->extents[i]);
+      PrimExpr new_ext = this->VisitExpr(op->extents[i]);
       if (new_ext.dtype().is_vector()) {
         LOG(WARNING) << "Cannot handle vector extent in alloc ";
         return Scalarize(GetRef<Stmt>(op));
@@ -435,7 +435,7 @@ class Vectorizer : public StmtExprMutator {
   // scalarize the statment
   Stmt Scalarize(Stmt stmt) {
     Var idx(var_->name_hint + ".s", var_->dtype);
-    Map<Var, Expr> values{{var_, idx}};
+    Map<Var, PrimExpr> values{{var_, idx}};
     stmt = Substitute(stmt, values);
     return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
   }
@@ -448,21 +448,21 @@ class Vectorizer : public StmtExprMutator {
   // the lanes.
   int var_lanes_;
   // ramp representing the var.
-  Expr ramp_;
+  PrimExpr ramp_;
   // flag to mark requirment of scalarization.
   bool need_scalarize_{false};
   // The lets
-  std::unordered_map<const VarNode*, Expr> lets_;
+  std::unordered_map<const VarNode*, PrimExpr> lets_;
   // mutate array, with given lane requirement
   // when finished, p_lane updates the lane requirement.
-  Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) {
+  Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
     if (arr.size() == 0) return arr;
     int& lanes = *p_lanes;
     bool changed = false;
-    std::vector<Expr> new_arr(arr.size());
+    std::vector<PrimExpr> new_arr(arr.size());
     for (size_t i = 0; i < arr.size(); i++) {
-      Expr old_elem = arr[i];
-      Expr new_elem = this->VisitExpr(old_elem);
+      PrimExpr old_elem = arr[i];
+      PrimExpr new_elem = this->VisitExpr(old_elem);
       if (!new_elem.same_as(old_elem)) changed = true;
       new_arr[i] = new_elem;
       lanes = std::max(lanes, new_elem.dtype().lanes());
@@ -475,27 +475,27 @@ class Vectorizer : public StmtExprMutator {
       }
     }
     if (!changed) return arr;
-    return Array<Expr>(new_arr);
+    return Array<PrimExpr>(new_arr);
   }
   template<typename T>
-  Expr BinaryVec(const T* op) {
-    Expr a = this->VisitExpr(op->a);
-    Expr b = this->VisitExpr(op->b);
+  PrimExpr BinaryVec(const T* op) {
+    PrimExpr a = this->VisitExpr(op->a);
+    PrimExpr b = this->VisitExpr(op->b);
     if (a.same_as(op->a) &&
         b.same_as(op->b)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
       return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
     }
   }
   template<typename T>
-  Expr AddSubVec(const T* op) {
-    Expr a = this->VisitExpr(op->a);
-    Expr b = this->VisitExpr(op->b);
+  PrimExpr AddSubVec(const T* op) {
+    PrimExpr a = this->VisitExpr(op->a);
+    PrimExpr b = this->VisitExpr(op->b);
     if (a.same_as(op->a) &&
         b.same_as(op->b)) {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     } else {
       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
       if (lanes != 1) {
index 96f231e..24f3e19 100644 (file)
@@ -100,7 +100,7 @@ class GPUCodeVerifier : public StmtVisitor {
         visited_shared_buffers_.insert(op->node.as<tvm::VarNode>());
       }
     } else if (op->attr_key == attr::thread_extent) {
-      VarExpr var = op->node.as<tvm::IterVarNode>()->var;
+      Var var = op->node.as<tvm::IterVarNode>()->var;
       const auto *extent = op->value.as<IntImmNode>();
       CHECK(extent);
 
@@ -169,7 +169,7 @@ class GPUCodeVerifier : public StmtVisitor {
 };
 
 bool VerifyGPUCode(Stmt stmt,
-                   Map<std::string, Expr> constraints) {
+                   Map<std::string, PrimExpr> constraints) {
   GPUCodeVerifier verifier;
 
   int64_t max_local_memory_per_block = INT64_MAX;
index 25e7258..899e9bc 100644 (file)
@@ -65,7 +65,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
  protected:
   /// Visitor implementation
   //@{
-  void VisitExpr(const Expr &n) final {
+  void VisitExpr(const PrimExpr &n) final {
     if (Failed()) return;
     StmtExprVisitor::VisitExpr(n);
   }
@@ -130,7 +130,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   }
 
   /// Handle memory access to a Variable
-  void HandleLoadStoreToVariable(const VarExpr &var) {
+  void HandleLoadStoreToVariable(const Var &var) {
     // We skip the access within thread env.
     if (InThreadEnv()) return;
 
@@ -181,7 +181,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   //@}
   LoweredFunc func_{nullptr};  ///< Function to be verified.
   int dev_type_{kDLCPU};       ///< Device type
-  std::unordered_map<const VarNode *, Expr> defs_;  ///< Variable definitions
+  std::unordered_map<const VarNode *, PrimExpr> defs_;  ///< Variable definitions
 };
 }  // namespace
 
index 69731ea..1883558 100644 (file)
@@ -83,7 +83,7 @@ struct GraphCodegen {
 
   std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
     std::unordered_map<std::string, tvm::runtime::NDArray> ret;
-    auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
+    auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr);
     for (auto expr : names) {
       auto key = expr.as<ir::StringImmNode>()->value;
       ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
@@ -190,8 +190,8 @@ class RelayBuildModule : public runtime::ModuleNode {
    *
    * \return Array<StringImm> names of params
    */
-  Array<tvm::Expr> ListParamNames() {
-    Array<tvm::Expr> ret;
+  Array<tvm::PrimExpr> ListParamNames() {
+    Array<tvm::PrimExpr> ret;
     for (const auto& kv : params_) {
       ret.push_back(ir::StringImmNode::make(kv.first));
     }
index 6c511ae..62de1c3 100644 (file)
@@ -203,7 +203,7 @@ class ScheduleGetter :
           return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
         } else {
           LOG(FATAL) << "not handled";
-          return tvm::Expr();
+          return tvm::PrimExpr();
         }
       }, "compile_engine_const", topi::kBroadcast);
     scalars_.push_back(value->op);
@@ -479,7 +479,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
             return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
           } else {
             LOG(FATAL) << "not handled";
-            return tvm::Expr();
+            return tvm::PrimExpr();
           }
       }, "data_const", topi::kBroadcast);
       scalars_.push_back(value);
index 618e135..3ff72b3 100644 (file)
@@ -641,9 +641,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
       });
     } else if (name == "list_params_name") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        Array<tvm::Expr> ret;
+        Array<tvm::PrimExpr> ret;
         for (const auto &kv : this->output_.params) {
-          tvm::Expr name = ir::StringImmNode::make(kv.first);
+          tvm::PrimExpr name = ir::StringImmNode::make(kv.first);
           ret.push_back(name);
         }
         *rv = ret;
index 7946ee6..5d262a0 100644 (file)
@@ -50,7 +50,7 @@ namespace transform {
 
 Pass LambdaLift();
 Pass InlinePrimitives();
-Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
+Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
 
 Pass ManifestAlloc(Target target_host) {
   auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
@@ -875,7 +875,7 @@ void VMCompiler::Lower(Module mod,
 
 Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
   Array<Pass> pass_seqs;
-  Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
+  Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
   // Run all dialect legalization passes.
   pass_seqs.push_back(relay::qnn::transform::Legalize());
index cea6115..5de2e92 100644 (file)
@@ -100,7 +100,7 @@ struct CallTracer : ExprVisitor {
  * \return The module with dead functions removed.
  */
 Module RemoveUnusedFunctions(const Module& module,
-                             Array<tvm::Expr> entry_funcs) {
+                             Array<tvm::PrimExpr> entry_funcs) {
   std::unordered_set<std::string> called_funcs{};
   for (auto entry : entry_funcs) {
     auto* str_name = entry.as<ir::StringImmNode>();
@@ -121,7 +121,7 @@ Module RemoveUnusedFunctions(const Module& module,
 
 namespace transform {
 
-Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
+Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
   runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
     [=](Module m, PassContext pc) {
     return relay::vm::RemoveUnusedFunctions(m, entry_functions);
index a6f44ce..f6ebadf 100644 (file)
@@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 
 TensorType ConstantNode::tensor_type() const {
   auto dtype = DataType(data->dtype);
-  Array<tvm::Expr> shape;
+  Array<tvm::PrimExpr> shape;
   for (int i = 0; i < data->ndim; i++) {
     CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
     CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
index d179d7e..3bc72fd 100644 (file)
@@ -127,7 +127,7 @@ class RelayHashHandler:
   using AttrsHashHandler::VisitAttr_;
   size_t VisitAttr_(const tvm::VarNode* var) final {
     size_t hash = std::hash<std::string>()(VarNode::_type_key);
-    auto it = hash_map_.find(GetRef<VarExpr>(var));
+    auto it = hash_map_.find(GetRef<tvm::Var>(var));
     if (it != hash_map_.end()) {
       return it->second;
     }
index b83750f..b888ecb 100644 (file)
@@ -137,10 +137,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
 // Frontend APIs
 TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
 .set_body_typed([]() {
-    Array<tvm::Expr> ret;
+    Array<tvm::PrimExpr> ret;
     for (const std::string& name :
              dmlc::Registry<OpRegistry>::ListAllNames()) {
-      ret.push_back(tvm::Expr(name));
+      ret.push_back(tvm::PrimExpr(name));
     }
     return ret;
   });
index a8f2e86..68aa77b 100644 (file)
@@ -211,7 +211,7 @@ bool BinaryDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   CHECK(static_cast<int>(data->shape.size()) != 0);
   CHECK(param->units.defined());
 
-  Array<tvm::Expr> oshape = data->shape;
+  Array<tvm::PrimExpr> oshape = data->shape;
   oshape.Set((oshape.size() - 1), param->units);
 
   DataType out_dtype = param->out_dtype;
index aeb40fd..35d6cba 100644 (file)
@@ -136,7 +136,7 @@ bool FIFOBufferRel(const Array<Type>& types,
   }
   reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);
 
-  Array<tvm::Expr> oshape = buffer->shape;
+  Array<tvm::PrimExpr> oshape = buffer->shape;
 
   reporter->Assign(types[2], TensorTypeNode::make(oshape, buffer->dtype));
   return true;
@@ -877,7 +877,7 @@ bool BatchMatmulRel(const Array<Type>& types,
       << " x shape=" << x->shape
       << ", y shape=" << y->shape;
 
-  Array<tvm::Expr> oshape = x->shape;
+  Array<tvm::PrimExpr> oshape = x->shape;
   oshape.Set(2, y->shape[1]);
 
   // assign output type
index 110e435..1b27dea 100644 (file)
@@ -42,9 +42,9 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   CHECK(static_cast<int>(data->shape.size()) != 0);
 
-  Array<tvm::Expr> oshape = data->shape;
+  Array<tvm::PrimExpr> oshape = data->shape;
   if (param->units.defined()) {
-    Array<tvm::Expr> dshape = data->shape;
+    Array<tvm::PrimExpr> dshape = data->shape;
     // validate the weight shape is proper if defined
     // Assign weight type
     Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
@@ -56,7 +56,7 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     oshape.Set((oshape.size() - 1), param->units);
   } else {
     if (weight == nullptr) return false;
-    Array<tvm::Expr> wshape = weight->shape;
+    Array<tvm::PrimExpr> wshape = weight->shape;
     oshape.Set((oshape.size() - 1), wshape[0]);
   }
 
index f9d753f..72ea70f 100644 (file)
@@ -52,7 +52,7 @@ Array<Array<Layout> > PadInferCorrectLayout(
     // split.
 
     // 1) Create a map from axis to param_width using old layout.
-    std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
+    std::map<std::string, tvm::Array<tvm::PrimExpr>> axis_pad_width;
     int index_counter = 0;
     CHECK_EQ(new_in_layouts.size(), 1);
     CHECK_EQ(old_in_layouts.size(), 1);
@@ -63,7 +63,7 @@ Array<Array<Layout> > PadInferCorrectLayout(
     }
 
     // 2) Create new pad width by walking over the new layout and using the map.
-    tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
+    tvm::Array<tvm::Array<tvm::PrimExpr>> new_pad_width;
     for (auto iter_var : new_in_layouts[0]->axes) {
       const auto& new_layout_axis = LayoutAxis::Get(iter_var);
       auto axis_name = new_layout_axis.name();
index b8ee7e7..c9d824d 100644 (file)
@@ -1223,7 +1223,7 @@ inline Tensor DynamicArange(const tvm::Tensor& start,
                             tvm::DataType dtype,
                             std::string name = "tensor",
                             std::string tag = topi::kInjective) {
-  tvm::Expr num_elem = tvm::Var("num_elem");
+  tvm::PrimExpr num_elem = tvm::Var("num_elem");
   return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
     return tvm::cast(dtype, start[0] + step[0] * indices[0]);
   }, name, tag);
@@ -1237,7 +1237,7 @@ Array<Tensor> ArangeCompute(const Attrs& attrs,
   Tensor start = inputs[0];
   Tensor stop =  inputs[1];
   Tensor step = inputs[2];
-  Array<tvm::Expr> empty = {0};
+  Array<tvm::PrimExpr> empty = {0};
   return { DynamicArange(start, stop, step, param->dtype) };
 }
 
index c5202b5..b2cef6c 100644 (file)
@@ -44,7 +44,7 @@ namespace relay {
  * \return The adjusted Layout.
  */
 inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
-                                       const Array<tvm::Expr>& old_shape) {
+                                       const Array<tvm::PrimExpr>& old_shape) {
   // For each subordinate axis
   //   1) Find the corresponding dual axis.
   //   2) Find the Index of this dual axis in old_layout.
index fea463d..f5c65e5 100644 (file)
@@ -269,7 +269,7 @@ class SequentialNode : public PassNode {
 
 PassInfo PassInfoNode::make(int opt_level,
                             std::string name,
-                            tvm::Array<tvm::Expr> required) {
+                            tvm::Array<tvm::PrimExpr> required) {
   auto pass_info = make_object<PassInfoNode>();
   pass_info->opt_level = opt_level;
   pass_info->name = std::move(name);
@@ -370,7 +370,7 @@ void SequentialNode::ResolveDependency(const Module& mod) {
 }
 
 // linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
+inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
                               const std::string& pass_name) {
   for (auto x : pass_array) {
     auto* str_name = x.as<ir::StringImmNode>();
@@ -428,7 +428,7 @@ Pass CreateModulePass(
     const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::Expr>& required) {
+    const tvm::Array<tvm::PrimExpr>& required) {
   PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
   return ModulePassNode::make(pass_func, pass_info);
 }
@@ -437,7 +437,7 @@ Pass CreateFunctionPass(
     const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func,
     int opt_level,
     const std::string& name,
-    const tvm::Array<tvm::Expr>& required) {
+    const tvm::Array<tvm::PrimExpr>& required) {
   PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
   return FunctionPassNode::make(pass_func, pass_info);
 }
@@ -507,7 +507,7 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
   tvm::Array<Pass> passes = args[0];
   int opt_level = args[1];
   std::string name = args[2];
-  tvm::Array<tvm::Expr> required = args[3];
+  tvm::Array<tvm::PrimExpr> required = args[3];
   PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
   *ret = Sequential(passes, pass_info);
 });
@@ -533,8 +533,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
   auto pctx = PassContext::Create();
   int opt_level = args[0];
   int fallback_device = args[1];
-  tvm::Array<tvm::Expr> required = args[2];
-  tvm::Array<tvm::Expr> disabled = args[3];
+  tvm::Array<tvm::PrimExpr> required = args[2];
+  tvm::Array<tvm::PrimExpr> disabled = args[3];
   pctx->opt_level = opt_level;
   pctx->fallback_device = fallback_device;
   pctx->required_pass = std::move(required);
index 1f47b20..c62520a 100644 (file)
@@ -202,7 +202,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       return ulhs;
     }
 
-    return tvm::Expr();
+    return tvm::PrimExpr();
   }
 
   Type VisitType_(const TensorTypeNode* op, const Type& tn) final {
@@ -243,8 +243,8 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       if (!dim.defined()) {
         // NB: We push an arbitrary dimension here so we can continue error propogation.
         shape.push_back(tt1->shape[i]);
-        tvm::Expr shape1 = tt1->shape[i];
-        tvm::Expr shape2 = tt2->shape[i];
+        tvm::PrimExpr shape1 = tt1->shape[i];
+        tvm::PrimExpr shape2 = tt2->shape[i];
         std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2);
         mismatches.push_back(tuple);
       } else {
index 6579c3d..e01a47d 100644 (file)
@@ -50,7 +50,7 @@ bool DequantizeRel(const Array<Type>& types,
   CHECK(IsScalarType(types[1], DataType::Float(32)));  // input_scale
   CHECK(IsScalarType(types[2], DataType::Int(32)));    // input_zero_point
 
-  const Array<tvm::Expr> oshape = data->shape;
+  const Array<tvm::PrimExpr> oshape = data->shape;
   // assign output type, output will always be float 32.
   reporter->Assign(types[3], TensorTypeNode::make(oshape, DataType::Float(32)));
   return true;
index f6133b8..f53d2c5 100644 (file)
@@ -57,7 +57,7 @@ bool QuantizeRel(const Array<Type>& types,
   AssignType(types[1], DataType::Float(32), data->shape[axis], reporter);  // scale
   AssignType(types[2], DataType::Int(32), data->shape[axis], reporter);    // zero point
 
-  const Array<tvm::Expr> oshape = data->shape;
+  const Array<tvm::PrimExpr> oshape = data->shape;
   const DataType out_dtype = quantize_attrs->out_dtype;
   CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
         out_dtype == DataType::Int(32))
index a8a9195..2686965 100644 (file)
@@ -190,7 +190,7 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
   CHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
 
-  const Array<tvm::Expr> oshape = data->shape;
+  const Array<tvm::PrimExpr> oshape = data->shape;
   // assign output type
   auto out_dtype = requantize_attrs->out_dtype;
   CHECK(out_dtype == DataType::Int(8) ||
index e4b5cf1..378a5e3 100644 (file)
@@ -94,7 +94,7 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
                          attrs.operator->(), input_shape, out_dtype);
 }
 
-static inline int64_t get_const_int(const tvm::Expr& x) {
+static inline int64_t get_const_int(const tvm::PrimExpr& x) {
   auto* value_ptr = as_const_int(x);
   CHECK(value_ptr) << "Expr is not a constant int";
   return value_ptr[0];
index 6c8df8b..3e32923 100644 (file)
@@ -33,13 +33,13 @@ class ElemWiseDetector : public ir::ExprVisitor {
  public:
   explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}
 
-  void VisitExpr(const Expr& e) final {
+  void VisitExpr(const PrimExpr& e) final {
     if (!is_elem_wise_) return;
     ExprVisitor::VisitExpr(e);
   }
 
   void VisitExpr_(const CallNode* op) final {
-    Array<Expr> axis = op->args;
+    Array<PrimExpr> axis = op->args;
     if (axis_.size() != axis.size()) {
       is_elem_wise_ = false;
       return;
index 82ee8ff..3cf30f4 100644 (file)
@@ -300,7 +300,7 @@ Array<Operation> ScanGetBody(const Operation& scan_op) {
   return GetSubGraph(scan->update, inputs, false);
 }
 
-Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan_op) {
   const ScanOpNode* scan = scan_op.as<ScanOpNode>();
   Array<Operation> body = ScanGetBody(scan_op);
 
@@ -377,7 +377,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
     }
   }
   ReachGraph reach;
-  Map<IterVar, Expr> ret;
+  Map<IterVar, PrimExpr> ret;
   std::unordered_set<TensorDimKey> place_holder_ref;
   for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
     for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
index d596335..99ba6e3 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -123,7 +123,7 @@ Array<Operation> ScanGetBody(const Operation& scan_op);
  * \param scan The scan node.
  * \return Map of spatial_axis -> IntImm
  */
-Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan);
+Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
 
 }  // namespace schedule
 }  // namespace tvm
index d08b4be..869e305 100644 (file)
@@ -55,7 +55,7 @@ void PassDownDomain(const Stage& stage,
                     std::unordered_map<IterVar, Range>* p_state,
                     arith::Analyzer* actx,
                     bool allow_missing) {
-  auto ceil_div = [actx](Expr a, Expr b) {
+  auto ceil_div = [actx](PrimExpr a, PrimExpr b) {
     if (actx->CanProve(indexmod(a, b) == 0)) {
       return actx->Simplify(indexdiv(a, b));
     }
@@ -118,7 +118,7 @@ void PassDownDomain(const Stage& stage,
 
 void PassUpIndex(const Stage& stage,
                  const Map<IterVar, Range>& dom_map,
-                 std::unordered_map<IterVar, Expr>* p_state,
+                 std::unordered_map<IterVar, PrimExpr>* p_state,
                  bool allow_missing) {
   auto& state = *p_state;
   for (size_t i = stage->relations.size(); i != 0; --i) {
@@ -128,10 +128,10 @@ void PassUpIndex(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      Expr outer = state.at(s->outer);
-      Expr inner = state.at(s->inner);
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr parent_min = dom_map.at(s->parent)->min;
+      PrimExpr outer = state.at(s->outer);
+      PrimExpr inner = state.at(s->inner);
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
       state[s->parent] = inner + outer * factor;
       // add min if they exist
       if (!is_zero(parent_min)) {
@@ -142,10 +142,10 @@ void PassUpIndex(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      Expr value = state.at(s->fused);
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr outer_min = dom_map.at(s->outer)->min;
-      Expr inner_min = dom_map.at(s->inner)->min;
+      PrimExpr value = state.at(s->fused);
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr outer_min = dom_map.at(s->outer)->min;
+      PrimExpr inner_min = dom_map.at(s->inner)->min;
       state[s->outer] = indexdiv(value, factor);
       state[s->inner] = indexmod(value, factor);
       // add min if they exist
@@ -160,8 +160,8 @@ void PassUpIndex(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      Expr value = state.at(s->rebased);
-      Expr parent_min = dom_map.at(s->parent)->min;
+      PrimExpr value = state.at(s->rebased);
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
       // add min if they exist
       if (!is_zero(parent_min)) {
         state[s->parent] = value + parent_min;
@@ -177,7 +177,7 @@ void PassUpIndex(const Stage& stage,
 
 void PassDownIndex(const Stage& stage,
                    const Map<IterVar, Range>& dom_map,
-                   std::unordered_map<IterVar, Expr>* p_state,
+                   std::unordered_map<IterVar, PrimExpr>* p_state,
                    bool allow_missing) {
   auto& state = *p_state;
   for (IterVarRelation rel : stage->relations) {
@@ -188,8 +188,8 @@ void PassDownIndex(const Stage& stage,
       }
       Range r = dom_map.at(s->inner);
       CHECK(is_zero(r->min));
-      Expr parent = state.at(s->parent);
-      Expr factor = r->extent;
+      PrimExpr parent = state.at(s->parent);
+      PrimExpr factor = r->extent;
       state[s->outer] = indexdiv(parent, factor);
       state[s->inner] = indexmod(parent, factor);
     } else if (const FuseNode* s = rel.as<FuseNode>()) {
@@ -197,11 +197,11 @@ void PassDownIndex(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr outer_min = dom_map.at(s->outer)->min;
-      Expr inner_min = dom_map.at(s->inner)->min;
-      Expr inner = state.at(s->inner);
-      Expr outer = state.at(s->outer);
+      PrimExpr factor = dom_map.at(s->inner)->extent;
+      PrimExpr outer_min = dom_map.at(s->outer)->min;
+      PrimExpr inner_min = dom_map.at(s->inner)->min;
+      PrimExpr inner = state.at(s->inner);
+      PrimExpr outer = state.at(s->outer);
       CHECK(is_zero(outer_min));
       CHECK(is_zero(inner_min));
       state[s->fused] = outer * factor + inner;
@@ -210,8 +210,8 @@ void PassDownIndex(const Stage& stage,
         CHECK(allow_missing);
         continue;
       }
-      Expr value = state.at(s->parent);
-      Expr parent_min = dom_map.at(s->parent)->min;
+      PrimExpr value = state.at(s->parent);
+      PrimExpr parent_min = dom_map.at(s->parent)->min;
       CHECK(is_zero(parent_min));
       state[s->rebased] = value;
     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
@@ -236,8 +236,8 @@ void PassUpDomain(const SplitNode* s,
     *parent = IntSet::range(dom_map.at(s->parent));
     return;
   }
-  Expr factor = dom_map.at(s->inner)->extent;
-  Expr parent_min = dom_map.at(s->parent)->min;
+  PrimExpr factor = dom_map.at(s->inner)->extent;
+  PrimExpr parent_min = dom_map.at(s->parent)->min;
   CHECK(outer.defined());
   CHECK(inner.defined());
   CHECK(factor.defined());
@@ -260,21 +260,21 @@ void PassUpDomain(const FuseNode* s,
     *inner = IntSet::range(dom_map.at(s->inner));
     return;
   }
-  Expr outer_min = dom_map.at(s->outer)->min;
-  Expr inner_min = dom_map.at(s->inner)->min;
+  PrimExpr outer_min = dom_map.at(s->outer)->min;
+  PrimExpr inner_min = dom_map.at(s->inner)->min;
 
   if (fused.is_single_point()) {
-    Expr value = fused.point_value();
-    Expr factor = dom_map.at(s->inner)->extent;
-    Expr v_outer  = indexdiv(value, factor);
-    Expr v_inner  = indexmod(value, factor);
+    PrimExpr value = fused.point_value();
+    PrimExpr factor = dom_map.at(s->inner)->extent;
+    PrimExpr v_outer  = indexdiv(value, factor);
+    PrimExpr v_inner  = indexmod(value, factor);
     if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
     if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
     *outer = IntSet::single_point(v_outer);
     *inner = IntSet::single_point(v_inner);
   } else {
-    Expr fused_extent = (fused.max() - fused.min() + 1);
-    Expr inner_extent = dom_map.at(s->inner)->extent;
+    PrimExpr fused_extent = (fused.max() - fused.min() + 1);
+    PrimExpr inner_extent = dom_map.at(s->inner)->extent;
     *outer = IntSet::interval(
         outer_min + indexdiv(fused.min(), inner_extent),
         outer_min + indexdiv(fused.max(), inner_extent));
@@ -305,7 +305,7 @@ void PassUpDomain(const RebaseNode* s,
     *parent = IntSet::range(dom_map.at(s->parent));
     return;
   }
-  Expr parent_min = dom_map.at(s->parent)->min;
+  PrimExpr parent_min = dom_map.at(s->parent)->min;
   *parent = arith::EvalSet(s->rebased->var + parent_min,
                            {{s->rebased, rebased}});
 }
@@ -458,8 +458,8 @@ void PassUpBoundCheck(const Stage& s,
       bool inner = state.at(s->inner);
 
       if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
-        Expr factor = dom_map.at(s->inner)->extent;
-        Expr step = dom_map.at(s->outer)->extent;
+        PrimExpr factor = dom_map.at(s->inner)->extent;
+        PrimExpr step = dom_map.at(s->outer)->extent;
         if (outer || inner) {
           state[s->parent] = true;
         } else {
@@ -486,10 +486,10 @@ void PassUpBoundCheck(const Stage& s,
   }
 }
 
-std::vector<Expr> MakeBoundCheck(
+std::vector<PrimExpr> MakeBoundCheck(
     const Stage& stage,
     const Map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, Expr>& value_map,
+    const std::unordered_map<IterVar, PrimExpr>& value_map,
     bool skip_ivar_domain,
     const std::unordered_set<IterVar>& skip_iter) {
   arith::Analyzer analyzer;
@@ -500,7 +500,7 @@ std::vector<Expr> MakeBoundCheck(
   }
   PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
 
-  std::vector<Expr> preds;
+  std::vector<PrimExpr> preds;
   std::unordered_map<const VarNode*, IntSet> iset_dmap;
 
   // setup domain map for set analysis
@@ -512,8 +512,8 @@ std::vector<Expr> MakeBoundCheck(
     if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
     if (bound_state.at(iv)) {
       Range dom = dom_map.at(iv);
-      Expr value = value_map.at(iv) - dom->min;
-      Expr vmax = EvalSet(value, iset_dmap).max();
+      PrimExpr value = value_map.at(iv) - dom->min;
+      PrimExpr vmax = EvalSet(value, iset_dmap).max();
       if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
         preds.emplace_back(value < dom->extent);
       }
@@ -524,10 +524,10 @@ std::vector<Expr> MakeBoundCheck(
     Range dom = dom_map.at(iv);
     CHECK(iv->dom.defined());
     if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
-      Expr value = value_map.at(iv) - iv->dom->min;
+      PrimExpr value = value_map.at(iv) - iv->dom->min;
       IntSet s = EvalSet(value, iset_dmap);
-      Expr vmin = s.min();
-      Expr vmax = s.max();
+      PrimExpr vmin = s.min();
+      PrimExpr vmax = s.max();
       // The range of `value` resides in [vmin, vmax]
       if (vmin.dtype() != value.dtype() || !analyzer.CanProve(vmin >= 0)) {
         preds.emplace_back(value >= 0);
index f46f305..6661589 100644 (file)
@@ -62,7 +62,7 @@ void PassDownDomain(
  */
 void PassUpIndex(const Stage& stage,
                  const Map<IterVar, Range>& dom_map,
-                 std::unordered_map<IterVar, Expr>* p_state,
+                 std::unordered_map<IterVar, PrimExpr>* p_state,
                  bool allow_missing = false);
 
 /*!
@@ -76,7 +76,7 @@ void PassUpIndex(const Stage& stage,
  */
 void PassDownIndex(const Stage& stage,
                    const Map<IterVar, Range>& dom_map,
-                   std::unordered_map<IterVar, Expr>* p_state,
+                   std::unordered_map<IterVar, PrimExpr>* p_state,
                    bool allow_missing = false);
 
 /*!
@@ -120,11 +120,11 @@ void PassDownBitMaskOr(const Stage& stage,
  * \param skip_iter The set of variables to skip bound condition.
  * \return List of predicates that we need to check.
  */
-std::vector<Expr>
+std::vector<PrimExpr>
 MakeBoundCheck(
     const Stage& stage,
     const Map<IterVar, Range>& dom_map,
-    const std::unordered_map<IterVar, Expr>& value_map,
+    const std::unordered_map<IterVar, PrimExpr>& value_map,
     bool skip_ivar_domain,
     const std::unordered_set<IterVar>& skip_iter);
 
index a6500ca..3bad338 100644 (file)
@@ -45,20 +45,20 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) {
 class VarReplacer : public ir::StmtExprMutator {
  public:
   explicit VarReplacer(
-      const std::unordered_map<const VarNode*, Expr>& vsub)
+      const std::unordered_map<const VarNode*, PrimExpr>& vsub)
       : vsub_(vsub) {}
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = vsub_.find(op);
     if (it != vsub_.end()) return it->second;
-    return GetRef<Expr>(op);
+    return GetRef<PrimExpr>(op);
   }
 
   ir::CommReducer MutateCommReducer(ir::CommReducer combiner) {
     // Replace free variables in combiner
-    auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) {
+    auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) {
       return this->VisitExpr(e);
       });
-    auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) {
+    auto new_result = ir::UpdateArray(combiner->result, [this] (const PrimExpr& e) {
       return this->VisitExpr(e);
       });
 
@@ -71,8 +71,8 @@ class VarReplacer : public ir::StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const ir::ReduceNode* op) final {
-    Expr new_e = StmtExprMutator::VisitExpr_(op);
+  PrimExpr VisitExpr_(const ir::ReduceNode* op) final {
+    PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
     const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
     ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
     if (op->combiner.same_as(new_combiner)) {
@@ -88,21 +88,21 @@ class VarReplacer : public ir::StmtExprMutator {
   }
 
  private:
-  const std::unordered_map<const VarNode*, Expr>& vsub_;
+  const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
 };
 
-Expr InjectPredicate(const Array<Expr>& predicates,
-                     Expr body) {
+PrimExpr InjectPredicate(const Array<PrimExpr>& predicates,
+                     PrimExpr body) {
   using ir::ReduceNode;
   using ir::SelectNode;
   if (predicates.size() == 0) return body;
   const ReduceNode* reduce = body.as<ReduceNode>();
   if (reduce) {
     auto n = make_object<ReduceNode>(*reduce);
-    n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, Expr());
-    return Expr(n);
+    n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr());
+    return PrimExpr(n);
   }
-  return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, Expr()),
+  return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()),
                       body,
                       make_zero(body.dtype()));
 }
@@ -153,7 +153,7 @@ Tensor Schedule::cache_read(const Tensor& tensor,
   Stage s = operator[](tensor->op);
   Tensor sugar_tensor = s->op.output(tensor->value_index);
   Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
-      return sugar_tensor(Array<Expr>(i.begin(), i.end()));
+      return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
     }, os.str());
   vsub[sugar_tensor] = cache;
 
@@ -193,9 +193,9 @@ void PrepareAxisMapping(Stage orig_stage,
                         std::unordered_set<IterVar>* p_red_axis,
                         Array<IterVar>* p_new_axis,
                         std::unordered_map<IterVar, Range>* p_dom_map,
-                        std::unordered_map<const VarNode*, Expr>* p_vsub,
-                        std::unordered_map<const VarNode*, Expr>* p_vsub2newvar,
-                        std::vector<Expr>* p_predicates) {
+                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
+                        std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
+                        std::vector<PrimExpr>* p_predicates) {
   auto& red_axis = *p_red_axis;
   auto& new_axis = *p_new_axis;
   auto& dom_map = *p_dom_map;
@@ -214,7 +214,7 @@ void PrepareAxisMapping(Stage orig_stage,
   schedule::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
   {
     // The source->cache
-    std::unordered_map<IterVar, Expr> value_map;
+    std::unordered_map<IterVar, PrimExpr> value_map;
     for (IterVar iv : orig_stage->leaf_iter_vars) {
       if (red_axis.count(iv)) continue;
       CHECK_EQ(iv->iter_type, kDataPar)
@@ -305,15 +305,15 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
   Array<IterVar> new_axis;
   std::unordered_map<IterVar, Range> dom_map;
 
-  std::unordered_map<const VarNode*, Expr> vsub;
-  std::unordered_map<const VarNode*, Expr> vsub2newvar;
-  std::vector<Expr> predicates;
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
+  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+  std::vector<PrimExpr> predicates;
 
   PrepareAxisMapping(orig_stage, compute,
     &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
 
-  Expr body;
-  Array<Expr> body_list;
+  PrimExpr body;
+  Array<PrimExpr> body_list;
   const ir::ReduceNode* first_reduce = nullptr;
   for (auto cbody : compute->body) {
     body = VarReplacer(vsub)(cbody);
@@ -340,10 +340,10 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
     body_list.push_back(body);
   }
   // The reader args
-  Array<Expr> args;
+  Array<PrimExpr> args;
   {
     // cache->compute
-    std::unordered_map<IterVar, Expr> value_map;
+    std::unordered_map<IterVar, PrimExpr> value_map;
     for (IterVar iv : compute->axis) {
       value_map[iv] = iv->var;
     }
@@ -357,7 +357,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
       compute->name + "." + scope, compute->tag, compute->attrs,
       new_axis, body_list);
 
-  Array<Expr> cache_expr_list;
+  Array<PrimExpr> cache_expr_list;
   for (size_t i = 0; i < tensor_size; i++) {
     Tensor cache_tensor = cache_op.output(i);
     cache_expr_list.push_back(cache_tensor(args));
@@ -386,9 +386,9 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
   Array<IterVar> new_axis;
   std::unordered_map<IterVar, Range> dom_map;
 
-  std::unordered_map<const VarNode*, Expr> vsub;
-  std::unordered_map<const VarNode*, Expr> vsub2newvar;
-  std::vector<Expr> predicates;
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
+  std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
+  std::vector<PrimExpr> predicates;
 
   PrepareAxisMapping(orig_stage, tensor_op,
     &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates);
@@ -404,15 +404,15 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
   for (Region old_region : tensor_op->input_regions) {
     Region region;
     for (Range r : old_region) {
-      Expr min = VarReplacer(vsub2newvar)(r->min);
-      Expr extent = VarReplacer(vsub2newvar)(r->extent);
+      PrimExpr min = VarReplacer(vsub2newvar)(r->min);
+      PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
       region.push_back(Range::make_by_min_extent(min, extent));
     }
     new_regions.push_back(region);
   }
 
-  Array<Expr> new_scalar_inputs;
-  for (Expr old_input : tensor_op->scalar_inputs) {
+  Array<PrimExpr> new_scalar_inputs;
+  for (PrimExpr old_input : tensor_op->scalar_inputs) {
     new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
   }
 
@@ -430,10 +430,10 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
   }
 
   // The reader args
-  Array<Expr> args;
+  Array<PrimExpr> args;
   {
     // cache->compute
-    std::unordered_map<IterVar, Expr> value_map;
+    std::unordered_map<IterVar, PrimExpr> value_map;
     for (IterVar iv : compute_axis) {
       value_map[iv] = iv->var;
     }
@@ -449,7 +449,7 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
     }
   }
 
-  Array<Expr> cache_expr_list;
+  Array<PrimExpr> cache_expr_list;
   for (size_t i = 0; i < tensor_size; i++) {
     Tensor cache_tensor = cache_op.output(i);
     cache_expr_list.push_back(cache_tensor(args));
@@ -542,7 +542,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
 void InjectInline(ScheduleNode* sch) {
   sch->InvalidateCache();
 
-  std::vector<Array<Expr> > new_body(sch->stages.size());
+  std::vector<Array<PrimExpr> > new_body(sch->stages.size());
   std::vector<bool> changed(sch->stages.size(), false);
   std::vector<Stmt> new_hybrid_body(sch->stages.size());
   std::vector<bool> hybrid_changed(sch->stages.size(), false);
@@ -552,7 +552,7 @@ void InjectInline(ScheduleNode* sch) {
     if (stage->attach_type == kInline) {
       stage->attach_type = kInlinedAlready;
       Array<Var> args;
-      Expr body;
+      PrimExpr body;
       {
         // setup args
         const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
@@ -583,7 +583,7 @@ void InjectInline(ScheduleNode* sch) {
                   << "The Reduce inputs of ComputeOp should "
                   << "have the same attribute except value_index";
             }
-            Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+            PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
                                         stage->op, args, body).as<ir::EvaluateNode>()->value;
             if (!new_value.same_as(new_body[j][0])) {
               changed[j] = true;
@@ -594,12 +594,12 @@ void InjectInline(ScheduleNode* sch) {
                 auto n = make_object<ir::ReduceNode>(*r);
                 n->value_index = static_cast<int>(k);
                 n->dtype = r->source[k].dtype();
-                new_body[j].Set(k, Expr(n));
+                new_body[j].Set(k, PrimExpr(n));
               }
             }
           } else {
             for (size_t k = 0; k < new_body[j].size(); ++k) {
-              Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+              PrimExpr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
                                           stage->op, args, body).as<ir::EvaluateNode>()->value;
               if (!new_value.same_as(new_body[j][k])) {
                 new_body[j].Set(k, new_value);
@@ -706,7 +706,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   arith::Analyzer analyzer;
   // Get the replace index
   std::unordered_map<IterVar, Range> dom_map;
-  std::unordered_map<IterVar, Expr> value_map;
+  std::unordered_map<IterVar, PrimExpr> value_map;
   for (IterVar iv : compute_op->reduce_axis) {
     if (touch_map.count(iv)) {
       dom_map[iv] = iv->dom;
@@ -727,7 +727,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
     }
   }
   schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
-  std::vector<Expr> predicates = schedule::MakeBoundCheck(
+  std::vector<PrimExpr> predicates = schedule::MakeBoundCheck(
       reduce_stage, dom_map, value_map, true, skip_bound_check);
 
   // Get the factored op node.
@@ -761,16 +761,16 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   predicates.push_back(reduce->condition);
-  Expr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, Expr()));
+  PrimExpr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, PrimExpr()));
 
-  std::unordered_map<const VarNode*, Expr> vsub;
+  std::unordered_map<const VarNode*, PrimExpr> vsub;
 
   for (IterVar iv : compute_op->reduce_axis) {
     if (!touch_map.count(iv)) {
       n->reduce_axis.push_back(iv);
     } else {
       CHECK(value_map.count(iv));
-      Expr index = value_map.at(iv);
+      PrimExpr index = value_map.at(iv);
       vsub[iv->var.get()] = index;
     }
   }
@@ -785,12 +785,12 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
     }
   }
   VarReplacer replacer(vsub);
-  Array<Expr> new_source = ir::UpdateArray(reduce->source,
-    [&replacer] (const Expr& e) { return replacer(e); });
+  Array<PrimExpr> new_source = ir::UpdateArray(reduce->source,
+    [&replacer] (const PrimExpr& e) { return replacer(e); });
 
-  Expr new_pred = replacer(predicate);
+  PrimExpr new_pred = replacer(predicate);
 
-  std::vector<Expr> body;
+  std::vector<PrimExpr> body;
   for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
     body.emplace_back(ReduceNode::make(reduce->combiner,
                                    new_source,
@@ -798,7 +798,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
                                    new_pred,
                                    idx));
   }
-  n->body = Array<Expr>(body);
+  n->body = Array<PrimExpr>(body);
   // refresh relations, keep the un-touched relations.
   Array<IterVarRelation> rels;
   for (IterVarRelation rel : reduce_stage->relations) {
@@ -842,7 +842,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   }
   Array<Tensor> repl_tensors = compute(old_tensors[0]->shape,
     [&](const Array<Var>& i) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       const int idx_size = static_cast<int>(i.size());
       for (int idx = 0; idx < idx_size; ++idx) {
         if (factor_axis_pos == idx) {
@@ -853,13 +853,13 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
       if (factor_axis_pos == idx_size) {
           indices.push_back(repl_red_axis->var);
       }
-      Array<Expr> factor_exprs;
+      Array<PrimExpr> factor_exprs;
       for (int idx = 0; idx < size; ++idx) {
         factor_exprs.push_back(factor_tensors[idx](indices));
       }
-      Array<Expr> reductions;
+      Array<PrimExpr> reductions;
       Array<IterVar> axis = {repl_red_axis};
-      Expr cond = const_true();
+      PrimExpr cond = const_true();
       for (int idx = 0; idx < size; ++idx) {
         reductions.push_back(ReduceNode::make(reduce->combiner,
           factor_exprs, axis, cond, idx));
index a53c1ae..fe56b66 100644 (file)
@@ -55,8 +55,8 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
 
 void Split(StageNode* self,
            IterVar parent,
-           Expr factor,
-           Expr nparts,
+           PrimExpr factor,
+           PrimExpr nparts,
            IterVar* p_outer,
            IterVar* p_inner) {
   // Check if split is valid.
@@ -217,21 +217,21 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
   return *this;
 }
 
-Stage& Stage::set_store_predicate(Expr predicate) {
+Stage& Stage::set_store_predicate(PrimExpr predicate) {
   StageNode* self = operator->();
   self->store_predicate = predicate;
   return *this;
 }
 
 Stage& Stage::split(
-    IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner) {  // NOLINT(*)
-  Split(operator->(), parent, factor, Expr(), p_outer, p_inner);
+    IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) {  // NOLINT(*)
+  Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
   return *this;
 }
 
 Stage& Stage::split_by_nparts(
-    IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
-  Split(operator->(), parent, Expr(), nparts, p_outer, p_inner);
+    IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*)
+  Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
   return *this;
 }
 
@@ -332,7 +332,7 @@ Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
 }
 
 Stage& Stage::tile(IterVar x_parent, IterVar y_parent,
-                   Expr x_factor, Expr y_factor,
+                   PrimExpr x_factor, PrimExpr y_factor,
                    IterVar* p_x_outer, IterVar* p_y_outer,
                    IterVar* p_x_inner, IterVar* p_y_inner) {
   split(x_parent, x_factor, p_x_outer, p_x_inner);
@@ -400,7 +400,7 @@ Stage& Stage::parallel(IterVar var) {   // NOLINT(*)
 
 Stage& Stage::pragma(IterVar var,
                      const std::string& pragma_type,
-                     const Expr& pragma_value) {   // NOLINT(*)
+                     const PrimExpr& pragma_value) {   // NOLINT(*)
   if (pragma_type == "unroll") {
     this->unroll(var);
   } else if (pragma_type == "vectorize") {
@@ -415,7 +415,7 @@ Stage& Stage::pragma(IterVar var,
   return *this;
 }
 
-Stage& Stage::prefetch(const Tensor &tensor, IterVar var, Expr offset) {
+Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) {
   StageNode *self = operator->();
   ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
   ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
@@ -756,8 +756,8 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
 IterVarRelation SplitNode::make(IterVar parent,
                                 IterVar outer,
                                 IterVar inner,
-                                Expr factor,
-                                Expr nparts) {
+                                PrimExpr factor,
+                                PrimExpr nparts) {
   auto n = make_object<SplitNode>();
   n->parent = parent;
   n->outer = outer;
index 38174df..1eb595c 100644 (file)
@@ -277,13 +277,13 @@ class SchedulePostProc : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
+  PrimExpr VisitExpr_(const CallNode* op) final {
     if (op->call_type == CallNode::Halide) {
       TensorKey key{op->func, op->value_index};
       auto it = replace_buffer_.find(key);
       if (it != replace_buffer_.end()) {
         const Tensor& dst = it->second;
-        Expr ret = CallNode::make(
+        PrimExpr ret = CallNode::make(
             op->dtype, dst->op->name, op->args,
             op->call_type, dst->op, dst->value_index);
         return this->VisitExpr(ret);
@@ -292,12 +292,12 @@ class SchedulePostProc : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const VarNode* op) final {
+  PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = var_value_.find(op);
     if (it != var_value_.end()) {
       return it->second;
     } else {
-      return GetRef<Expr>(op);
+      return GetRef<PrimExpr>(op);
     }
   }
 
@@ -343,9 +343,9 @@ class SchedulePostProc : public StmtExprMutator {
     replace_op_[src->op.get()] = repl_op;
   }
   // The thread extent scope.
-  std::unordered_map<const Object*, Expr> thread_extent_scope_;
+  std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
   // The scan value
-  std::unordered_map<const VarNode*, Expr> var_value_;
+  std::unordered_map<const VarNode*, PrimExpr> var_value_;
   // buffer replacement
   std::unordered_map<TensorKey, Tensor> replace_buffer_;
   // buffere realization to be replaced
index a6010c3..9a24257 100644 (file)
@@ -28,7 +28,7 @@ namespace test {
 struct TestAttrs : public AttrsNode<TestAttrs> {
   int axis;
   std::string name;
-  Expr expr;
+  PrimExpr expr;
   double learning_rate;
 
   TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") {
@@ -70,10 +70,10 @@ TEST(Attrs, Basic) {
     LOG(FATAL) << "bad";
   } catch (const tvm::AttrError& e) {
     std::string what = e.what();
-    CHECK(what.find("expr : Expr, default=1") != std::string::npos);
+    CHECK(what.find("expr : PrimExpr, default=1") != std::string::npos);
     CHECK(what.find("axisx") != std::string::npos);
   }
-  n->InitBySeq("learning_rate", Expr(1), "expr", 128, "name", "xx");
+  n->InitBySeq("learning_rate", PrimExpr(1), "expr", 128, "name", "xx");
   CHECK_EQ(n->learning_rate, 1.0);
 
   n->InitBySeq("name", "xxx", "expr", 128);
@@ -84,7 +84,7 @@ TEST(Attrs, Basic) {
   std::ostringstream os;
   n->PrintDocString(os);
   LOG(INFO) << "docstring\n"<< os.str();
-  CHECK(os.str().find("expr : Expr, default=1") != std::string::npos);
+  CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos);
 }
 
 
index 6998657..15bdd47 100644 (file)
 TEST(BuildModule, Basic) {
   using namespace tvm;
   auto n = var("n");
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   shape.push_back(n);
 
   auto A = placeholder(shape, DataType::Float(32), "A");
   auto B = placeholder(shape, DataType::Float(32), "B");
 
-  auto C = compute(A->shape, [&A, &B](Expr i) {
+  auto C = compute(A->shape, [&A, &B](PrimExpr i) {
     return A[i] + B[i];
   }, "C");
 
@@ -88,18 +88,18 @@ TEST(BuildModule, Heterogeneous) {
 
   // The shape of input tensors.
   const int n = 4;
-  Array<Expr> shape{n};
+  Array<PrimExpr> shape{n};
 
   auto A = placeholder(shape, DataType::Float(32), "A");
   auto B = placeholder(shape, DataType::Float(32), "B");
   auto C = placeholder(shape, DataType::Float(32), "C");
 
-  auto elemwise_add = compute(A->shape, [&A, &B](Expr i) {
+  auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) {
     return A[i] + B[i];
   }, "elemwise_add");
 
   auto copy = placeholder(shape, DataType::Float(32), "__copy");
-  auto elemwise_sub = compute(C->shape, [&copy, &C](Expr i) {
+  auto elemwise_sub = compute(C->shape, [&copy, &C](PrimExpr i) {
     return copy[i] - C[i];
   }, "elemwise_sub");
 
index 3d7c355..d5d8aae 100644 (file)
@@ -144,11 +144,11 @@ TEST(InplaceArrayBase, ExceptionSafety) {
   ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), "");
 }
 
-TEST(Array, Expr) {
+TEST(Array, PrimExpr) {
   using namespace tvm;
   Var x("x");
   auto z = max(x + 1 + 2, 100);
-  Array<Expr> list{x, z, z};
+  Array<PrimExpr> list{x, z, z};
   LOG(INFO) << list.size();
   LOG(INFO) << list[0];
   LOG(INFO) << list[1];
@@ -158,7 +158,7 @@ TEST(Array, Mutate) {
   using namespace tvm;
   Var x("x");
   auto z = max(x + 1 + 2, 100);
-  Array<Expr> list{x, z, z};
+  Array<PrimExpr> list{x, z, z};
   auto list2 = list;
   list.Set(1, x);
   CHECK(list[1].same_as(x));
@@ -167,8 +167,8 @@ TEST(Array, Mutate) {
 
 TEST(Array, Iterator) {
   using namespace tvm;
-  Array<Expr> array{1, 2, 3};
-  std::vector<Expr> vector(array.begin(), array.end());
+  Array<PrimExpr> array{1, 2, 3};
+  std::vector<PrimExpr> vector(array.begin(), array.end());
   CHECK(vector[1].as<IntImmNode>()->value == 2);
 }
 
@@ -177,7 +177,7 @@ TEST(Map, Expr) {
   Var x("x");
   auto z = max(x + 1 + 2, 100);
   auto zz = z + 1;
-  Map<Expr, Expr> dict{{x, z}, {z, 2}};
+  Map<PrimExpr, PrimExpr> dict{{x, z}, {z, 2}};
   CHECK(dict.size() == 2);
   CHECK(dict[x].same_as(z));
   CHECK(dict.count(z));
@@ -188,7 +188,7 @@ TEST(StrMap, Expr) {
   using namespace tvm;
   Var x("x");
   auto z = max(x + 1 + 2, 100);
-  Map<std::string, Expr> dict{{"x", z}, {"z", 2}};
+  Map<std::string, PrimExpr> dict{{"x", z}, {"z", 2}};
   CHECK(dict.size() == 2);
   CHECK(dict["x"].same_as(z));
 }
@@ -197,7 +197,7 @@ TEST(Map, Mutate) {
   using namespace tvm;
   Var x("x");
   auto z = max(x + 1 + 2, 100);
-  Map<Expr, Expr> dict{{x, z}, {z, 2}};
+  Map<PrimExpr, PrimExpr> dict{{x, z}, {z, 2}};
   auto zz = z + 1;
   CHECK(dict[x].same_as(z));
   dict.Set(x, zz);
@@ -218,9 +218,9 @@ TEST(Map, Mutate) {
 
 TEST(Map, Iterator) {
   using namespace tvm;
-  Expr a = 1, b = 2;
-  Map<Expr, Expr> map1{{a, b}};
-  std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>
+  PrimExpr a = 1, b = 2;
+  Map<PrimExpr, PrimExpr> map1{{a, b}};
+  std::unordered_map<PrimExpr, PrimExpr, ObjectHash, ObjectEqual>
       map2(map1.begin(), map1.end());
   CHECK(map2[a].as<IntImmNode>()->value == 2);
 }
index 4b6915f..d9b6200 100644 (file)
@@ -26,7 +26,7 @@ TEST(Expr, Basic) {
   Var x("x");
   auto z = max(x + 1 + 2, 100);
   ObjectRef tmp = z;
-  Expr zz = Downcast<Expr>(tmp);
+  PrimExpr zz = Downcast<PrimExpr>(tmp);
   std::ostringstream os;
   os << z;
   CHECK(zz.same_as(z));
@@ -37,7 +37,7 @@ TEST(Expr, Basic) {
 TEST(ExprNodeRef, Basic) {
   using namespace tvm;
   Var x("x");
-  Expr z = max(x + 1 + 2, 100);
+  PrimExpr z = max(x + 1 + 2, 100);
   const ir::MaxNode* op = z.as<ir::MaxNode>();
   CHECK(GetRef<ObjectRef>(op).same_as(z));
 }
index 23a81b9..178f582 100644 (file)
@@ -61,7 +61,7 @@ TEST(IRF, ExprTransform) {
   auto z = x + 1;
 
   class MyExprFunctor
-      : public ir::ExprFunctor<int(const Expr&, int)> {
+      : public ir::ExprFunctor<int(const PrimExpr&, int)> {
    public:
     int VisitExpr_(const VarNode* op, int b) final {
       return b;
@@ -90,7 +90,7 @@ TEST(IRF, ExprVisit) {
   auto z = x + 1;
 
   class MyVisitor
-      : public ir::ExprFunctor<void(const Expr&)>,
+      : public ir::ExprFunctor<void(const PrimExpr&)>,
         public ir::StmtFunctor<void(const Stmt&)> {
    public:
     int count = 0;
@@ -152,13 +152,13 @@ TEST(IRF, StmtMutator) {
 
    protected:
     // implementation
-    Expr VisitExpr_(const AddNode* op) final {
+    PrimExpr VisitExpr_(const AddNode* op) final {
       return op->a;
     }
     Stmt VisitStmt_(const SeqStmtNode* op) final {
       return StmtMutator::VisitSeqStmt_(op, true);
     }
-    Expr VisitExpr(const Expr& expr) final {
+    PrimExpr VisitExpr(const PrimExpr& expr) final {
       return ExprMutator::VisitExpr(expr);
     }
   };
index 47cd000..d1316de 100644 (file)
@@ -26,7 +26,7 @@ TEST(IRSSA, Convert) {
   using namespace tvm;
   using namespace tvm::ir;
   Var x("x"), y;
-  Expr let = LetNode::make(x, 1, x + 1);
+  PrimExpr let = LetNode::make(x, 1, x + 1);
 
   auto z = EvaluateNode::make(let + let);
   CHECK(!ir::VerifySSA(z));
index 24ed6d8..b0e5f24 100644 (file)
@@ -130,7 +130,7 @@ TEST(PackedFunc, Expr) {
   using namespace tvm::runtime;
   // automatic conversion of int to expr
   PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
-      Expr x = args[0];
+      PrimExpr x = args[0];
       *rv = x.as<tvm::ir::IntImmNode>()->value + 1;
   });
   int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
@@ -218,7 +218,7 @@ TEST(PackedFunc, ObjectConversion) {
   // Check convert back
   CHECK(rv.operator NDArray().same_as(x));
   CHECK(rv.operator ObjectRef().same_as(x));
-  CHECK(!rv.IsObjectRef<Expr>());
+  CHECK(!rv.IsObjectRef<PrimExpr>());
 
   auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
       CHECK_EQ(args[0].type_code(), kNDArrayContainer);
@@ -228,7 +228,7 @@ TEST(PackedFunc, ObjectConversion) {
       CHECK(args[1].operator NDArray().get() == nullptr);
       CHECK(args[1].operator Module().get() == nullptr);
       CHECK(args[1].operator Array<NDArray>().get() == nullptr);
-      CHECK(!args[0].IsObjectRef<Expr>());
+      CHECK(!args[0].IsObjectRef<PrimExpr>());
     });
   pf1(x, ObjectRef());
   pf1(ObjectRef(x), NDArray());
@@ -254,7 +254,7 @@ TEST(PackedFunc, ObjectConversion) {
       CHECK(args[1].operator ObjectRef().get() == nullptr);
       CHECK(args[1].operator NDArray().get() == nullptr);
       CHECK(args[1].operator Module().get() == nullptr);
-      CHECK(!args[0].IsObjectRef<Expr>());
+      CHECK(!args[0].IsObjectRef<PrimExpr>());
     });
   pf2(m, ObjectRef());
   pf2(ObjectRef(m), Module());
index 2b03454..5392eae 100644 (file)
@@ -24,7 +24,7 @@ TEST(Pattern, Basic) {
   using namespace tvm;
   using namespace tvm::arith;
   Var x("x"), y("y"), z("z");
-  arith::PVar<Expr> px, py, pz;
+  arith::PVar<PrimExpr> px, py, pz;
   arith::PVar<DataType> pt;
   arith::PVar<int> planes;
 
@@ -49,7 +49,7 @@ TEST(Pattern, Basic) {
   CHECK((px + min(py, px)).Match(z + min(y, z)));
   CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2)));
   CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2)));
-  CHECK((px - floormod(py, px * PConst<Expr>(2))).Match(x - floormod(2, x * 2)));
+  CHECK((px - floormod(py, px * PConst<PrimExpr>(2))).Match(x - floormod(2, x * 2)));
 
   // logicals
   CHECK((px == pz).Match(x == 1));
@@ -111,10 +111,10 @@ TEST(Pattern, Basic) {
   }
   // ramp pattern
   {
-    CHECK(ramp(px, PConst<Expr>(1), planes).Match(
+    CHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(
         ir::RampNode::make(x, 1, 10)));
     CHECK(planes.Eval() == 10);
-    CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
+    CHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(
         ir::RampNode::make(x, 2, 10)));
   }
   // broadcast pattern
index 19c851b..6333f15 100644 (file)
@@ -25,7 +25,7 @@
 TEST(SimplePasses, HasSideEffect) {
   using namespace tvm;
   auto n = var("n");
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   shape.push_back(n);
 
   auto A = placeholder(shape, DataType::Float(32), "A");
index 3cb6e65..ef5c39d 100644 (file)
@@ -29,8 +29,8 @@ def test_operator_type_and_tags():
     B1 = B[0]
     B2 = B[0,0]
 
-    assert isinstance(k + n, tvm.expr.Expr)
-    assert isinstance(n + n, tvm.expr.Expr)
+    assert isinstance(k + n, tvm.expr.PrimExpr)
+    assert isinstance(n + n, tvm.expr.PrimExpr)
     assert isinstance(k + A, tvm.tensor.Tensor)
     assert isinstance(A + k, tvm.tensor.Tensor)
     assert isinstance(n + A, tvm.tensor.Tensor)
@@ -53,11 +53,11 @@ def test_operator_type_and_tags():
     assert (B + A).op.tag == topi.tag.BROADCAST
     assert (B + B).op.tag == topi.tag.BROADCAST
 
-    assert isinstance(k + B2, tvm.expr.Expr)
-    assert isinstance(B2 + k, tvm.expr.Expr)
-    assert isinstance(n + B2, tvm.expr.Expr)
-    assert isinstance(B2 + n, tvm.expr.Expr)
-    assert isinstance(B2 + B2, tvm.expr.Expr)
+    assert isinstance(k + B2, tvm.expr.PrimExpr)
+    assert isinstance(B2 + k, tvm.expr.PrimExpr)
+    assert isinstance(n + B2, tvm.expr.PrimExpr)
+    assert isinstance(B2 + n, tvm.expr.PrimExpr)
+    assert isinstance(B2 + B2, tvm.expr.PrimExpr)
     assert isinstance(B2 + A, tvm.tensor.Tensor)
     assert isinstance(A + B2, tvm.tensor.Tensor)
     assert isinstance(B2 + B, tvm.tensor.Tensor)
index d2d106d..02f8118 100644 (file)
@@ -19,7 +19,7 @@ import numpy as np
 
 def lower_intrin(stmt):
     """wrapper to call transformation in stmt"""
-    lower_expr = isinstance(stmt, tvm.expr.Expr)
+    lower_expr = isinstance(stmt, tvm.expr.PrimExpr)
     stmt = tvm.stmt.Evaluate(stmt) if lower_expr else stmt
     stmt = tvm.ir_pass.CanonicalSimplify(stmt)
     stmt  = tvm.ir_pass._LowerIntrinStmt(stmt, "llvm")
index 542d43e..ce16e23 100644 (file)
@@ -44,7 +44,7 @@ namespace topi {
  * \return A Tensor whose op member is a broadcast operation
  */
 inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
-                                const tvm::Array<tvm::Expr>& output_shape,
+                                const tvm::Array<tvm::PrimExpr>& output_shape,
                                 std::string name = "T_broadcast_to",
                                 std::string tag = kBroadcast) {
   CHECK_GE(output_shape.size(), t->shape.size())
@@ -59,38 +59,38 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
     return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
   };
   return tvm::compute(
-      tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()),
+      tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
       l,
       name,
       tag);
 }
 
 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)                       \
-  inline tvm::Expr Name(const tvm::Expr& a,                           \
-                        const tvm::Expr& b) {                         \
+  inline tvm::PrimExpr Name(const tvm::PrimExpr& a,                   \
+                            const tvm::PrimExpr& b) {                 \
     ComputeRule;                                                      \
   }                                                                   \
   inline tvm::Tensor Name(const tvm::Tensor& A,                       \
                           const tvm::Tensor& B,                       \
                           std::string name = "T_" #Name,              \
                           std::string tag = kBroadcast) {             \
-    auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; };           \
+    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };   \
     return detail::WithBroadcast(l, A, B, name, tag);                 \
   }                                                                   \
   inline tvm::Tensor Name(const tvm::Tensor& A,                       \
-                          const tvm::Expr& B,                         \
+                          const tvm::PrimExpr& B,                     \
                           std::string name = "T_" #Name,              \
                           std::string tag = kElementWise) {           \
-    auto l = [](tvm::Expr a, tvm::Expr b) { ComputeRule; };           \
+    auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \
     return compute(A->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
         return l(A(i), B);                                            \
       }, name, tag);                                                  \
   }                                                                   \
-  inline tvm::Tensor Name(const tvm::Expr& A,                         \
+  inline tvm::Tensor Name(const tvm::PrimExpr& A,                     \
                           const tvm::Tensor& B,                       \
                           std::string name = "T_" #Name,              \
                           std::string tag = kElementWise) {           \
-    auto l = [&](tvm::Expr a, tvm::Expr b) { ComputeRule; };          \
+    auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; };  \
     return compute(B->shape, [&](const ::tvm::Array<::tvm::Var>& i) { \
         return l(A, B(i));                                            \
       }, name, tag);                                                  \
@@ -102,12 +102,12 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
                           const tvm::Tensor& B) {                   \
     return topi::OpName(A, B);                                      \
   }                                                                 \
-  inline tvm::Tensor Name(const tvm::Expr& A,                       \
+  inline tvm::Tensor Name(const tvm::PrimExpr& A,                   \
                           const tvm::Tensor& B) {                   \
     return topi::OpName(A, B);                                      \
   }                                                                 \
   inline tvm::Tensor Name(const tvm::Tensor& A,                     \
-                          const tvm::Expr& B) {                     \
+                          const tvm::PrimExpr& B) {                 \
     return topi::OpName(A, B);                                      \
   }
 
index 4dce9a0..c644903 100644 (file)
@@ -52,7 +52,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        Expr("tvm.contrib.cublas.matmul"),
+        PrimExpr("tvm.contrib.cublas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
@@ -62,7 +62,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
 }
 
 /*!
-* \brief Create an op that multiplies batch matrices 
+* \brief Create an op that multiplies batch matrices
 *        lhs and rhs with cuBLAS
 *
 * \param lhs The left matrix operand
@@ -84,7 +84,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
     { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        Expr("tvm.contrib.cublas.batch_matmul"),
+        PrimExpr("tvm.contrib.cublas.batch_matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
index 6eb6dbc..a61499d 100644 (file)
@@ -51,7 +51,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        Expr("tvm.contrib.rocblas.matmul"),
+        PrimExpr("tvm.contrib.rocblas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
index 3bd3e4a..781258a 100644 (file)
@@ -119,8 +119,8 @@ inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs)
     auto thread_x = tvm::thread_axis(Range(), "threadIdx.x");
     s[dense].bind(tx, thread_x);
     s[dense_f].compute_at(s[dense], tx);
-    s[dense].set_store_predicate(static_cast<Expr>(thread_x) == 0);
-    s[out].set_store_predicate(static_cast<Expr>(thread_x) == 0);
+    s[dense].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
+    s[out].set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
   };
 
   std::function<void(Operation)> traverse;
index 3166d08..a82b363 100644 (file)
@@ -125,7 +125,7 @@ Schedule ScheduleReduce(const Target& target,
     }
   }
 
-  stage_real.set_store_predicate(static_cast<Expr>(thread_x) == 0);
+  stage_real.set_store_predicate(static_cast<PrimExpr>(thread_x) == 0);
   return sch;
 }
 
index 8c5068a..2d326e7 100644 (file)
@@ -37,18 +37,18 @@ namespace topi {
 namespace detail {
 
 struct BroadcastHelper {
-  std::deque<tvm::Expr> common_shape;
+  std::deque<tvm::PrimExpr> common_shape;
   std::deque<tvm::Var> all_vars;
   std::deque<tvm::Var> vars1;
   std::deque<tvm::Var> vars2;
 };
 
-inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
-                                      const tvm::Array<tvm::Expr>& shape2) {
+inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
+                                      const tvm::Array<tvm::PrimExpr>& shape2) {
   BroadcastHelper bh;
   int s1_size = shape1.size();
   int s2_size = shape2.size();
-  tvm::Expr one(1);
+  tvm::PrimExpr one(1);
   int i;
   for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
     // TODO(@icemelon9): Need to revisit this part
@@ -81,9 +81,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
     } else {
       CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
                    << " and " << shape2[s2_size - i] << " in: "
-                   << tvm::Array<tvm::Expr>(shape1.begin(), shape1.end())
+                   << tvm::Array<tvm::PrimExpr>(shape1.begin(), shape1.end())
                    << " and "
-                   << tvm::Array<tvm::Expr>(shape2.begin(), shape2.end());
+                   << tvm::Array<tvm::PrimExpr>(shape2.begin(), shape2.end());
     }
   }
   // Remaining dimensions whether on shape1 or shape2 can always be completed
@@ -98,12 +98,12 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
   return bh;
 }
 
-inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
+inline tvm::Array<tvm::PrimExpr> InputIndexFromBroadcast(
     const tvm::Array<tvm::Var>& ovars,
     const tvm::Tensor& T,
     const std::deque<tvm::Var>& my_vars,
     const std::deque<tvm::Var>& all_vars) {
-  tvm::Array<tvm::Expr> ivars;
+  tvm::Array<tvm::PrimExpr> ivars;
   CHECK_EQ(ovars.size(), all_vars.size());
   // N^2, could use a map but NBD.
   size_t expected_dims = T->shape.size();
@@ -138,7 +138,7 @@ inline tvm::Tensor WithBroadcast(FBinaryExpr op,
               B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
   };
   return tvm::compute(
-      tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()),
+      tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), bh.common_shape.end()),
       l,
       name,
       tag);
index 00db1fc..43ac3a2 100644 (file)
@@ -41,7 +41,7 @@ using namespace tvm;
  *
  * \return true if the given expr is a constant int or uint, false otherwise.
  */
-inline bool IsConstInt(Expr expr) {
+inline bool IsConstInt(PrimExpr expr) {
   return
     expr->IsInstance<tvm::ir::IntImmNode>() ||
     expr->IsInstance<tvm::ir::UIntImmNode>();
@@ -55,7 +55,7 @@ inline bool IsConstInt(Expr expr) {
  *
  * \return The integer value.
  */
-inline int64_t GetConstInt(Expr expr) {
+inline int64_t GetConstInt(PrimExpr expr) {
   if (expr->IsInstance<tvm::ir::IntImmNode>()) {
     return expr.as<tvm::ir::IntImmNode>()->value;
   }
@@ -75,11 +75,13 @@ inline int64_t GetConstInt(Expr expr) {
  *
  * \return A vector of the integer values
  */
-inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
+inline std::vector<int> GetConstIntValues(
+    Array<PrimExpr> exprs, const std::string& var_name) {
   std::vector<int> result;
   if (!exprs.defined()) return result;
   for (auto expr : exprs) {
-    CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
+    CHECK(IsConstInt(expr)) << "All elements of "
+                            << var_name << " must be constant integers";
     result.push_back(GetConstInt(expr));
   }
   return result;
@@ -94,7 +96,8 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
  *
  * \return A vector of the int64_t values
  */
-inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
+inline std::vector<int64_t> GetConstInt64Values(
+    Array<PrimExpr> exprs, const std::string& var_name) {
   std::vector<int64_t> result;
   if (!exprs.defined()) return result;
   for (auto expr : exprs) {
@@ -113,10 +116,10 @@ inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::st
  *
  * \return result True if both expressions are equal, else false
  */
-inline bool EqualCheck(Expr lhs, Expr rhs) {
+inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
   bool result = tvm::ir::Equal(lhs, rhs);
   if (!result) {
-    Expr zero(0);
+    PrimExpr zero(0);
     result = tvm::ir::Equal(tvm::ir::CanonicalSimplify(lhs-rhs), zero);
   }
   return result;
index 643b44b..8bdda80 100644 (file)
@@ -42,12 +42,12 @@ using namespace tvm;
  *
  * \return The Buffer object
  */
-inline Buffer DeclExternBuffer(Array<Expr> shape,
+inline Buffer DeclExternBuffer(Array<PrimExpr> shape,
                                DataType dtype,
                                std::string name) {
   auto data = var(name, DataType::Handle());
-  auto elem_offset = Expr();
-  return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
+  auto elem_offset = PrimExpr();
+  return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
                           -1, 0, kDefault);
 }
 
@@ -56,7 +56,7 @@ inline Buffer DeclExternBuffer(Array<Expr> shape,
  * function. The function expects two arguments: an array of Buffers holding the input
  * tensor values, and a pre-allocated array of Buffers to be filled with the outputs.
  */
-using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
+using FExtern = std::function<PrimExpr(Array<Buffer>, Array<Buffer>)>;
 
 /*!
  * \brief Create tensors representing the result of invoking an external function.
@@ -75,7 +75,7 @@ using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
  * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
  * element of out_types.
  */
-inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
+inline Array<Tensor> make_extern(const Array< Array<PrimExpr> >& out_shapes,
                                  const std::vector<DataType>& out_types,
                                  const Array<Tensor>& inputs,
                                  FExtern fextern,
@@ -116,18 +116,18 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
  *
  * \return An expression representing the pack operation
  */
-inline Expr pack_buffer(Buffer buf) {
+inline PrimExpr pack_buffer(Buffer buf) {
   CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
   auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
                                    buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
-  Expr strides;
+  PrimExpr strides;
   if (buf->strides.size() > 0) {
     strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
                                   buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
   } else {
     strides = 0;
   }
-  Array<Expr> pack_args{
+  Array<PrimExpr> pack_args{
     buf->data,
     shape,
     strides,
@@ -148,7 +148,7 @@ inline Expr pack_buffer(Buffer buf) {
  *
  * \return An expression representing the invocation
  */
-inline Expr call_packed(Array<Expr> args) {
+inline PrimExpr call_packed(Array<PrimExpr> args) {
   return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
                              args, tvm::ir::CallNode::CallType::Intrinsic);
 }
index 50e0f95..ec757e9 100644 (file)
@@ -42,7 +42,7 @@ using namespace tvm;
  * \return An array of 4 elements, representing padding sizes for
  * each individual side. The array is in the order { top, left, bottom, right }
  */
-inline Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
+inline Array<PrimExpr> GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) {
   pad_h *= 2;
   pad_w *= 2;
 
index 6cd4707..5526a7d 100644 (file)
@@ -41,10 +41,10 @@ using namespace tvm;
 *
 * \return The index after flattening
 */
-inline Expr RavelIndex(Array<Expr> indices, Array<Expr> shape) {
+inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
   CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
   CHECK_GT(indices.size(), 0) << "indices must not be empty";
-  Expr idx;
+  PrimExpr idx;
   for (size_t i = 0; i < indices.size(); ++i) {
     if (i == 0) {
       idx = indices[i];
@@ -63,8 +63,8 @@ inline Expr RavelIndex(Array<Expr> indices, Array<Expr> shape) {
 *
 * \return The coordinate corresponding to the 1D index
 */
-inline Array<Expr> UnravelIndex(Expr idx, Array<Expr> shape) {
-  std::vector<Expr> indices;
+inline Array<PrimExpr> UnravelIndex(PrimExpr idx, Array<PrimExpr> shape) {
+  std::vector<PrimExpr> indices;
 
   for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
     indices.push_back(indexmod(idx, shape[i]));
index fe23836..e52452e 100644 (file)
@@ -36,7 +36,7 @@ using namespace tvm;
  *
  * \return True if the input shape is empty.
  */
-inline bool is_empty_shape(const Array<Expr>& x) {
+inline bool is_empty_shape(const Array<PrimExpr>& x) {
   bool is_empty = false;
   for (const auto& dim : x) {
     if (auto int_dim = dim.as<IntImmNode>()) {
index dec94f3..df7cff0 100644 (file)
@@ -191,9 +191,9 @@ inline Tensor sign(const Tensor& x,
                    std::string name = "T_sign",
                    std::string tag = kElementWise) {
   return compute(x->shape, [&](const Array<Var>& i) {
-    Expr zero = make_zero(x->dtype);
-    Expr one = make_const(x->dtype, 1);
-    Expr minus_one = make_const(x->dtype, -1);
+    PrimExpr zero = make_zero(x->dtype);
+    PrimExpr one = make_const(x->dtype, 1);
+    PrimExpr minus_one = make_const(x->dtype, -1);
     auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero);
     auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1);
     return s2;
@@ -213,7 +213,7 @@ inline Tensor rsqrt(const Tensor& x,
                        std::string name = "tensor",
                        std::string tag = kElementWise) {
   return compute(x->shape, [&](const Array<Var>& i) {
-    Expr one = make_const(x->dtype, 1);
+    PrimExpr one = make_const(x->dtype, 1);
     return one/tvm::sqrt(x(i));
   }, name, tag);
 }
@@ -231,8 +231,8 @@ inline Tensor rsqrt(const Tensor& x,
 * \return A Tensor whose op member is the clip operation
 */
 inline Tensor clip(const Tensor& x,
-                   const Expr& a_min,
-                   const Expr& a_max,
+                   const PrimExpr& a_min,
+                   const PrimExpr& a_max,
                    std::string name = "T_clip",
                    std::string tag = kElementWise) {
   return compute(x->shape, [&](const Array<Var>& i) {
@@ -325,12 +325,12 @@ inline Tensor elemwise_sum(const Array<Tensor>& xs,
 *
 * \return A Tensor whose op member is the full operation
 */
-inline Tensor full(const Array<Expr>& shape,
+inline Tensor full(const Array<PrimExpr>& shape,
                    DataType dtype,
-                   const Expr fill_value,
+                   const PrimExpr fill_value,
                    std::string name = "T_full",
                    std::string tag = kElementWise) {
-  Expr ev = cast(dtype, fill_value);
+  PrimExpr ev = cast(dtype, fill_value);
   if (!ev.defined()) {
     LOG(ERROR) << "Can't cast fill_value to " << dtype;
   }
@@ -351,10 +351,10 @@ inline Tensor full(const Array<Expr>& shape,
 * \return A Tensor whose op memeber is the full_like operation
 */
 inline Tensor full_like(const Tensor& x,
-                        const Expr fill_value,
+                        const PrimExpr fill_value,
                         std::string name = "T_full_like",
                         std::string tag = kElementWise) {
-  Expr ev = cast(x->dtype, fill_value);
+  PrimExpr ev = cast(x->dtype, fill_value);
   return compute(x->shape, [&](const Array<Var>& i) {
       return ev;
   }, name, tag);
index 5f58562..f169ec9 100644 (file)
@@ -50,8 +50,8 @@ using namespace tvm;
  *
  * \return The interpolated value in the given index.
  */
-inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices,
-                                 const Expr max_y, const Expr max_x) {
+inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
+                                 const PrimExpr max_y, const PrimExpr max_x) {
   auto in_y = indices[2];
   auto yf = tvm::floor(in_y);
   auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y));
@@ -91,11 +91,11 @@ inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
-                                           const Array<Expr>& shape,
+                                           const Array<PrimExpr>& shape,
                                            bool align_corners = false,
                                            std::string name = "tensor",
                                            std::string tag = kInjective) {
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(cast(DataType::Int(32), shape[0]));
   out_shape.push_back(cast(DataType::Int(32), shape[1]));
@@ -103,7 +103,7 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-    Array<Expr> idx;
+    Array<PrimExpr> idx;
     idx.push_back(indices[0]);
     idx.push_back(indices[1] * input->shape[1] / shape[0]);
     idx.push_back(indices[2] * input->shape[2] / shape[1]);
@@ -125,11 +125,11 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
-                                           const Array<Expr>& shape,
+                                           const Array<PrimExpr>& shape,
                                            bool align_corners = false,
                                            std::string name = "tensor",
                                            std::string tag = kInjective) {
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
   out_shape.push_back(cast(DataType::Int(32), shape[0]));
@@ -137,7 +137,7 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-    Array<Expr> idx;
+    Array<PrimExpr> idx;
     idx.push_back(indices[0]);
     idx.push_back(indices[1]);
     idx.push_back(indices[2] * input->shape[2] / shape[0]);
@@ -159,11 +159,11 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
-                                            const Array<Expr>& shape,
+                                            const Array<PrimExpr>& shape,
                                             bool align_corners = false,
                                             std::string name = "tensor",
                                             std::string tag = kInjective) {
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
   out_shape.push_back(cast(DataType::Int(32), shape[0]));
@@ -172,7 +172,7 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-    Array<Expr> idx;
+    Array<PrimExpr> idx;
     idx.push_back(indices[0]);
     idx.push_back(indices[1]);
     idx.push_back(indices[2] * input->shape[2] / shape[0]);
@@ -196,7 +196,7 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_nearest_neighbor(const Tensor& input,
-                                      const Array<Expr>& shape,
+                                      const Array<PrimExpr>& shape,
                                       std::string layout = "NCHW",
                                       bool align_corners = false,
                                       std::string name = "tensor",
@@ -227,25 +227,25 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_bilinear_nhwc(const Tensor& input,
-                                   const Array<Expr>& shape,
+                                   const Array<PrimExpr>& shape,
                                    bool align_corners = false,
                                    std::string name = "tensor",
                                    std::string tag = kInjective) {
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(cast(DataType::Int(32), shape[0]));
   out_shape.push_back(cast(DataType::Int(32), shape[1]));
   out_shape.push_back(input->shape[3]);
 
-  Expr cone = make_const(DataType::Int(32), 1);
+  PrimExpr cone = make_const(DataType::Int(32), 1);
 
   auto in_height = as_const_int(input->shape[1]);
   auto in_width = as_const_int(input->shape[2]);
   auto out_height = as_const_int(shape[0]);
   auto out_width = as_const_int(shape[1]);
 
-  Expr y_ratio;
-  Expr x_ratio;
+  PrimExpr y_ratio;
+  PrimExpr x_ratio;
 
   if (!align_corners) {
     y_ratio = make_const(DataType::Float(32), (static_cast<float>(*in_height) /
@@ -259,8 +259,8 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
                                      static_cast<float>(*out_width - 1)));
   }
 
-  Expr other_y = tvm::ir::Simplify(input->shape[1] - cone);
-  Expr other_x = tvm::ir::Simplify(input->shape[2] - cone);
+  PrimExpr other_y = tvm::ir::Simplify(input->shape[1] - cone);
+  PrimExpr other_x = tvm::ir::Simplify(input->shape[2] - cone);
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
@@ -304,25 +304,25 @@ inline Tensor resize_bilinear_nhwc(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_bilinear_nchw(const Tensor& input,
-                                   const Array<Expr>& shape,
+                                   const Array<PrimExpr>& shape,
                                    bool align_corners = false,
                                    std::string name = "tensor",
                                    std::string tag = kInjective) {
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   out_shape.push_back(input->shape[0]);
   out_shape.push_back(input->shape[1]);
   out_shape.push_back(cast(DataType::Int(32), shape[0]));
   out_shape.push_back(cast(DataType::Int(32), shape[1]));
 
-  Expr cone = make_const(DataType::Int(32), 1);
+  PrimExpr cone = make_const(DataType::Int(32), 1);
 
   auto in_height = as_const_int(input->shape[2]);
   auto in_width = as_const_int(input->shape[3]);
   auto out_height = as_const_int(shape[0]);
   auto out_width = as_const_int(shape[1]);
 
-  Expr y_ratio;
-  Expr x_ratio;
+  PrimExpr y_ratio;
+  PrimExpr x_ratio;
 
   if (!align_corners) {
     y_ratio = make_const(DataType::Float(32), (static_cast<float>(*in_height) /
@@ -336,8 +336,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
                                      static_cast<float>(*out_width - 1)));
   }
 
-  Expr other_y = tvm::ir::Simplify(input->shape[2] - cone);
-  Expr other_x = tvm::ir::Simplify(input->shape[3] - cone);
+  PrimExpr other_y = tvm::ir::Simplify(input->shape[2] - cone);
+  PrimExpr other_x = tvm::ir::Simplify(input->shape[3] - cone);
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
@@ -360,7 +360,7 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize_bilinear(const Tensor& input,
-                              const Array<tvm::Expr>& shape,
+                              const Array<tvm::PrimExpr>& shape,
                               std::string layout = "NCHW",
                               bool align_corners = false,
                               std::string name = "tensor",
@@ -390,7 +390,7 @@ inline Tensor resize_bilinear(const Tensor& input,
 * \return A Tensor resized to given shape
 */
 inline Tensor resize(const Tensor& input,
-                     const Array<Expr>& shape,
+                     const Array<PrimExpr>& shape,
                      std::string layout = "NCHW",
                      bool align_corners = false,
                      std::string mode = "BILINEAR",
index 5920c0b..3f65c75 100644 (file)
@@ -39,9 +39,9 @@ using namespace tvm;
 namespace detail {
 
 template <typename T>
-tvm::Expr Map(const tvm::Array<tvm::Expr>& exprs, T op) {
+tvm::PrimExpr Map(const tvm::Array<tvm::PrimExpr>& exprs, T op) {
   CHECK_GE(exprs.size(), 1);
-  tvm::Expr res = exprs[0];
+  tvm::PrimExpr res = exprs[0];
   for (size_t i = 1; i < exprs.size(); ++i) {
     res = op(res, exprs[i]);
   }
@@ -172,9 +172,9 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
  *
  */
 inline tvm::Tensor pad(const tvm::Tensor& t,
-                       const tvm::Array<tvm::Expr>& pad_before,
-                       tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
-                       Expr pad_value = Expr(),
+                       const tvm::Array<tvm::PrimExpr>& pad_before,
+                       tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
+                       PrimExpr pad_value = PrimExpr(),
                        std::string name = "T_pad",
                        std::string tag = kElementWise,
                        std::string pad_mode = "constant") {
@@ -185,9 +185,9 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
   }
   CHECK_GE(pad_before.size(), 1);
   CHECK_EQ(pad_before.size(), pad_after.size());
-  tvm::Array<tvm::Expr> output_shape;
-  tvm::Array<tvm::Expr> pad_before_int32;
-  tvm::Array<tvm::Expr> pad_after_int32;
+  tvm::Array<tvm::PrimExpr> output_shape;
+  tvm::Array<tvm::PrimExpr> pad_before_int32;
+  tvm::Array<tvm::PrimExpr> pad_after_int32;
   for (const auto &ele : pad_before) {
     pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
   }
@@ -207,9 +207,9 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
     pad_value = tvm::make_const(t->dtype, 0);
   }
   auto l = [&](tvm::Array<tvm::Var> ovars) {
-    tvm::Array<tvm::Expr> indices;
-    tvm::Array<tvm::Expr> sel;
-    tvm::Array<tvm::Expr> pad_idx;
+    tvm::Array<tvm::PrimExpr> indices;
+    tvm::Array<tvm::PrimExpr> sel;
+    tvm::Array<tvm::PrimExpr> pad_idx;
     for (size_t i = 0; i < t->shape.size(); ++i) {
       if (i >= pad_before_int32.size()) {
         indices.push_back(ovars[i]);
@@ -286,7 +286,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
   CHECK_EQ(4, W->shape.size());
   auto pH = I->shape[2];
   auto pW = I->shape[3];
-  tvm::Array<tvm::Expr> output_shape{
+  tvm::Array<tvm::PrimExpr> output_shape{
     I->shape[0],                                            // B
     W->shape[0],                                            // O
     indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,  // H
@@ -297,7 +297,7 @@ inline tvm::Tensor conv2d_nchw(const tvm::Tensor& I,
   auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
   auto T = (pad_h == 0 && pad_w == 0)
                ? I
-               : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+               : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
   auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
     return tvm::sum(
         T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw),
@@ -337,7 +337,7 @@ inline tvm::Tensor conv2d_hwcn(const tvm::Tensor& I,
   CHECK_EQ(4, W->shape.size());
   auto pH = I->shape[2];
   auto pW = I->shape[3];
-  tvm::Array<tvm::Expr> output_shape{
+  tvm::Array<tvm::PrimExpr> output_shape{
       indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,  // H
       indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1,  // W
       I->shape[2],                                             // B
@@ -389,7 +389,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
   auto pH = I->shape[2];
   auto pW = I->shape[3];
   auto pCM = W->shape[1];  // channel_multiplier
-  tvm::Array<tvm::Expr> output_shape{
+  tvm::Array<tvm::PrimExpr> output_shape{
       I->shape[0],                                            // B
       W->shape[1],                                            // O
       indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1,  // H
@@ -400,7 +400,7 @@ inline tvm::Tensor depthwise_conv2d_nchw(const tvm::Tensor& I,
   auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
   auto T = (pad_h == 0 && pad_w == 0)
                ? I
-               : pad(I, {tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+               : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
   auto l = [&](tvm::Var b, tvm::Var o, tvm::Var h, tvm::Var w) {
     return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
                     W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
@@ -422,7 +422,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
   auto pH = I->shape[1];
   auto pW = I->shape[2];
   auto pCM = W->shape[1];  // channel_multiplier
-  tvm::Array<tvm::Expr> output_shape{
+  tvm::Array<tvm::PrimExpr> output_shape{
       I->shape[0],                                            // B
       indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1,  // H
       indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1,   // W
@@ -433,7 +433,7 @@ inline tvm::Tensor depthwise_conv2d_nhwc(const tvm::Tensor& I,
   auto kw = tvm::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
   auto T = (pad_h == 0 && pad_w == 0)
                ? I
-               : pad(I, {tvm::Expr(0), pad_h, pad_w, tvm::Expr(0)});
+               : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
   auto l = [&](tvm::Var b, tvm::Var h, tvm::Var w, tvm::Var o) {
     return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
                     W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
@@ -474,7 +474,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
   CHECK_EQ(5, W->shape.size());
   auto pH = I->shape[2];
   auto pW = I->shape[3];
-  tvm::Array<tvm::Expr> output_shape{
+  tvm::Array<tvm::PrimExpr> output_shape{
       I->shape[0],                                            // B
       I->shape[1],                                            // G
       W->shape[2],                                            // O
@@ -487,7 +487,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
 
   auto T = (pad_h == 0 && pad_w == 0)
                ? I
-               : pad(I, {tvm::Expr(0), tvm::Expr(0), tvm::Expr(0), pad_h, pad_w});
+               : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
   auto l = [&](tvm::Array<tvm::Var> args) {
     tvm::Var b = args[0];
     tvm::Var g = args[1];
index 6b79db3..e2af3ae 100644 (file)
@@ -55,7 +55,7 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
     << "binarize_pack: axis size must be a multiple of 32";
 
   auto n = ishape.size();
-  Array<Expr> oshape;
+  Array<PrimExpr> oshape;
   for (size_t i = 0; i < n; ++i) {
     oshape.push_back(i == static_cast<size_t>(axis) ?
                      tvm::ir::Simplify(indexdiv(ishape[i], 32)) :
@@ -65,15 +65,15 @@ inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
   return tvm::compute(
     oshape,
     [&](const Array<Var>& indices) {
-      Array<Expr> start_idx;
+      Array<PrimExpr> start_idx;
       for (size_t i = 0; i < n; ++i) {
         start_idx.push_back(i == static_cast<size_t>(axis) ?
                             indices[i] * 32 :
-                            static_cast<Expr>(indices[i]));
+                            static_cast<PrimExpr>(indices[i]));
       }
       auto packed = make_const(DataType::UInt(32), 0);
       for (size_t j = 0; j < 32; ++j) {
-        Array<Expr> idx;
+        Array<PrimExpr> idx;
         for (size_t i = 0; i < n; ++i) {
           idx.push_back(i == static_cast<size_t>(axis) ?
                         start_idx[i] + static_cast<int>(j) :
index 1dc2c8d..334b170 100644 (file)
@@ -42,10 +42,10 @@ using namespace tvm;
 *
 * \return The logical conjunction expression
 */
-Expr all(Array<Expr> args) {
+PrimExpr all(Array<PrimExpr> args) {
   CHECK_GT(args.size(), 0) << "all requires at least one argument";
 
-  Expr ret = args[0];
+  PrimExpr ret = args[0];
   for (size_t i = 1; i < args.size(); ++i) {
     ret = ret && args[i];
   }
@@ -65,7 +65,7 @@ Expr all(Array<Expr> args) {
 * \return The output tensor.
 */
 inline Tensor dilate(const Tensor& x,
-                     Array<Expr> strides,
+                     Array<PrimExpr> strides,
                      std::string name = "tensor",
                      std::string tag = kInjective) {
   auto n = x->shape.size();
@@ -73,7 +73,7 @@ inline Tensor dilate(const Tensor& x,
     << "strides size (" << strides.size()
     << ") must match dimension of x (" << n << ")";
 
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 0; i < n; ++i) {
     out_shape.push_back(tvm::ir::Simplify(
       (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
@@ -82,8 +82,8 @@ inline Tensor dilate(const Tensor& x,
   return tvm::compute(
     out_shape,
     [&](const Array<Var>& indices) {
-      Array<Expr> not_zero;
-      Array<Expr> index_tuple;
+      Array<PrimExpr> not_zero;
+      Array<PrimExpr> index_tuple;
       for (size_t i = 0; i < n; ++i) {
         if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
           index_tuple.push_back(indices[i]);
index e331b97..6b542f7 100644 (file)
@@ -51,14 +51,14 @@ inline Tensor flatten(const Tensor& x,
                       std::string name = "tensor",
                       std::string tag = kInjective) {
   auto ishape = x->shape;
-  Expr dim = 1;
+  PrimExpr dim = 1;
   for (size_t i = 1; i < ishape.size(); ++i) {
     dim = dim * ishape[i];
   }
 
-  Array<Expr> oshape({ ishape[0], dim });
+  Array<PrimExpr> oshape({ ishape[0], dim });
 
-  std::vector<Expr> extra_shape;
+  std::vector<PrimExpr> extra_shape;
   for (size_t i = 1; i < ishape.size(); ++i) {
     extra_shape.push_back(ishape[i]);
   }
@@ -66,8 +66,8 @@ inline Tensor flatten(const Tensor& x,
 
   return tvm::compute(
     oshape, [&](Var i, Var j) {
-      Expr idx = j;
-      std::vector<Expr> index;
+      PrimExpr idx = j;
+      std::vector<PrimExpr> index;
       for (auto s : extra_shape) {
         index.push_back(indexmod(idx, s));
         idx = indexdiv(idx, s);
index 2490b37..0cce997 100644 (file)
@@ -59,10 +59,10 @@ inline Tensor lrn(const Tensor& data,
   CHECK_EQ(size % 2, 1) << "size should be odd number";
   CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
   auto input_shape = data->shape;
-  Array<Expr> pad_before{ 0, 0, 0, 0};
-  Array<Expr> pad_after{ 0, 0, 0, 0};
-  pad_before.Set(axis, static_cast<Expr>(size/2));
-  pad_after.Set(axis, static_cast<Expr>(size/2));
+  Array<PrimExpr> pad_before{ 0, 0, 0, 0};
+  Array<PrimExpr> pad_after{ 0, 0, 0, 0};
+  pad_before.Set(axis, static_cast<PrimExpr>(size/2));
+  pad_after.Set(axis, static_cast<PrimExpr>(size/2));
   auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
   auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
   Tensor sqr_sum;
index 35bcd53..a074ee1 100644 (file)
@@ -61,9 +61,9 @@ enum PoolType : int {
 * \return The output tensor in same layout order
 */
 inline Tensor pool_impl(const Tensor& x,
-                        const Array<Expr>& kernel_size,
-                        const Array<Expr>& stride_size,
-                        const Array<Expr>& padding_size,
+                        const Array<PrimExpr>& kernel_size,
+                        const Array<PrimExpr>& stride_size,
+                        const Array<PrimExpr>& padding_size,
                         PoolType pool_type,
                         bool ceil_mode,
                         const size_t height_axis,
@@ -94,11 +94,11 @@ inline Tensor pool_impl(const Tensor& x,
     pad_right += stride_width - 1;
   }
 
-  Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
+  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
   pad_before.Set(height_axis, pad_top);
   pad_before.Set(width_axis, pad_left);
 
-  Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
+  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
   pad_after.Set(height_axis, pad_bottom);
   pad_after.Set(width_axis, pad_right);
 
@@ -110,7 +110,7 @@ inline Tensor pool_impl(const Tensor& x,
   auto dheight = tvm::reduce_axis(Range(0, kernel_height));
   auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
 
-  Array<Expr> out_shape = x->shape;
+  Array<PrimExpr> out_shape = x->shape;
   out_shape.Set(height_axis, out_height);
   out_shape.Set(width_axis, out_width);
 
@@ -125,7 +125,7 @@ inline Tensor pool_impl(const Tensor& x,
     auto temp = do_pad ? pad(
         x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
     return tvm::compute(out_shape, [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       indices.Set(height_axis, output[height_axis] * stride_height + dheight);
       indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
@@ -138,7 +138,7 @@ inline Tensor pool_impl(const Tensor& x,
     // TVM compute for summing the pooling window.
     auto pool_sum = tvm::compute(out_shape,
     [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       indices.Set(height_axis, output[height_axis] * stride_height + dheight);
       indices.Set(width_axis, output[width_axis] * stride_width + dwidth);
@@ -148,18 +148,18 @@ inline Tensor pool_impl(const Tensor& x,
     // TVM compute for dividing the reduced window sum by kernel size.
     return tvm::compute(out_shape,
     [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       if (count_include_pad) {
         return div(pool_sum(indices), (kernel_height * kernel_width));
       } else {
-        Expr h_start = output[height_axis] * stride_height - pad_top;
-        Expr w_start = output[width_axis] * stride_width - pad_left;
-        Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
-        Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+        PrimExpr h_start = output[height_axis] * stride_height - pad_top;
+        PrimExpr w_start = output[width_axis] * stride_width - pad_left;
+        PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height);
+        PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width);
         h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
         w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
-        Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
+        PrimExpr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
                                            make_const(DataType::DataType::Int(32), 1));
         return div(pool_sum(indices), divide_factor);
       }
@@ -170,9 +170,12 @@ inline Tensor pool_impl(const Tensor& x,
   }
 }
 
-inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
-                             const Array<Expr>& kernel_size, const Array<Expr>& stride_size,
-                             const Array<Expr>& padding_size, PoolType pool_type, bool ceil_mode,
+inline Tensor pool_grad_impl(const Tensor& out_grad,
+                             const Tensor& x,
+                             const Array<PrimExpr>& kernel_size,
+                             const Array<PrimExpr>& stride_size,
+                             const Array<PrimExpr>& padding_size,
+                             PoolType pool_type, bool ceil_mode,
                              const size_t height_axis, const size_t width_axis,
                              bool count_include_pad) {
   CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
@@ -201,11 +204,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
     pad_right += stride_width - 1;
   }
 
-  Array<Expr> pad_before(std::vector<Expr>(x->shape.size(), 0));
+  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
   pad_before.Set(height_axis, pad_top);
   pad_before.Set(width_axis, pad_left);
 
-  Array<Expr> pad_after(std::vector<Expr>(x->shape.size(), 0));
+  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
   pad_after.Set(height_axis, pad_bottom);
   pad_after.Set(width_axis, pad_right);
 
@@ -217,7 +220,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
   auto dheight = tvm::reduce_axis(Range(0, kernel_height));
   auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
 
-  Array<Expr> out_shape = x->shape;
+  Array<PrimExpr> out_shape = x->shape;
   out_shape.Set(height_axis, out_height);
   out_shape.Set(width_axis, out_width);
 
@@ -229,7 +232,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
                       ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
 
   if (pool_type == kMaxPool) {
-    Array<Expr> ravel_shape{x->shape.begin(), x->shape.end()};
+    Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
     ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
     ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
 
@@ -243,7 +246,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
     auto mp_argmax =
         tvm::compute(out_shape,
                      [&](const Array<Var>& inds) {
-                       Array<Expr> window_inds{inds.begin(), inds.end()};
+                       Array<PrimExpr> window_inds{inds.begin(), inds.end()};
                        window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
                        window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
                        auto idx = detail::RavelIndex(window_inds, ravel_shape);
@@ -256,19 +259,19 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
     return tvm::compute(
         x->shape,
         [&](const Array<Var>& inds) {
-          Array<Expr> pad_inds {inds.begin(), inds.end()};
+          Array<PrimExpr> pad_inds {inds.begin(), inds.end()};
           pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
           pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
           auto idx = detail::RavelIndex(pad_inds, ravel_shape);
 
-          Array<Expr> out_idx {inds.begin(), inds.end()};
+          Array<PrimExpr> out_idx {inds.begin(), inds.end()};
           out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
           out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
 
-          Expr out_idx_lower_h = ir::SelectNode::make(
+          PrimExpr out_idx_lower_h = ir::SelectNode::make(
               pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
               (pad_inds[height_axis] - kernel_height) / stride_height + 1);
-          Expr out_idx_lower_w = ir::SelectNode::make(
+          PrimExpr out_idx_lower_w = ir::SelectNode::make(
               pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
               (pad_inds[width_axis] - kernel_width) / stride_width + 1);
 
@@ -287,29 +290,29 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
     return tvm::compute(
         x->shape,
         [&](const Array<Var>& inds) {
-          Expr pad_h_idx = inds[height_axis] + pad_top;
-          Expr pad_w_idx = inds[width_axis] + pad_left;
+          PrimExpr pad_h_idx = inds[height_axis] + pad_top;
+          PrimExpr pad_w_idx = inds[width_axis] + pad_left;
 
           // output indices whose pooling windows cover current input element (can be out-of-bound)
-          Array<Expr> out_idx{inds.begin(), inds.end()};
+          Array<PrimExpr> out_idx{inds.begin(), inds.end()};
           out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
           out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
 
-          Expr out_idx_lower_h = ir::SelectNode::make(
+          PrimExpr out_idx_lower_h = ir::SelectNode::make(
               pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
               (pad_h_idx - kernel_height) / stride_height + 1);
-          Expr out_idx_lower_w = ir::SelectNode::make(
+          PrimExpr out_idx_lower_w = ir::SelectNode::make(
               pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
               (pad_w_idx - kernel_width) / stride_width + 1);
 
-          Expr divide_factor;  // number of pooled elements
+          PrimExpr divide_factor;  // number of pooled elements
           if (count_include_pad) {
             divide_factor = kernel_height * kernel_width;
           } else {
-            Expr h_start = out_idx[height_axis] * stride_height - pad_top;
-            Expr w_start = out_idx[width_axis] * stride_width - pad_left;
-            Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
-            Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+            PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
+            PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
+            PrimExpr h_end = ir::MinNode::make(h_start + kernel_height, height);
+            PrimExpr w_end = ir::MinNode::make(w_start + kernel_width, width);
             h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0));
             w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0));
             divide_factor =
@@ -412,9 +415,9 @@ inline bool find_width(const std::string& layout,
 * \return The output tensor in the same layout
 */
 inline Tensor pool(const Tensor& x,
-                   const Array<Expr>& kernel_size,
-                   const Array<Expr>& stride_size,
-                   const Array<Expr>& padding_size,
+                   const Array<PrimExpr>& kernel_size,
+                   const Array<PrimExpr>& stride_size,
+                   const Array<PrimExpr>& padding_size,
                    PoolType pool_type,
                    bool ceil_mode,
                    const std::string& layout = "NCHW",
@@ -457,8 +460,8 @@ inline Tensor pool(const Tensor& x,
  *
  * \return The output tensor in the same layout
  */
-inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Expr>& kernel_size,
-                        const Array<Expr>& stride_size, const Array<Expr>& padding_size,
+inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
+                        const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
                         PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
                         bool count_include_pad = true) {
   int height_axis = -1, width_axis = -1;
@@ -467,16 +470,16 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<Exp
                         height_axis, width_axis, count_include_pad);
 }
 
-inline Expr start_index(const Var& out_index,
-                        const Expr& odim,
-                        const Expr& idim) {
+inline PrimExpr start_index(const Var& out_index,
+                        const PrimExpr& odim,
+                        const PrimExpr& idim) {
   return indexdiv(out_index * idim, odim);
 }
 
-inline Expr end_index(const Var& out_index,
-                      const Expr& odim,
-                      const Expr& idim) {
-  Expr tmp = indexdiv((out_index + 1) * idim, odim);
+inline PrimExpr end_index(const Var& out_index,
+                      const PrimExpr& odim,
+                      const PrimExpr& idim) {
+  PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
   return tvm::ir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0,
                                tmp, tmp + 1);
 }
@@ -493,7 +496,7 @@ inline Expr end_index(const Var& out_index,
 * \return The output tensor in same layout order
 */
 inline Tensor adaptive_pool_impl(const Tensor& x,
-                                 const Array<Expr>& output_size,
+                                 const Array<PrimExpr>& output_size,
                                  PoolType pool_type,
                                  const size_t height_axis,
                                  const size_t width_axis) {
@@ -504,13 +507,13 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
 
   auto out_height = cast(DataType::Int(32), output_size[0]);
   auto out_width = cast(DataType::Int(32), output_size[1]);
-  Array<Expr> out_shape = x->shape;
+  Array<PrimExpr> out_shape = x->shape;
   out_shape.Set(height_axis, out_height);
   out_shape.Set(width_axis, out_width);
 
   if (pool_type == kMaxPool) {
     return tvm::compute(out_shape, [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
       auto i_end_h = end_index(output[height_axis], out_height, height);
@@ -524,7 +527,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
     }, "tensor", "adaptive_pool_max");
   } else if (pool_type == kAvgPool) {
     auto pool_sum = tvm::compute(out_shape, [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
       auto i_end_h = end_index(output[height_axis], out_height, height);
@@ -540,7 +543,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
     }, "tensor", "adaptive_pool_sum");
 
     return tvm::compute(out_shape, [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       auto i_start_h = start_index(output[height_axis], out_height, height);
       auto i_end_h = end_index(output[height_axis], out_height, height);
@@ -583,7 +586,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x,
 * \return The output tensor in same layout order
 */
 inline Tensor adaptive_pool(const Tensor& x,
-                            const Array<Expr>& output_size,
+                            const Array<PrimExpr>& output_size,
                             PoolType pool_type,
                             const std::string& layout = "NCHW") {
   int height_axis = -1, width_axis = -1;
@@ -620,7 +623,7 @@ inline Tensor adaptive_pool(const Tensor& x,
 inline Tensor global_pool(const Tensor& x,
                           PoolType pool_type,
                           const std::string& layout = "NCHW") {
-  return adaptive_pool(x, Array<Expr>{1, 1}, pool_type, layout);
+  return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
 }
 
 /*!
@@ -639,9 +642,9 @@ inline Tensor global_pool(const Tensor& x,
 * \return The output tensor in same layout order
 */
 inline Tensor pool_impl_nd(const Tensor& x,
-                           const Array<Expr>& kernel_size,
-                           const Array<Expr>& stride_size,
-                           const Array<Expr>& padding_size,
+                           const Array<PrimExpr>& kernel_size,
+                           const Array<PrimExpr>& stride_size,
+                           const Array<PrimExpr>& padding_size,
                            PoolType pool_type,
                            bool ceil_mode,
                            const std::vector<int>& axis,
@@ -654,13 +657,13 @@ inline Tensor pool_impl_nd(const Tensor& x,
   CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
 
   Array<IterVar> daxis;
-  std::vector<Expr> kernel(k_size);
-  std::vector<Expr> stride(k_size);
-  std::vector<Expr> pad_head(k_size);
-  std::vector<Expr> pad_tail(k_size);
-  Array<Expr> pad_before(std::vector<Expr>(x_size, 0));
-  Array<Expr> pad_after(std::vector<Expr>(x_size, 0));
-  Array<Expr> out_shape = x->shape;
+  std::vector<PrimExpr> kernel(k_size);
+  std::vector<PrimExpr> stride(k_size);
+  std::vector<PrimExpr> pad_head(k_size);
+  std::vector<PrimExpr> pad_tail(k_size);
+  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
+  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
+  Array<PrimExpr> out_shape = x->shape;
 
   bool do_pad = false;
   for (int i = 0; i < k_size; i++) {
@@ -694,7 +697,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
     auto temp = do_pad ? pad(
         x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
     return tvm::compute(out_shape, [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
 
       for (int i = 0; i < k_size; i++) {
@@ -711,7 +714,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
     // TVM compute for summing the pooling window.
     auto pool_sum = tvm::compute(out_shape,
     [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
 
       for (int i = 0; i < k_size; i++) {
@@ -724,7 +727,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
     // TVM compute for dividing the reduced window sum by kernel size.
     return tvm::compute(out_shape,
     [&](const Array<Var>& output) {
-      Array<Expr> indices;
+      Array<PrimExpr> indices;
       for (const Var& var : output) indices.push_back(var);
       if (count_include_pad) {
         auto kernel_size = make_const(DataType::Int(32), 1);
@@ -733,8 +736,8 @@ inline Tensor pool_impl_nd(const Tensor& x,
         }
         return div(pool_sum(indices), kernel_size);
       } else {
-        std::vector<Expr> start(k_size);
-        std::vector<Expr> end(k_size);
+        std::vector<PrimExpr> start(k_size);
+        std::vector<PrimExpr> end(k_size);
         auto kernel_size = make_const(DataType::Int(32), 1);
         for (int i = 0; i < k_size; i++) {
           int ii = axis[i];
@@ -744,7 +747,7 @@ inline Tensor pool_impl_nd(const Tensor& x,
           kernel_size *= (end[i] - start[i]);
         }
 
-        Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
+        PrimExpr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
         return div(pool_sum(indices), divide_factor);
       }
     }, "tensor", kElementWise);
@@ -784,9 +787,9 @@ inline Tensor pool_impl_nd(const Tensor& x,
 * \return The output tensor in the same layout
 */
 inline Tensor pool1d(const Tensor& x,
-                     const Array<Expr>& kernel_size,
-                     const Array<Expr>& stride_size,
-                     const Array<Expr>& padding_size,
+                     const Array<PrimExpr>& kernel_size,
+                     const Array<PrimExpr>& stride_size,
+                     const Array<PrimExpr>& padding_size,
                      PoolType pool_type,
                      bool ceil_mode,
                      const std::string& layout = "NCW",
@@ -830,9 +833,9 @@ inline Tensor pool1d(const Tensor& x,
 * \return The output tensor in the same layout
 */
 inline Tensor pool3d(const Tensor& x,
-                     const Array<Expr>& kernel_size,
-                     const Array<Expr>& stride_size,
-                     const Array<Expr>& padding_size,
+                     const Array<PrimExpr>& kernel_size,
+                     const Array<PrimExpr>& stride_size,
+                     const Array<PrimExpr>& padding_size,
                      PoolType pool_type,
                      bool ceil_mode,
                      const std::string& layout = "NCDHW",
index c3124bb..58ecc95 100644 (file)
@@ -66,7 +66,7 @@ inline Tensor softmax(const Tensor &x,
 
   auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
                                           const IterVar &reduce_index) {
-    Array<Expr> eval_range;
+    Array<PrimExpr> eval_range;
     int arg_counter = 0;
     for (size_t i = 0; i < ndim; ++i) {
       if (static_cast<int>(i) == axis)
@@ -78,7 +78,7 @@ inline Tensor softmax(const Tensor &x,
   };
 
   auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
-    Array<Expr> non_reduce_indices;
+    Array<PrimExpr> non_reduce_indices;
     for (size_t i = 0; i < ndim; ++i) {
       if (static_cast<int>(i) != axis)
         non_reduce_indices.push_back(indices[i]);
@@ -135,8 +135,8 @@ inline Tensor log_softmax(const Tensor& x,
                           std::string tag = "log_softmax_output") {
   CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
 
-  Expr m = x->shape[0];
-  Expr n = x->shape[1];
+  PrimExpr m = x->shape[0];
+  PrimExpr n = x->shape[1];
 
   auto k = tvm::reduce_axis(Range(0, n), "k");
   auto max_elem = tvm::compute(
index 3ca0854..b6230c7 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -49,7 +49,7 @@ using namespace topi::image;
 * \return A Tensor upsampled to given shape
 */
 inline Tensor upsampling(const Tensor& input,
-                         const Array<Expr> shape,
+                         const Array<PrimExpr> shape,
                          std::string layout = "NCHW",
                          std::string mode = "NEAREST_NEIGHBOR",
                          std::string name = "tensor",
index 2d3d7d3..ac843b1 100644 (file)
@@ -43,11 +43,11 @@ namespace topi {
 using namespace tvm;
 
 /*! \brief The operation to use for CommReduce */
-using FReduce = std::function<Expr(Expr source, const Array<IterVar>& axis)>;
+using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
 
 /*! \brief The operation to use for CommReduceIdx */
 using FCommReduce = std::function<
-  Array<Expr>(Array<Expr> exprs, const Array<IterVar>& axis, Expr* condition)>;
+  Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition)>;
 
 /*!
 * \brief Convert a reduction axis which could be empty or have negative
@@ -97,12 +97,12 @@ inline Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Te
 }
 
 /*! \brief Calculate the target shape for a reduce op */
-inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
+inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis,
                                          const Tensor& data,
                                          bool keepdims,
                                          bool atleast1d) {
   auto ndim = data->shape.size();
-  Array<Expr> target_shape;
+  Array<PrimExpr> target_shape;
   if (keepdims) {
     for (size_t i = 0; i < ndim; ++i) {
       if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
@@ -140,12 +140,12 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
  */
 inline Tensor DoCommReduce(const Tensor& data,
                            FReduce func,
-                           const Array<Expr>& target_shape,
+                           const Array<PrimExpr>& target_shape,
                            const std::vector<int>& reduce_axes,
                            const std::vector<int>& squeeze_axes) {
   auto r_axes = MakeReduceAxes(reduce_axes, data);
   auto compute = [&](const Array<Var>& indices) {
-    Array<Expr> eval_range;
+    Array<PrimExpr> eval_range;
     Array<Var> eval_indices;
     int arg_counter = 0;
     int red_counter = 0;
@@ -222,8 +222,8 @@ inline Tensor CommReduceIdx(const Tensor& data,
 
   auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data]
   (const Array<Var>& indices) {
-    Array<Expr> eval_range;
-    Array<Expr> eval_indices;
+    Array<PrimExpr> eval_range;
+    Array<PrimExpr> eval_indices;
     int arg_counter = 0;
     int red_counter = 0;
 
@@ -243,7 +243,7 @@ inline Tensor CommReduceIdx(const Tensor& data,
       }
     }
 
-    Array<Expr> ravel_shape;
+    Array<PrimExpr> ravel_shape;
     for (auto i : real_axis) {
       ravel_shape.push_back(data->shape[i]);
     }
@@ -263,10 +263,10 @@ inline Tensor CommReduceIdx(const Tensor& data,
 }
 
 /*! \brief A combiner function for a reduction */
-using FCombine = std::function<Array<Expr>(Array<Var> lhs, Array<Var> rhs)>;
+using FCombine = std::function<Array<PrimExpr>(Array<Var> lhs, Array<Var> rhs)>;
 
 /*! \brief An initializer function for a reduction */
-using FIdentity = std::function<Array<Expr>(std::vector<DataType> types)>;
+using FIdentity = std::function<Array<PrimExpr>(std::vector<DataType> types)>;
 
 /*!
  * \brief Create a commutative reducer for a reduction
@@ -281,7 +281,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine,
                                    FIdentity fidentity,
                                    std::string name = "reduce") {
   return [fcombine, fidentity, name]
-  (Array<Expr> exprs, const Array<IterVar>& axis, Expr* condition) {
+  (Array<PrimExpr> exprs, const Array<IterVar>& axis, PrimExpr* condition) {
     Array<Var> lhs, rhs;
     std::vector<DataType> dtypes;
 
@@ -297,7 +297,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine,
     auto cond = condition != nullptr ? *condition : tvm::const_true();
 
     auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem);
-    Array<Expr> outputs;
+    Array<PrimExpr> outputs;
     for (size_t i = 0; i < exprs.size(); ++i) {
       outputs.push_back(
         tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
@@ -307,17 +307,17 @@ inline FCommReduce MakeCommReducer(FCombine fcombine,
 }
 
 /*! \brief Wrap tvm::min to ensure we get the correct overload */
-inline Expr MinOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) {
   return tvm::min(source, axis);
 }
 
 /*! \brief Wrap tvm::max to ensure we get the correct overload */
-inline Expr MaxOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
   return tvm::max(source, axis);  // NOLINT(*)
 }
 
 /*! \brief Wrap tvm::prod to ensure we get the correct overload */
-inline Expr ProdOp(Expr source, Array<IterVar> axis) {
+inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis) {
   return tvm::prod(source, axis);  // NOLINT(*)
 }
 
@@ -341,7 +341,7 @@ inline Tensor sum(const Tensor& data,
   return CommReduce(data, axis, tvm::sum, keepdims, atleast1d);
 }
 
-inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
+inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
   CHECK_GE(data->shape.size(), target_shape.size());
   auto ishape = detail::GetConstIntValues(data->shape, "ishape");
   auto oshape = detail::GetConstIntValues(target_shape, "oshape");
@@ -472,13 +472,13 @@ inline Tensor argmin(const Tensor& data,
                      bool keepdims = false,
                      bool atleast1d = false) {
   auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
-    Array<Expr> result;
+    Array<PrimExpr> result;
     result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0]));  // idx
     result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1]));  // val
     return result;
   };
   auto fidentity = [](std::vector<DataType> types) {
-    Array<Expr> result;
+    Array<PrimExpr> result;
     result.push_back(tvm::make_const(types[0], -1));  // idx
     result.push_back(tvm::max_value(types[1]));  // val
     return result;
@@ -489,13 +489,13 @@ inline Tensor argmin(const Tensor& data,
 
 inline FCommReduce MakeArgmaxReducer() {
   auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
-    Array<Expr> result;
+    Array<PrimExpr> result;
     result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0]));  // idx
     result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1]));  // val
     return result;
   };
   auto fidentity = [](std::vector<DataType> types) {
-    Array<Expr> result;
+    Array<PrimExpr> result;
     result.push_back(tvm::make_const(types[0], -1));  // idx
     result.push_back(tvm::min_value(types[1]));  // val
     return result;
index 00106c1..66e2773 100644 (file)
@@ -72,7 +72,7 @@ inline Tensor expand_dims(const Tensor& x,
     // Calculate offset from last dimension
     axis = ndim + axis + 1;
   }
-  Array<Expr> new_shape;
+  Array<PrimExpr> new_shape;
   for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
     new_shape.push_back(x->shape[i]);
   }
@@ -85,7 +85,7 @@ inline Tensor expand_dims(const Tensor& x,
 
   return compute(
     new_shape, [&](const Array<Var>& indices) {
-      Array<Expr> idx;
+      Array<PrimExpr> idx;
       for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
         idx.push_back(indices[i]);
       }
@@ -118,7 +118,7 @@ inline Tensor transpose(const Tensor& x,
     }
   }
 
-  Array<Expr> new_shape;
+  Array<PrimExpr> new_shape;
   for (size_t i = 0; i < axes.size(); ++i) {
     int axis = static_cast<int>(axes[i]->value);
     int new_axis = axis;
@@ -140,7 +140,7 @@ inline Tensor transpose(const Tensor& x,
 
   return compute(
     new_shape, [&](const Array<Var>& indices) {
-      std::vector<Expr> idx;
+      std::vector<PrimExpr> idx;
       for (size_t i = 0; i < axes.size(); ++i) {
         idx.push_back(1);
       }
@@ -181,7 +181,7 @@ inline Tensor flip(const Tensor& x,
   // Reverse the Input Tensor in the axis specified
   return compute(
     x->shape, [&](const Array<Var>& indices) {
-      Array<Expr> real_indices;
+      Array<PrimExpr> real_indices;
       for (size_t i = 0; i < src_tensor_dim; ++i) {
         if (i == static_cast<size_t>(axis)) {
           real_indices.push_back(x->shape[i] - indices[i] - 1);
@@ -204,11 +204,11 @@ inline Tensor flip(const Tensor& x,
 * \return A Tensor whose op member is the reshape operation
 */
 inline Tensor reshape(const Tensor& x,
-                      Array<Expr> newshape,
+                      Array<PrimExpr> newshape,
                       std::string name = "T_reshape",
                       std::string tag = kInjective) {
   auto x_shape = x->shape;
-  Array<Expr> target_shape;
+  Array<PrimExpr> target_shape;
 
   for (const auto &ele : newshape) {
     if (ele.as<IntImmNode>()) {
@@ -226,7 +226,7 @@ inline Tensor reshape(const Tensor& x,
     return compute(
       target_shape, [&](const Array<Var>& indices) {
         return x(UnravelIndex(
-          RavelIndex(Array<Expr>{indices.begin(), indices.end()}, target_shape),
+          RavelIndex(Array<PrimExpr>{indices.begin(), indices.end()}, target_shape),
           x_shape));
       }, name, tag);
   }
@@ -272,7 +272,7 @@ inline Tensor squeeze(const Tensor& x,
 
   std::unordered_set<int> axis_set(axis_val.begin(), axis_val.end());
 
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 0; i < ndim; ++i) {
     if (axis_set.count(static_cast<int>(i)) == 0) {
       out_shape.push_back(x->shape[i]);
@@ -284,7 +284,7 @@ inline Tensor squeeze(const Tensor& x,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-      Array<Expr> real_indices;
+      Array<PrimExpr> real_indices;
       int flag = 0;
       for (size_t i = 0; i < ndim; ++i) {
         if (axis_set.count(static_cast<int>(i)) == 0) {
@@ -323,17 +323,17 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
   CHECK_LT(axis, inputs[0]->shape.size()) <<
     "axis out of bounds";
 
-  Array<Expr> axis_sizes;
+  Array<PrimExpr> axis_sizes;
   for (auto t : inputs) {
     axis_sizes.push_back(t->shape[axis]);
   }
 
-  Expr join_size = axis_sizes[0];
+  PrimExpr join_size = axis_sizes[0];
   for (size_t i = 1; i < axis_sizes.size(); ++i) {
     join_size += axis_sizes[i];
   }
   join_size = tvm::ir::Simplify(join_size);
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 0; i < inputs[0]->shape.size(); ++i) {
     out_shape.push_back(i == static_cast<size_t>(axis) ? join_size : inputs[0]->shape[i]);
   }
@@ -345,7 +345,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
       for (size_t i = 0; i < inputs.size() - 1; ++i) {
         ind -= axis_sizes[i];
 
-        Array<Expr> idx;
+        Array<PrimExpr> idx;
         for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
           idx.push_back(indices[i]);
         }
@@ -388,7 +388,7 @@ inline Tensor stack(const Array<Tensor>& inputs,
     "axis out of bounds";
 
   const int stack_size = static_cast<int>(inputs.size());
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 0; i < static_cast<size_t>(axis); ++i)
     out_shape.push_back(inputs[0]->shape[i]);
   out_shape.push_back(stack_size);
@@ -397,7 +397,7 @@ inline Tensor stack(const Array<Tensor>& inputs,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-      Array<Expr> idx;
+      Array<PrimExpr> idx;
       for (size_t i = 0; i < indices.size(); ++i)
         if (i != static_cast<size_t>(axis))
           idx.push_back(indices[i]);
@@ -445,7 +445,7 @@ inline Array<Tensor> split(const Tensor& x,
     begin_ids.push_back(val);
   }
 
-  Array< Array<Expr> > out_shapes;
+  Array< Array<PrimExpr> > out_shapes;
   for (size_t i = 0; i < begin_ids.size(); ++i) {
     int out_axis_size;
     if (i == begin_ids.size() - 1) {
@@ -454,7 +454,7 @@ inline Array<Tensor> split(const Tensor& x,
       out_axis_size = begin_ids[i + 1] - begin_ids[i];
     }
 
-    Array<Expr> shape;
+    Array<PrimExpr> shape;
     for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
       shape.push_back(x->shape[i]);
     }
@@ -472,7 +472,7 @@ inline Array<Tensor> split(const Tensor& x,
       compute(
         out_shapes[i], [&](const Array<Var>& indices) {
           auto begin = begin_ids[i];
-          Array<Expr> real_indices;
+          Array<PrimExpr> real_indices;
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(indices[j]);
           }
@@ -547,9 +547,9 @@ inline Tensor strided_slice(const Tensor& x,
     end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
   }
   // Compute
-  Array<Expr> out_shape;
-  Array<Expr> begin_expr;
-  Array<Expr> strides_expr;
+  Array<PrimExpr> out_shape;
+  Array<PrimExpr> begin_expr;
+  Array<PrimExpr> strides_expr;
 
   for (size_t i = 0; i < src_tensor_dim; ++i) {
     int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
@@ -581,7 +581,7 @@ inline Tensor strided_slice(const Tensor& x,
 
   return compute(
     out_shape, [&](const Array<Var>& indices) {
-      Array<Expr> real_indices;
+      Array<PrimExpr> real_indices;
       for (size_t i = 0; i < src_tensor_dim; ++i) {
         real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
       }
@@ -647,9 +647,9 @@ inline Tensor take(const Tensor& a,
                    std::string mode = "clip",
                    std::string name = "T_take",
                    std::string tag = kInjective) {
-  Array<Expr> a_shape = a->shape;
-  Array<Expr> out_shape = indices->shape;
-  Expr a_size = 1;
+  Array<PrimExpr> a_shape = a->shape;
+  Array<PrimExpr> out_shape = indices->shape;
+  PrimExpr a_size = 1;
   for (size_t i = 0; i < a_shape.size(); ++i) {
     a_size = a_size * a_shape[i];
   }
@@ -699,15 +699,16 @@ inline Tensor sequence_mask(const Tensor& data,
   CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,).";
   auto length_dim = data->shape[axis];
   auto batch_dim = data->shape[1 - axis];
-  Array<Expr> out_shape = data->shape;
+  Array<PrimExpr> out_shape = data->shape;
   Tensor out = compute(
       out_shape, [&](const Array<Var>& out_index) {
-        Array<Expr> len_index;
+        Array<PrimExpr> len_index;
         auto tid = out_index[axis];
         auto bid = out_index[1 - axis];
         len_index.push_back(bid);
-        Expr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
-                                     tvm::make_const(data->dtype, mask_value), data(out_index));
+        PrimExpr ret = tvm::if_then_else(
+            tvm::cast(valid_length->dtype, tid) >= valid_length(len_index),
+            tvm::make_const(data->dtype, mask_value), data(out_index));
         return ret;
       }, name, tag);
   return out;
@@ -740,7 +741,7 @@ inline Tensor take(const Tensor& a,
   auto axis_dim = a->shape[axis];
 
   int indices_len = static_cast<int>(indices->shape.size());
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 0; i < a->shape.size(); ++i) {
     if (axis == static_cast<int>(i)) {
       for (size_t j = 0; j < indices->shape.size(); ++j) {
@@ -753,11 +754,11 @@ inline Tensor take(const Tensor& a,
   if (mode == "clip") {
     return compute(
         out_shape, [&](const Array<Var>& out_index) {
-          Array<Expr> indices_position;
+          Array<PrimExpr> indices_position;
           for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
             indices_position.push_back(out_index[j]);
           }
-          Array<Expr> real_indices;
+          Array<PrimExpr> real_indices;
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
@@ -774,11 +775,11 @@ inline Tensor take(const Tensor& a,
                     "Make sure input indices are in bound";
     return compute(
         out_shape, [&](const Array<Var>& out_index) {
-          Array<Expr> indices_position;
+          Array<PrimExpr> indices_position;
           for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
             indices_position.push_back(out_index[j]);
           }
-          Array<Expr> real_indices;
+          Array<PrimExpr> real_indices;
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
@@ -791,11 +792,11 @@ inline Tensor take(const Tensor& a,
   } else {  // mode == "wrap"
     return compute(
         out_shape, [&](const Array<Var>& out_index) {
-          Array<Expr> indices_position;
+          Array<PrimExpr> indices_position;
           for (size_t j = axis; j < static_cast<size_t>(axis+indices_len); ++j) {
             indices_position.push_back(out_index[j]);
           }
-          Array<Expr> real_indices;
+          Array<PrimExpr> real_indices;
           for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
             real_indices.push_back(out_index[j]);
           }
@@ -830,7 +831,7 @@ inline Tensor where(const Tensor& condition,
     << x->shape.size() << " vs " << y->shape.size();
   CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: "
                                << x->dtype << " vs " << y->dtype;
-  Array<Expr> oshape = x->shape;
+  Array<PrimExpr> oshape = x->shape;
   Tensor out;
 
   if (condition->shape.size() != 1) {
@@ -848,7 +849,7 @@ inline Tensor where(const Tensor& condition,
       << condition->shape[0] << " vs " << x->shape[0];
     out = compute(
       oshape, [&](const Array<Var>& indices) {
-        Array<Expr> condition_idx{indices[0]};
+        Array<PrimExpr> condition_idx{indices[0]};
         return tvm::ir::SelectNode::make(condition(condition_idx) != 0,
                                      x(indices), y(indices));
       }, name, tag);
@@ -885,7 +886,7 @@ inline Tensor repeat(const Tensor& x,
     // Calculate offset from last dimension
     axis += ndim;
   }
-  Array<Expr> new_shape;
+  Array<PrimExpr> new_shape;
   for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
     new_shape.push_back(x->shape[i]);
   }
@@ -896,7 +897,7 @@ inline Tensor repeat(const Tensor& x,
 
   return compute(
     new_shape, [&](const Array<Var>& indices) {
-      Array<Expr> idx;
+      Array<PrimExpr> idx;
       for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
         idx.push_back(indices[i]);
       }
@@ -925,9 +926,9 @@ inline Tensor tile(const Tensor& x,
   size_t ndim = x->shape.size();
   size_t rdim = reps.size();
   size_t tdim = (ndim > rdim) ? ndim : rdim;
-  Array<Expr> data_shape;
-  Array<Expr> reps_shape;
-  Array<Expr> new_shape;
+  Array<PrimExpr> data_shape;
+  Array<PrimExpr> reps_shape;
+  Array<PrimExpr> new_shape;
   if (ndim == rdim) {
     for (size_t i = 0; i < ndim; ++i) {
       data_shape.push_back(x->shape[i]);
@@ -958,7 +959,7 @@ inline Tensor tile(const Tensor& x,
   } else {
     return compute(
       new_shape, [&](const Array<Var>& indices) {
-        Array<Expr> idx;
+        Array<PrimExpr> idx;
         if (ndim >= rdim) {
           for (size_t i = 0; i < ndim; ++i)
             idx.push_back(indexmod(indices[i], x->shape[i]));
@@ -991,7 +992,7 @@ inline Tensor gather_nd(const Tensor& data,
   size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
   CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
                                  << "than dimensions of data tensor";
-  Array<Expr> out_shape;
+  Array<PrimExpr> out_shape;
   for (size_t i = 1; i < ndim_i; ++i) {
     out_shape.push_back(indices->shape[i]);
   }
@@ -1003,12 +1004,12 @@ inline Tensor gather_nd(const Tensor& data,
   }
   return compute(
         out_shape, [&](const Array<Var>& out_index) {
-          Array<Expr> indices_position;
+          Array<PrimExpr> indices_position;
           indices_position.push_back(0);
           for (size_t i = 0; i < ndim_i - 1; ++i) {
             indices_position.push_back(out_index[i]);
           }
-          Array<Expr> real_indices;
+          Array<PrimExpr> real_indices;
           for (size_t i = 0; i < indices_dim0; ++i) {
             indices_position.Set(0, make_const(DataType::Int(32), i));
             if (indices->dtype.is_int()) {
@@ -1046,7 +1047,7 @@ inline tvm::Tensor matmul(const tvm::Tensor& A,
                            bool trans_b = false,
                            std::string name = "T_matmul",
                            std::string tag = kMatMul) {
-  tvm::Array<tvm::Expr> output_shape{A->shape[trans_a ? 1 : 0],
+  tvm::Array<tvm::PrimExpr> output_shape{A->shape[trans_a ? 1 : 0],
                                      B->shape[trans_b ? 0 : 1]};
   auto k = tvm::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k");
   auto l = [&](tvm::Var i, tvm::Var j) {
@@ -1075,7 +1076,7 @@ inline Tensor tensordot(const Tensor& A,
   CHECK_GE(A->shape.size(), axes);
   CHECK_GE(B->shape.size(), axes);
 
-  Array<Expr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
+  Array<PrimExpr> output_shape(A->shape.begin(), A->shape.end() + (-axes));
   for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it)
     output_shape.push_back(*it);
 
@@ -1086,13 +1087,13 @@ inline Tensor tensordot(const Tensor& A,
   auto func =
     [&A, &B, &iter_vars, axes]
     (const Array<Var>& input_indices) {
-      Array<Expr> A_indices(
+      Array<PrimExpr> A_indices(
           input_indices.begin(),
           input_indices.begin() + (A->shape.size() - axes));
       for (auto& v : iter_vars)
         A_indices.push_back(v);
 
-      Array<Expr> B_indices;
+      Array<PrimExpr> B_indices;
       for (auto& v : iter_vars)
         B_indices.push_back(v);
 
@@ -1124,8 +1125,8 @@ inline Tensor tensordot(const Tensor& A,
  */
 inline Tensor tensordot(const Tensor& A,
                         const tvm::Tensor& B,
-                        Array<Expr> A_axes,
-                        Array<Expr> B_axes,
+                        Array<PrimExpr> A_axes,
+                        Array<PrimExpr> B_axes,
                         std::string name = "T_tensordot",
                         std::string tag = kMatMul) {
   CHECK_EQ(A_axes.size(), B_axes.size());
@@ -1133,7 +1134,7 @@ inline Tensor tensordot(const Tensor& A,
   auto A_axes_val = GetConstIntValues(A_axes, "A_axes");
   auto B_axes_val = GetConstIntValues(B_axes, "B_axes");
 
-  Array<Expr> output_shape;
+  Array<PrimExpr> output_shape;
   for (unsigned i = 0; i < A->shape.size(); ++i)
     if (std::find(A_axes_val.begin(), A_axes_val.end(), i) == A_axes_val.end())
       output_shape.push_back(A->shape[i]);
@@ -1149,7 +1150,7 @@ inline Tensor tensordot(const Tensor& A,
     [&A, &B, &iter_vars, A_axes_val, B_axes_val]
     (const Array<Var>& input_indices) {
       int idx_input = 0;
-      Array<Expr> A_indices;
+      Array<PrimExpr> A_indices;
       for (unsigned i = 0; i < A->shape.size(); ++i) {
         auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i);
         if (axes_pos == A_axes_val.end())
@@ -1158,7 +1159,7 @@ inline Tensor tensordot(const Tensor& A,
           A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]);
       }
 
-      Array<Expr> B_indices;
+      Array<PrimExpr> B_indices;
       for (unsigned i = 0; i < B->shape.size(); ++i) {
         auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i);
         if (axes_pos == B_axes_val.end())
@@ -1171,15 +1172,15 @@ inline Tensor tensordot(const Tensor& A,
   return compute(output_shape, func, name, tag);
 }
 
-inline Tensor arange(const Expr& start,
-                     const Expr& stop,
-                     const Expr& step,
+inline Tensor arange(const PrimExpr& start,
+                     const PrimExpr& stop,
+                     const PrimExpr& step,
                      DataType dtype,
                      std::string name = "T_arange",
                      std::string tag = kInjective) {
-  Expr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil(
+  PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil(
       tvm::cast(tvm::DataType::Float(32), stop - start) / step));
-  Array<Expr> shape;
+  Array<PrimExpr> shape;
   return compute({num_elem}, [&](const Array<Var>& indices) {
     return tvm::cast(dtype, start + step * indices[0]);
   }, name, tag);
@@ -1213,12 +1214,12 @@ inline Tensor layout_transform(const Tensor& src,
   CHECK(layout_converter.defined())
     << "cannot convert from " << src_layout << " to " << dst_layout;
 
-  Array<Expr> dst_shape = layout_converter.ForwardShape(src->shape);
+  Array<PrimExpr> dst_shape = layout_converter.ForwardShape(src->shape);
 
   return compute(
     dst_shape, [&](const Array<Var>& dst_indices) {
-      Array<Expr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
-      Array<Expr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
+      Array<PrimExpr> dst_indices_expr(dst_indices.begin(), dst_indices.end());
+      Array<PrimExpr> src_indices = layout_converter.BackwardIndex(dst_indices_expr);
       return src(src_indices);
   }, name, tag);
 }
@@ -1236,10 +1237,10 @@ inline Tensor shape(const Tensor& src,
                     const std::string name = "T_shape",
                     const std::string tag = kInjective) {
   int ndim = static_cast<int>(src->shape.size());
-  Array<Expr> out_shape{ndim};
+  Array<PrimExpr> out_shape{ndim};
   return compute(out_shape, [&](const Array<Var>& indices) {
     auto idx = indices[0];
-    Expr ret = 0;
+    PrimExpr ret = 0;
     for (int i = 0; i < ndim; ++i) {
       ret = tvm::if_then_else(idx == i, src->shape[i], ret);
     }
@@ -1260,9 +1261,9 @@ inline Tensor ndarray_size(const Tensor& src,
                            const std::string& name = "ndarray_size",
                            const std::string& tag = kInjective) {
   int ndim = static_cast<int>(src->shape.size());
-  Array<Expr> out_ndarray_size = {1};
+  Array<PrimExpr> out_ndarray_size = {1};
   return compute(out_ndarray_size, [&](const Array<Var>& indices) {
-    Expr ret = 1;
+    PrimExpr ret = 1;
     for (int i = 0; i < ndim; ++i) {
       ret *= src->shape[i];
     }
@@ -1284,14 +1285,14 @@ inline Tensor ndarray_size(const Tensor& src,
  * \return one-hot tensor.
  */
 inline Tensor one_hot(const Tensor& indices,
-                      const Expr on_value,
-                      const Expr off_value,
+                      const PrimExpr on_value,
+                      const PrimExpr off_value,
                       int depth,
                       int axis,
                       const DataType& dtype,
                       const std::string name = "T_one_hot",
                       const std::string tag = kInjective) {
-  Array<Expr> oshape;
+  Array<PrimExpr> oshape;
   int ndim = indices->shape.size() + 1;
   int indices_index = 0;
   int true_axis = (axis == -1) ? indices->shape.size() : axis;
@@ -1303,8 +1304,8 @@ inline Tensor one_hot(const Tensor& indices,
     }
   }
 
-  Expr on_value_cast = cast(dtype, on_value);
-  Expr off_value_cast = cast(dtype, off_value);
+  PrimExpr on_value_cast = cast(dtype, on_value);
+  PrimExpr off_value_cast = cast(dtype, off_value);
   return compute(oshape, [&](const Array<Var>& iter_vars) {
     Array<Var> indices_indices;
     for (size_t i = 0; i < iter_vars.size(); i++) {
index 08dd136..df3fade 100644 (file)
@@ -74,7 +74,7 @@ inline Tensor reorg(const Tensor &data,
   int out_h = h_in / stride;
   int out_w = w_in / stride;
 
-  Array<Expr> out_shape = {batch, out_c, out_h, out_w};
+  Array<PrimExpr> out_shape = {batch, out_c, out_h, out_w};
   return reshape(out, out_shape);
 }
 }  // namespace vision
index acffcb5..13f8e72 100644 (file)
@@ -57,7 +57,7 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
     out_shape = tuple(
         tvm.ir_pass.Simplify(
             (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n))
-    pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr)
+    pad_value = (pad_value if isinstance(pad_value, tvm.expr.PrimExpr)
                  else tvm.const(pad_value, data.dtype))
     def _pad(*indices):
         not_zero = []
index 079dda5..8f32a29 100644 (file)
@@ -198,7 +198,7 @@ def simplify(expr):
     out : Expr or int
         The simplified output
     """
-    return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr
+    return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.PrimExpr) else expr
 
 
 def ravel_index(indices, shape):
index eeb1249..e9c9bc0 100644 (file)
@@ -112,14 +112,14 @@ TVM_REGISTER_GLOBAL("topi.TEST_create_target")
         *rv = Op(args[0].operator tvm::Tensor(),                        \
                  args[1].operator tvm::Tensor());                       \
       } else if (!lhs_is_tensor && rhs_is_tensor) {                     \
-        *rv = Op(args[0].operator tvm::Expr(),                          \
+        *rv = Op(args[0].operator tvm::PrimExpr(),                          \
                  args[1].operator tvm::Tensor());                       \
       } else if (lhs_is_tensor && !rhs_is_tensor) {                     \
         *rv = Op(args[0].operator tvm::Tensor(),                        \
-                 args[1].operator tvm::Expr());                         \
+                 args[1].operator tvm::PrimExpr());                         \
       } else if (!lhs_is_tensor && !rhs_is_tensor) {                    \
-        *rv = Op(args[0].operator tvm::Expr(),                          \
-                 args[1].operator tvm::Expr());                         \
+        *rv = Op(args[0].operator tvm::PrimExpr(),                          \
+                 args[1].operator tvm::PrimExpr());                         \
       }                                                                 \
     });                                                                 \
 
@@ -433,7 +433,7 @@ TVM_REGISTER_GLOBAL("topi.tensordot")
   } else if (args.size() == 3) {
     *rv = tensordot(args[0], args[1], args[2]);
   } else {
-    Array<Expr> axes = args[3];
+    Array<PrimExpr> axes = args[3];
     *rv = tensordot(args[0], args[1], args[2], axes);
   }
   });
index 8cf5efe..4361f8f 100644 (file)
@@ -58,8 +58,8 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
     B = (tvm.var("B", dtype=dtype) if rhs_shape is None
          else tvm.placeholder(shape=rhs_shape, name="B", dtype=dtype))
     C = ftopi(A, B)
-    if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
-        assert(isinstance(C, tvm.expr.Expr))
+    if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr):
+        assert(isinstance(C, tvm.expr.PrimExpr))
         return
 
     def gen_operand(shape, low, high, ctx):
@@ -241,8 +241,8 @@ def test_logical_single_ele():
         # Build the logic and compile the function
         A = tvm.placeholder(shape=indata.shape, name="A", dtype=dtype)
         B = func(A)
-        if isinstance(A, tvm.expr.Expr):
-            assert (isinstance(B, tvm.expr.Expr))
+        if isinstance(A, tvm.expr.PrimExpr):
+            assert (isinstance(B, tvm.expr.PrimExpr))
             return
 
         def check_device(device):
@@ -283,8 +283,8 @@ def test_logical_binary_ele():
         A = (tvm.var("A", dtype=dtype))
         B = (tvm.var("B", dtype=dtype))
         C = func(A, B)
-        if isinstance(A, tvm.expr.Expr) and isinstance(B, tvm.expr.Expr):
-            assert (isinstance(C, tvm.expr.Expr))
+        if isinstance(A, tvm.expr.PrimExpr) and isinstance(B, tvm.expr.PrimExpr):
+            assert (isinstance(C, tvm.expr.PrimExpr))
             return
 
         def check_device(device):