[REFACTOR][IR] Add Node suffix to low-level IR nodes (#4649)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 8 Jan 2020 17:01:00 +0000 (09:01 -0800)
committerGitHub <noreply@github.com>
Wed, 8 Jan 2020 17:01:00 +0000 (09:01 -0800)
* [REFACTOR][IR] Variable -> VarNode

* [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc.

* [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc.

* [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc.

* [REFACTOR][IR] Add Node suffix to Select/Call/Load/Ramp/Shuffle/Let

* [REFACTOR][IR] Add node suffix to IntImm/UIntImm/FloatImm/StringImm

* [REFACTOR][IR] Add Node suffix to Any, AttrStmt, AssertStmt

* [REFACTOR][IR] Add Node suffix to Store/Provide/Allocate/Free

* [REFACTOR][IR] Add Node suffix to ProducerConsumer

* Fix lint

* style updates, test fixes

194 files changed:
include/tvm/arithmetic.h
include/tvm/attrs.h
include/tvm/expr.h
include/tvm/expr_operator.h
include/tvm/ir.h
include/tvm/ir_functor_ext.h
include/tvm/ir_pass.h
include/tvm/operation.h
include/tvm/relay/type.h
src/api/api_ir.cc
src/api/api_lang.cc
src/arithmetic/analyzer.cc
src/arithmetic/bound_deducer.cc
src/arithmetic/canonical_simplify.cc
src/arithmetic/compute_expr.h
src/arithmetic/const_fold.h
src/arithmetic/const_int_bound.cc
src/arithmetic/detect_linear_equation.cc
src/arithmetic/domain_touched.cc
src/arithmetic/int_operator.h
src/arithmetic/int_set.cc
src/arithmetic/ir_mutator_with_analyzer.cc
src/arithmetic/ir_mutator_with_analyzer.h
src/arithmetic/ir_visitor_with_analyzer.h
src/arithmetic/modular_set.cc
src/arithmetic/pattern_match.h
src/arithmetic/rewrite_simplify.cc
src/arithmetic/rewrite_simplify.h
src/arithmetic/stmt_simplify.cc
src/autotvm/feature_visitor.cc
src/autotvm/feature_visitor.h
src/autotvm/touch_extractor.cc
src/autotvm/touch_extractor.h
src/codegen/build_module.cc
src/codegen/codegen_c.cc
src/codegen/codegen_c.h
src/codegen/codegen_c_host.cc
src/codegen/codegen_c_host.h
src/codegen/codegen_cuda.cc
src/codegen/codegen_cuda.h
src/codegen/codegen_metal.cc
src/codegen/codegen_metal.h
src/codegen/codegen_opencl.cc
src/codegen/codegen_opencl.h
src/codegen/codegen_opengl.cc
src/codegen/codegen_opengl.h
src/codegen/codegen_source_base.cc
src/codegen/codegen_source_base.h
src/codegen/codegen_vhls.cc
src/codegen/codegen_vhls.h
src/codegen/intrin_rule.cc
src/codegen/intrin_rule.h
src/codegen/llvm/codegen_amdgpu.cc
src/codegen/llvm/codegen_arm.cc
src/codegen/llvm/codegen_cpu.cc
src/codegen/llvm/codegen_cpu.h
src/codegen/llvm/codegen_llvm.cc
src/codegen/llvm/codegen_llvm.h
src/codegen/llvm/codegen_nvptx.cc
src/codegen/llvm/codegen_x86_64.cc
src/codegen/llvm/intrin_rule_llvm.cc
src/codegen/llvm/intrin_rule_llvm.h
src/codegen/llvm/intrin_rule_nvptx.cc
src/codegen/llvm/intrin_rule_rocm.cc
src/codegen/spirv/codegen_spirv.cc
src/codegen/spirv/codegen_spirv.h
src/codegen/spirv/intrin_rule_spirv.cc
src/codegen/stackvm/codegen_stackvm.cc
src/codegen/stackvm/codegen_stackvm.h
src/contrib/hybrid/codegen_hybrid.cc
src/contrib/hybrid/codegen_hybrid.h
src/lang/attr_functor.h
src/lang/attrs.cc
src/lang/buffer.cc
src/lang/data_layout.cc
src/lang/expr.cc
src/lang/expr_operator.cc
src/lang/ir.cc
src/lang/tensor.cc
src/op/compute_op.cc
src/op/cross_thread_reduction.cc
src/op/extern_op.cc
src/op/hybrid_op.cc
src/op/op_util.cc
src/op/placeholder_op.cc
src/op/scan_op.cc
src/op/tensor_compute_op.cc
src/op/tensorize.cc
src/pass/arg_binder.cc
src/pass/arg_binder.h
src/pass/bound_checker.cc
src/pass/combine_context_call.cc
src/pass/coproc_sync.cc
src/pass/detect_device.cc
src/pass/hoist_if_then_else.cc
src/pass/infer_fragment.cc
src/pass/inject_copy_intrin.cc
src/pass/inject_double_buffer.cc
src/pass/inject_prefetch.cc
src/pass/inject_virtual_thread.cc
src/pass/inline.cc
src/pass/ir_deep_compare.cc
src/pass/ir_functor.cc
src/pass/ir_util.cc
src/pass/ir_util.h
src/pass/lift_attr_scope.cc
src/pass/loop_partition.cc
src/pass/lower_custom_datatypes.cc
src/pass/lower_intrin.cc
src/pass/lower_thread_allreduce.cc
src/pass/lower_tvm_builtin.cc
src/pass/lower_warp_memory.cc
src/pass/make_api.cc
src/pass/remap_thread_axis.cc
src/pass/remove_no_op.cc
src/pass/rewrite_unsafe_select.cc
src/pass/simple_passes.cc
src/pass/skip_assert.cc
src/pass/split_host_device.cc
src/pass/ssa.cc
src/pass/storage_access.cc
src/pass/storage_access.h
src/pass/storage_flatten.cc
src/pass/storage_rewrite.cc
src/pass/storage_sync.cc
src/pass/tensor_core.cc
src/pass/unroll_loop.cc
src/pass/vectorize_loop.cc
src/pass/verify_compact_buffer.cc
src/pass/verify_gpu_code.cc
src/pass/verify_memory.cc
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/backend/contrib/dnnl/codegen.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/vm/lambda_lift.cc
src/relay/backend/vm/removed_unused_funcs.cc
src/relay/ir/alpha_equal.cc
src/relay/ir/expr.cc
src/relay/ir/hash.cc
src/relay/ir/pretty_printer.cc
src/relay/op/memory/memory.cc
src/relay/op/nn/convolution.cc
src/relay/op/nn/convolution.h
src/relay/op/nn/nn.cc
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/nn/upsampling.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/canonicalize_cast.cc
src/relay/pass/canonicalize_ops.cc
src/relay/pass/combine_parallel_conv2d.cc
src/relay/pass/combine_parallel_dense.cc
src/relay/pass/combine_parallel_op_batch.cc
src/relay/pass/convert_layout.cc
src/relay/pass/device_annotation.cc
src/relay/pass/eliminate_common_subexpr.cc
src/relay/pass/fold_constant.cc
src/relay/pass/fold_scale_axis.cc
src/relay/pass/fuse_ops.cc
src/relay/pass/infer_layout_util.h
src/relay/pass/legalize.cc
src/relay/pass/mac_count.cc
src/relay/pass/pass_manager.cc
src/relay/pass/pattern_util.h
src/relay/pass/simplify_inference.cc
src/relay/pass/type_solver.cc
src/schedule/auto_inline_elem_wise.cc
src/schedule/bound.cc
src/schedule/graph.cc
src/schedule/message_passing.cc
src/schedule/schedule_dataflow_rewrite.cc
src/schedule/schedule_lang.cc
src/schedule/schedule_ops.cc
tests/cpp/attrs_test.cc
tests/cpp/container_test.cc
tests/cpp/expr_test.cc
tests/cpp/ir_functor_test.cc
tests/cpp/ir_simplify_test.cc
tests/cpp/ir_ssa_test.cc
tests/cpp/packed_func_test.cc
tests/cpp/pattern_match_test.cc
topi/include/topi/detail/broadcast.h
topi/include/topi/detail/constant_utils.h
topi/include/topi/detail/extern.h
topi/include/topi/detail/tensor_utils.h
topi/include/topi/elemwise.h
topi/include/topi/nn.h
topi/include/topi/nn/pooling.h
topi/include/topi/reduction.h
topi/include/topi/transform.h

index e5f7567..d135d30 100644 (file)
@@ -564,7 +564,7 @@ IntSet EvalSet(Expr e,
  * \return An integer set that can cover all the possible values of e.
  */
 IntSet EvalSet(Expr e,
-               const std::unordered_map<const Variable*, IntSet>& dom_map);
+               const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*!
  * \brief Find an symbolic integer set that contains is union over
@@ -586,7 +586,7 @@ IntSet EvalSet(Range r,
  * \return An integer set that can cover all the possible values.
  */
 IntSet EvalSet(IntSet s,
-               const std::unordered_map<const Variable*, IntSet>& dom_map);
+               const std::unordered_map<const VarNode*, IntSet>& dom_map);
 /*!
  * \brief Same as EvalSet, but takes unordered_map
  *
@@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s,
  * \return An integer set that can cover all the possible values of e.
  */
 IntSet EvalSet(Range r,
-               const std::unordered_map<const Variable*, IntSet>& dom_map);
+               const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*! \brief Map from Expr to IntSet */
 using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
@@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
  */
 ExprIntSetMap EvalSetForEachSubExpr(
     Expr e,
-    const std::unordered_map<const Variable*, IntSet>& dom_map);
+    const std::unordered_map<const VarNode*, IntSet>& dom_map);
 
 /*!
  * \brief Create an union set of all sets
@@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond,
  * \return An integer set that always satisfies the condition.
  */
 IntSet DeduceBound(Expr v, Expr cond,
-                   const std::unordered_map<const Variable*, IntSet>& hint_map,
-                   const std::unordered_map<const Variable*, IntSet>& relax_map);
+                   const std::unordered_map<const VarNode*, IntSet>& hint_map,
+                   const std::unordered_map<const VarNode*, IntSet>& relax_map);
 
 /*!
  * \brief Infer a regular domain that covers all the calls or provides within the given statement.
index 0178eab..13c8b30 100644 (file)
@@ -488,9 +488,9 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
   } else {
     Expr expr = val;
     CHECK(expr.defined());
-    if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+    if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<T>(op->value);
-    } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
+    } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
       *ptr = static_cast<T>(op->value);
     } else {
       LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
@@ -503,7 +503,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
     *ptr = val.operator std::string();
   } else {
     Expr expr = val;
-    const ir::StringImm* op = expr.as<ir::StringImm>();
+    const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
     CHECK(op != nullptr);
     *ptr = op->value;
   }
@@ -519,11 +519,11 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
   } else {
     Expr expr = val;
     CHECK(expr.defined());
-    if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+    if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<double>(op->value);
-    } else if (const ir::IntImm* op = expr.as<ir::IntImm>()) {
+    } else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
       *ptr = static_cast<double>(op->value);
-    } else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
+    } else if (const ir::UIntImmNode* op = expr.as<ir::UIntImmNode>()) {
       *ptr = static_cast<double>(op->value);
     } else {
       LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
index aee565d..64d7547 100644 (file)
@@ -102,7 +102,7 @@ class Var;
  * - Let
  * - LetStmt
  */
-class Variable : public ExprNode {
+class VarNode : public ExprNode {
  public:
   /*!
    * \brief The hint to the variable name.
@@ -118,7 +118,7 @@ class Variable : public ExprNode {
   }
 
   static constexpr const char* _type_key = "Variable";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
 };
 
 /*! \brief a named variable in TVM */
@@ -139,18 +139,18 @@ class Var : public Expr {
    * \brief Get pointer to the internal value.
    * \return the corresponding Variable.
    */
-  const Variable* operator->() const {
+  const VarNode* operator->() const {
     return get();
   }
   /*!
    * \brief Get pointer to the internal value.
    * \return the corresponding Variable.
    */
-  const Variable* get() const {
-    return static_cast<const Variable*>(data_.get());
+  const VarNode* get() const {
+    return static_cast<const VarNode*>(data_.get());
   }
   /*! \brief type indicate the container type */
-  using ContainerType = Variable;
+  using ContainerType = VarNode;
 };
 
 // Backward compatibility, will be removed later.
@@ -161,7 +161,7 @@ using ExprEqual = ObjectEqual;
 
 class Integer;
 /*! \brief ExprNode: constant integer. */
-class IntImm : public ExprNode {
+class IntImmNode : public ExprNode {
  public:
   /*! \brief the Internal value. */
   int64_t value;
@@ -174,7 +174,7 @@ class IntImm : public ExprNode {
   TVM_DLL static Integer make(DataType t, int64_t value);
 
   static constexpr const char* _type_key = "IntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IntImm, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, ExprNode);
 };
 
 /*!
@@ -206,8 +206,8 @@ class Integer : public Expr {
    * \brief Get pointer to the internal value.
    * \return the content of the integer.
    */
-  const IntImm* operator->() const {
-    return static_cast<const IntImm*>(get());
+  const IntImmNode* operator->() const {
+    return static_cast<const IntImmNode*>(get());
   }
   /*!
    * \brief convert to int64_t
@@ -218,7 +218,7 @@ class Integer : public Expr {
     return (*this)->value;
   }
   /*! \brief type indicate the container type */
-  using ContainerType = IntImm;
+  using ContainerType = IntImmNode;
 };
 
 /*! \brief range over one dimension */
index a73edb4..bf8b1a3 100644 (file)
@@ -75,7 +75,7 @@ inline Expr const_false(int lanes = 1) {
  */
 inline const int64_t* as_const_int(const Expr& x) {
   if (!x.defined()) return nullptr;
-  if (const ir::IntImm* op = x.as<ir::IntImm>()) {
+  if (const ir::IntImmNode* op = x.as<ir::IntImmNode>()) {
     return &(op->value);
   } else {
     return nullptr;
@@ -90,7 +90,7 @@ inline const int64_t* as_const_int(const Expr& x) {
  */
 inline const uint64_t* as_const_uint(const Expr& x) {
   if (!x.defined()) return nullptr;
-  if (const ir::UIntImm* op = x.as<ir::UIntImm>()) {
+  if (const ir::UIntImmNode* op = x.as<ir::UIntImmNode>()) {
     return &(op->value);
   } else {
     return nullptr;
@@ -600,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline Expr OpName(Expr x) {                                          \
-    return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
+    return ir::CallNode::make(x.dtype(), #OpName, {x}, ir::CallNode::PureIntrinsic); \
   }                                                                     \
 
 TVM_DECLARE_INTRIN_UNARY(exp);
@@ -617,11 +617,11 @@ TVM_DECLARE_INTRIN_UNARY(atan);
 
 // Implementation details after this
 inline bool is_const(const Expr& x) {
-  if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
+  if (x.as<ir::IntImmNode>() || x.as<ir::UIntImmNode>()) {
     return true;
-  } else if (const auto* op = x.as<ir::Broadcast>()) {
+  } else if (const auto* op = x.as<ir::BroadcastNode>()) {
     const Expr& val = op->value;
-    if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) {
+    if (val.as<ir::IntImmNode>() || val.as<ir::UIntImmNode>()) {
       return true;
     }
   }
@@ -629,9 +629,9 @@ inline bool is_const(const Expr& x) {
 }
 
 inline bool is_positive_const(const Expr& a) {
-  if (const ir::IntImm* op = a.as<ir::IntImm>()) {
+  if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
     return op->value > 0;
-  } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) {
+  } else if (const ir::UIntImmNode* op = a.as<ir::UIntImmNode>()) {
     return op->value > 0;
   } else {
     return false;
@@ -639,7 +639,7 @@ inline bool is_positive_const(const Expr& a) {
 }
 
 inline bool is_negative_const(const Expr& a) {
-  if (const ir::IntImm* op = a.as<ir::IntImm>()) {
+  if (const ir::IntImmNode* op = a.as<ir::IntImmNode>()) {
     return op->value < 0;
   } else {
     return false;
@@ -647,15 +647,15 @@ inline bool is_negative_const(const Expr& a) {
 }
 
 inline bool is_const_int(const Expr& x, int64_t value) {
-  if (const auto* op = x.as<ir::IntImm>()) {
+  if (const auto* op = x.as<ir::IntImmNode>()) {
     return op->value == value;
-  } else if (const auto* op = x.as<ir::UIntImm>()) {
+  } else if (const auto* op = x.as<ir::UIntImmNode>()) {
     return op->value == static_cast<uint64_t>(value);
-  } else if (const auto* op = x.as<ir::Broadcast>()) {
+  } else if (const auto* op = x.as<ir::BroadcastNode>()) {
     const Expr& val = op->value;
-    if (const auto* opv = val.as<ir::IntImm>()) {
+    if (const auto* opv = val.as<ir::IntImmNode>()) {
       return opv->value == value;
-    } else if (const auto* opv = val.as<ir::UIntImm>()) {
+    } else if (const auto* opv = val.as<ir::UIntImmNode>()) {
       return opv->value == static_cast<uint64_t>(value);
     }
   }
@@ -664,7 +664,7 @@ inline bool is_const_int(const Expr& x, int64_t value) {
 
 inline bool is_no_op(const Stmt& stmt) {
   if (!stmt.defined()) return true;
-  if (const auto* op = stmt.as<ir::Evaluate>()) {
+  if (const auto* op = stmt.as<ir::EvaluateNode>()) {
     return is_const(op->value);
   }
   if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
@@ -675,15 +675,15 @@ inline bool is_no_op(const Stmt& stmt) {
 
 template<typename ValueType>
 inline Expr MakeConstScalar(DataType t, ValueType value) {
-  if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
-  if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
-  if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
+  if (t.is_int()) return ir::IntImmNode::make(t, static_cast<int64_t>(value));
+  if (t.is_uint()) return ir::UIntImmNode::make(t, static_cast<uint64_t>(value));
+  if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
   // For now, we store const scalar values of custom datatypes within doubles; later, during the
   // datatypes lowering pass, we will lower the value to its true representation in the format
   // specified by the datatype.
   // TODO(gus) when do we need to start worrying about doubles not being precise enough?
   if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin))
-    return ir::FloatImm::make(t, static_cast<double>(value));
+    return ir::FloatImmNode::make(t, static_cast<double>(value));
   LOG(FATAL) << "cannot make const for type " << t;
   return Expr();
 }
@@ -693,7 +693,7 @@ inline Expr make_const(DataType t, ValueType value) {
   if (t.lanes() == 1) {
     return MakeConstScalar(t, value);
   } else {
-    return ir::Broadcast::make(
+    return ir::BroadcastNode::make(
         MakeConstScalar(t.element_of(), value), t.lanes());
   }
 }
index b1cefff..11ce09d 100644 (file)
 namespace tvm {
 namespace ir {
 
-using IntImm = tvm::IntImm;
-using Variable = tvm::Variable;
+using IntImmNode = tvm::IntImmNode;
+using VarNode = tvm::VarNode;
 
 /*! \brief constant unsigned integer. */
-class UIntImm : public ExprNode {
+class UIntImmNode : public ExprNode {
  public:
   /*! \brief The constant value content. */
   uint64_t value;
@@ -53,11 +53,11 @@ class UIntImm : public ExprNode {
   TVM_DLL static Expr make(DataType t, uint64_t value);
 
   static constexpr const char* _type_key = "UIntImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(UIntImm, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(UIntImmNode, ExprNode);
 };
 
 /*! \brief Floating point constants. */
-class FloatImm : public ExprNode {
+class FloatImmNode : public ExprNode {
  public:
   /*! \brief The constant value content. */
   double value;
@@ -70,11 +70,11 @@ class FloatImm : public ExprNode {
   TVM_DLL static Expr make(DataType t, double value);
 
   static constexpr const char* _type_key = "FloatImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImm, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, ExprNode);
 };
 
 /*! \brief String constants, only used in asserts. */
-class StringImm : public ExprNode {
+class StringImmNode : public ExprNode {
  public:
   /*! \brief The constant value content. */
   std::string value;
@@ -87,14 +87,14 @@ class StringImm : public ExprNode {
   TVM_DLL Expr static make(std::string value);
 
   static constexpr const char* _type_key = "StringImm";
-  TVM_DECLARE_FINAL_OBJECT_INFO(StringImm, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, ExprNode);
 };
 
 /*!
  * \brief Cast value from one data type to another.
  * \note The lanes of value should keep fixed.
  */
-class Cast : public ExprNode {
+class CastNode : public ExprNode {
  public:
   /*! \brief Original data type. */
   Expr value;
@@ -107,7 +107,7 @@ class Cast : public ExprNode {
   TVM_DLL static Expr make(DataType t, Expr v);
 
   static constexpr const char* _type_key = "Cast";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode);
 };
 
 /*!
@@ -143,19 +143,19 @@ class BinaryOpNode : public ExprNode {
 };
 
 /*! \brief a + b */
-class Add : public BinaryOpNode<Add> {
+class AddNode : public BinaryOpNode<AddNode> {
  public:
   static constexpr const char* _type_key = "Add";
 };
 
 /*! \brief a - b */
-class Sub : public BinaryOpNode<Sub> {
+class SubNode : public BinaryOpNode<SubNode> {
  public:
   static constexpr const char* _type_key = "Sub";
 };
 
 /*! \brief a * b */
-class Mul : public BinaryOpNode<Mul> {
+class MulNode : public BinaryOpNode<MulNode> {
  public:
   static constexpr const char* _type_key = "Mul";
 };
@@ -164,7 +164,7 @@ class Mul : public BinaryOpNode<Mul> {
  * \brief a / b in the C semnatics.
  * \note For integer division, C standard uses trunc div.
  */
-class Div : public BinaryOpNode<Div> {
+class DivNode : public BinaryOpNode<DivNode> {
  public:
   static constexpr const char* _type_key = "Div";
 };
@@ -173,31 +173,31 @@ class Div : public BinaryOpNode<Div> {
  * \brief a % b in the C semnatics.
  * \note For integer division, C standard uses trunc div.
  */
-class Mod : public BinaryOpNode<Mod> {
+class ModNode : public BinaryOpNode<ModNode> {
  public:
   static constexpr const char* _type_key = "Mod";
 };
 
 /*! \brief Floor division, floor(a/b) */
-class FloorDiv : public BinaryOpNode<FloorDiv> {
+class FloorDivNode : public BinaryOpNode<FloorDivNode> {
  public:
   static constexpr const char* _type_key = "FloorDiv";
 };
 
 /*! \brief The remainder of the floordiv */
-class FloorMod : public BinaryOpNode<FloorMod> {
+class FloorModNode : public BinaryOpNode<FloorModNode> {
  public:
   static constexpr const char* _type_key = "FloorMod";
 };
 
 /*! \brief min(a, b) */
-class Min : public BinaryOpNode<Min> {
+class MinNode : public BinaryOpNode<MinNode> {
  public:
   static constexpr const char* _type_key = "Min";
 };
 
 /*! \brief max(a, b) */
-class Max : public BinaryOpNode<Max> {
+class MaxNode : public BinaryOpNode<MaxNode> {
  public:
   static constexpr const char* _type_key = "Max";
 };
@@ -235,43 +235,43 @@ class CmpOpNode : public ExprNode {
 };
 
 /*! \brief a == b */
-class EQ : public CmpOpNode<EQ> {
+class EQNode : public CmpOpNode<EQNode> {
  public:
   static constexpr const char* _type_key = "EQ";
 };
 
 /*! \brief a != b */
-class NE : public CmpOpNode<NE> {
+class NENode : public CmpOpNode<NENode> {
  public:
   static constexpr const char* _type_key = "NE";
 };
 
 /*! \brief a < b */
-class LT : public CmpOpNode<LT> {
+class LTNode : public CmpOpNode<LTNode> {
  public:
   static constexpr const char* _type_key = "LT";
 };
 
 /*! \brief a <= b */
-struct LE : public CmpOpNode<LE> {
+struct LENode : public CmpOpNode<LENode> {
  public:
   static constexpr const char* _type_key = "LE";
 };
 
 /*! \brief a > b */
-class GT : public CmpOpNode<GT> {
+class GTNode : public CmpOpNode<GTNode> {
  public:
   static constexpr const char* _type_key = "GT";
 };
 
 /*! \brief a >= b */
-class GE : public CmpOpNode<GE> {
+class GENode : public CmpOpNode<GENode> {
  public:
   static constexpr const char* _type_key = "GE";
 };
 
 /*! \brief a && b */
-class And : public ExprNode {
+class AndNode : public ExprNode {
  public:
   /*! \brief The left operand. */
   Expr a;
@@ -287,11 +287,11 @@ class And : public ExprNode {
   TVM_DLL static Expr make(Expr a, Expr b);
 
   static constexpr const char* _type_key = "And";
-  TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode);
 };
 
 /*! \brief a || b */
-class Or : public ExprNode {
+class OrNode : public ExprNode {
  public:
   /*! \brief The left operand. */
   Expr a;
@@ -307,11 +307,11 @@ class Or : public ExprNode {
   TVM_DLL static Expr make(Expr a, Expr b);
 
   static constexpr const char* _type_key = "Or";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode);
 };
 
 /*! \brief !a */
-class Not : public ExprNode {
+class NotNode : public ExprNode {
  public:
   /*! \brief The input operand. */
   Expr a;
@@ -324,7 +324,7 @@ class Not : public ExprNode {
   TVM_DLL static Expr make(Expr a);
 
   static constexpr const char* _type_key = "Not";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode);
 };
 
 /*!
@@ -334,7 +334,7 @@ class Not : public ExprNode {
  *       Do not use it to guard against out of bound access,
  *       please use if_then_else instead.
  */
-class Select : public ExprNode {
+class SelectNode : public ExprNode {
  public:
   /*! \brief The condition */
   Expr condition;
@@ -353,7 +353,7 @@ class Select : public ExprNode {
   TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value);
 
   static constexpr const char* _type_key = "Select";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode);
 };
 
 /*!
@@ -371,7 +371,7 @@ class Select : public ExprNode {
  *
  * \endcode
  */
-class Load : public ExprNode {
+class LoadNode : public ExprNode {
  public:
   /*! \brief The buffer variable. */
   Var buffer_var;
@@ -390,7 +390,7 @@ class Load : public ExprNode {
   TVM_DLL static Expr make(DataType dtype, Var buffer_var, Expr index, Expr predicate);
 
   static constexpr const char* _type_key = "Load";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Load, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, ExprNode);
 };
 
 /*!
@@ -402,7 +402,7 @@ class Load : public ExprNode {
  *  - ramp(0, 1, 3) = [0, 1, 2]
  *  - ramp(1, 2, 4) = [1, 3, 5, 7]
  */
-class Ramp : public ExprNode {
+class RampNode : public ExprNode {
  public:
   /*! \brief The base value. */
   Expr base;
@@ -421,11 +421,11 @@ class Ramp : public ExprNode {
   TVM_DLL static Expr make(Expr base, Expr stride, int lanes);
 
   static constexpr const char* _type_key = "Ramp";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Ramp, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, ExprNode);
 };
 
 /*! \brief Create a vector where all the elements are value. */
-class Broadcast : public ExprNode {
+class BroadcastNode : public ExprNode {
  public:
   /*! \brief The base value. */
   Expr value;
@@ -441,13 +441,13 @@ class Broadcast : public ExprNode {
   TVM_DLL static Expr make(Expr value, int lanes);
 
   static constexpr const char* _type_key = "Broadcast";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Broadcast, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, ExprNode);
 };
 
 /*!
  * \brief Let binding. Bind var to value then evaluate body.
  */
-class Let : public ExprNode {
+class LetNode : public ExprNode {
  public:
   /*! \brief The variable. */
   Var var;
@@ -466,7 +466,7 @@ class Let : public ExprNode {
   TVM_DLL static Expr make(Var var, Expr value, Expr body);
 
   static constexpr const char* _type_key = "Let";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Let, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
 };
 
 // Call node, represent a function call or a multi-dimensional array load.
@@ -494,7 +494,7 @@ class FunctionRef : public ObjectRef {
 /*!
  * \brief Call node.
  */
-class Call : public ExprNode {
+class CallNode : public ExprNode {
  public:
   /*! \brief Possible types of calls. */
   enum CallType : int {
@@ -560,7 +560,7 @@ class Call : public ExprNode {
   bool is_vectorizable() const;
 
   static constexpr const char* _type_key = "Call";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Call, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
 
   // Build-in intrinsics
   static constexpr const char* reinterpret = "reinterpret";
@@ -585,7 +585,7 @@ class Call : public ExprNode {
  *  vec = concat(vectors)
  *  result = (vec[indices[0]], vec[indices[1]] ...)
  */
-class Shuffle : public ExprNode {
+class ShuffleNode : public ExprNode {
  public:
   /*! \brief the input vectors. */
   Array<Expr> vectors;
@@ -602,7 +602,7 @@ class Shuffle : public ExprNode {
   TVM_DLL static Expr make_extract_element(Expr vector, int index);
 
   static constexpr const char* _type_key = "Shuffle";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Shuffle, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, ExprNode);
 };
 
 // Reduce operator
@@ -671,7 +671,7 @@ inline const CommReducerNode* CommReducer::operator->() const {
 }
 
 /*! \brief Reduction operator operator */
-class Reduce : public ExprNode {
+class ReduceNode : public ExprNode {
  public:
   /*! \brief The commutative combiner */
   CommReducer combiner;
@@ -704,29 +704,29 @@ class Reduce : public ExprNode {
   }
 
   static constexpr const char* _type_key = "Reduce";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Reduce, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, ExprNode);
 };
 
 /*! \brief Any shape. */
-class Any : public ExprNode {
+class AnyNode : public ExprNode {
  public:
   void VisitAttrs(AttrVisitor* v) {}
   /*! \brief Convert to var. */
   Var ToVar() const {
-    return Variable::make(DataType::Int(32), "any_dim");
+    return VarNode::make(DataType::Int(32), "any_dim");
   }
 
   TVM_DLL static Expr make();
 
   static constexpr const char* _type_key = "Any";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Any, ExprNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, ExprNode);
 };
 
 // Statements
 /*!
  * \brief Let binding, bind var to value, then run body.
  */
-class LetStmt : public StmtNode {
+class LetStmtNode : public StmtNode {
  public:
   /*! \brief The variable. */
   Var var;
@@ -744,7 +744,7 @@ class LetStmt : public StmtNode {
   TVM_DLL static Stmt make(Var var, Expr value, Stmt body);
 
   static constexpr const char* _type_key = "LetStmt";
-  TVM_DECLARE_FINAL_OBJECT_INFO(LetStmt, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
 };
 
 /*!
@@ -757,7 +757,7 @@ class LetStmt : public StmtNode {
  *    - Bound of function, variables.
  *    - Hint which block corresponds to a parallel region.
  */
-class AttrStmt : public StmtNode {
+class AttrStmtNode : public StmtNode {
  public:
   /*! \brief this is attribute about certain node */
   ObjectRef node;
@@ -781,13 +781,13 @@ class AttrStmt : public StmtNode {
                            Stmt body);
 
   static constexpr const char* _type_key = "AttrStmt";
-  TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmt, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
 };
 
 /*!
  * \brief Assert condition, if an error occurs, return the error message.
  */
-class AssertStmt : public StmtNode {
+class AssertStmtNode : public StmtNode {
  public:
   /*! \brief Condition to be checked. */
   Expr condition;
@@ -808,12 +808,12 @@ class AssertStmt : public StmtNode {
   TVM_DLL static Stmt make(Expr condition, Expr message, Stmt body);
 
   static constexpr const char* _type_key = "AssertStmt";
-  TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmt, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
 };
 
 // TODO(tvm-team): consider consolidate with AttrStmt.
 /*! \brief annotation node of producer/consumer relation. */
-class ProducerConsumer : public StmtNode {
+class ProducerConsumerNode : public StmtNode {
  public:
   /*! \brief The corresponding tensor. */
   FunctionRef func;
@@ -831,7 +831,7 @@ class ProducerConsumer : public StmtNode {
   TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
 
   static constexpr const char* _type_key = "ProducerConsumer";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumer, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode);
 };
 
 /*!
@@ -850,9 +850,9 @@ class ProducerConsumer : public StmtNode {
  *  buffer[index.v2] = value.v2;
  *
  * \endcode
- * \sa Load
+ * \sa LoadNode
  */
-class Store : public StmtNode {
+class StoreNode : public StmtNode {
  public:
   /*! \brief The buffer variable. */
   Var buffer_var;
@@ -876,13 +876,13 @@ class Store : public StmtNode {
                            Expr predicate);
 
   static constexpr const char* _type_key = "Store";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Store, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
 };
 
 /*!
  * \brief Store value into mult-dimensional array defined by func.
  */
-class Provide : public StmtNode {
+class ProvideNode : public StmtNode {
  public:
   /*! \brief The function to be updated. */
   FunctionRef func;
@@ -906,13 +906,13 @@ class Provide : public StmtNode {
                            Array<Expr> args);
 
   static constexpr const char* _type_key = "Provide";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Provide, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode);
 };
 
 /*!
  * \brief Allocate a buffer that can be used in body.
  */
-class Allocate : public StmtNode {
+class AllocateNode : public StmtNode {
  public:
   /*! \brief The buffer variable. */
   Var buffer_var;
@@ -963,11 +963,11 @@ class Allocate : public StmtNode {
       const Array<Expr>& extents);
 
   static constexpr const char* _type_key = "Allocate";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Allocate, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
 };
 
 /*! \brief Free the resources in the buffer before the scope ends. */
-class Free : public StmtNode {
+class FreeNode : public StmtNode {
  public:
   /*! \brief The buffer variable. */
   Var buffer_var;
@@ -979,14 +979,14 @@ class Free : public StmtNode {
   TVM_DLL static Stmt make(Var buffer_var);
 
   static constexpr const char* _type_key = "Free";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Free, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
 };
 
 /*!
  * \brief Annotate the bounds where func need to be written and read in body.
  *  We will need to allocate space for the corresponding regions.
  */
-class Realize : public StmtNode {
+class RealizeNode : public StmtNode {
  public:
   /*! \brief The function to be realized. */
   FunctionRef func;
@@ -1018,7 +1018,7 @@ class Realize : public StmtNode {
                            Stmt body);
 
   static constexpr const char* _type_key = "Realize";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Realize, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
 };
 
 /*!
@@ -1104,7 +1104,7 @@ class SeqStmt : public Stmt {
       if (!stmt.defined()) return;
       if (auto* op = stmt.as<SeqStmtNode>()) {
         operator()(0, op->seq);
-      } else if (auto* op = stmt.as<ProducerConsumer>()) {
+      } else if (auto* op = stmt.as<ProducerConsumerNode>()) {
         // NOTE: The consumer block annotation was not as useful and can be safely dropped.
         if (!op->is_producer) {
           operator()(0, op->body);
@@ -1133,7 +1133,7 @@ class SeqStmt : public Stmt {
 /*!
  * \brief IfThenElse statment.
  */
-class IfThenElse : public StmtNode {
+class IfThenElseNode : public StmtNode {
  public:
   /*! \brief The condition. */
   Expr condition;
@@ -1151,7 +1151,7 @@ class IfThenElse : public StmtNode {
   TVM_DLL static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
 
   static constexpr const char* _type_key = "IfThenElse";
-  TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElse, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
 };
 
 /*!
@@ -1160,7 +1160,7 @@ class IfThenElse : public StmtNode {
  *
  *  If value do not have side-effect, this node can be safely removed.
  */
-class Evaluate : public StmtNode {
+class EvaluateNode : public StmtNode {
  public:
   /*! \brief The expression to be evaluated. */
   Expr value;
@@ -1172,7 +1172,7 @@ class Evaluate : public StmtNode {
   TVM_DLL static Stmt make(Expr v);
 
   static constexpr const char* _type_key = "Evaluate";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Evaluate, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
 };
 
 /*! \brief Additional annotation of for loop. */
@@ -1204,7 +1204,7 @@ enum class DeviceAPI: int {
  *  }
  * \endcode
  */
-class For : public StmtNode {
+class ForNode : public StmtNode {
  public:
   /*! \brief The loop variable. */
   Var loop_var;
@@ -1239,13 +1239,13 @@ class For : public StmtNode {
   }
 
   static constexpr const char* _type_key = "For";
-  TVM_DECLARE_FINAL_OBJECT_INFO(For, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
 };
 
 /*!
  * \brief A prefetch hint of func.
  */
-class Prefetch : public StmtNode {
+class PrefetchNode : public StmtNode {
  public:
   /*! \brief The function to be prefetched. */
   FunctionRef func;
@@ -1269,7 +1269,7 @@ class Prefetch : public StmtNode {
                            Region bounds);
 
   static constexpr const char* _type_key = "Prefetch";
-  TVM_DECLARE_FINAL_OBJECT_INFO(Prefetch, StmtNode);
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
 };
 
 /*!
@@ -1708,9 +1708,9 @@ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
  * \return Expr a expression with dtype.
  */
 inline Expr TypeAnnotation(DataType dtype) {
-  return ir::Call::make(dtype,
+  return ir::CallNode::make(dtype,
                         "type_annotation", {},
-                        ir::Call::PureIntrinsic);
+                        ir::CallNode::PureIntrinsic);
 }
 
 // overload printing of for type.
index 6cc6d70..d70c8de 100644 (file)
@@ -132,38 +132,38 @@ class ExprFunctor<R(const Expr& n, Args...)> {
     return vtable(n, this, std::forward<Args>(args)...);
   }
   // Functions that can be overriden by subclass
-  virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
-  virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const UIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
+  virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
   virtual R VisitExprDefault_(const Object* op, Args ...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     return R();
@@ -174,38 +174,38 @@ class ExprFunctor<R(const Expr& n, Args...)> {
   static FType InitVTable() {
     FType vtable;
     // Set dispatch
-    IR_EXPR_FUNCTOR_DISPATCH(Variable);
-    IR_EXPR_FUNCTOR_DISPATCH(Load);
-    IR_EXPR_FUNCTOR_DISPATCH(Let);
-    IR_EXPR_FUNCTOR_DISPATCH(Call);
-    IR_EXPR_FUNCTOR_DISPATCH(Add);
-    IR_EXPR_FUNCTOR_DISPATCH(Sub);
-    IR_EXPR_FUNCTOR_DISPATCH(Mul);
-    IR_EXPR_FUNCTOR_DISPATCH(Div);
-    IR_EXPR_FUNCTOR_DISPATCH(Mod);
-    IR_EXPR_FUNCTOR_DISPATCH(FloorDiv);
-    IR_EXPR_FUNCTOR_DISPATCH(FloorMod);
-    IR_EXPR_FUNCTOR_DISPATCH(Min);
-    IR_EXPR_FUNCTOR_DISPATCH(Max);
-    IR_EXPR_FUNCTOR_DISPATCH(EQ);
-    IR_EXPR_FUNCTOR_DISPATCH(NE);
-    IR_EXPR_FUNCTOR_DISPATCH(LT);
-    IR_EXPR_FUNCTOR_DISPATCH(LE);
-    IR_EXPR_FUNCTOR_DISPATCH(GT);
-    IR_EXPR_FUNCTOR_DISPATCH(GE);
-    IR_EXPR_FUNCTOR_DISPATCH(And);
-    IR_EXPR_FUNCTOR_DISPATCH(Or);
-    IR_EXPR_FUNCTOR_DISPATCH(Reduce);
-    IR_EXPR_FUNCTOR_DISPATCH(Cast);
-    IR_EXPR_FUNCTOR_DISPATCH(Not);
-    IR_EXPR_FUNCTOR_DISPATCH(Select);
-    IR_EXPR_FUNCTOR_DISPATCH(Ramp);
-    IR_EXPR_FUNCTOR_DISPATCH(Shuffle);
-    IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
-    IR_EXPR_FUNCTOR_DISPATCH(IntImm);
-    IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
-    IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
-    IR_EXPR_FUNCTOR_DISPATCH(StringImm);
+    IR_EXPR_FUNCTOR_DISPATCH(VarNode);
+    IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
+    IR_EXPR_FUNCTOR_DISPATCH(LetNode);
+    IR_EXPR_FUNCTOR_DISPATCH(CallNode);
+    IR_EXPR_FUNCTOR_DISPATCH(AddNode);
+    IR_EXPR_FUNCTOR_DISPATCH(SubNode);
+    IR_EXPR_FUNCTOR_DISPATCH(MulNode);
+    IR_EXPR_FUNCTOR_DISPATCH(DivNode);
+    IR_EXPR_FUNCTOR_DISPATCH(ModNode);
+    IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode);
+    IR_EXPR_FUNCTOR_DISPATCH(FloorModNode);
+    IR_EXPR_FUNCTOR_DISPATCH(MinNode);
+    IR_EXPR_FUNCTOR_DISPATCH(MaxNode);
+    IR_EXPR_FUNCTOR_DISPATCH(EQNode);
+    IR_EXPR_FUNCTOR_DISPATCH(NENode);
+    IR_EXPR_FUNCTOR_DISPATCH(LTNode);
+    IR_EXPR_FUNCTOR_DISPATCH(LENode);
+    IR_EXPR_FUNCTOR_DISPATCH(GTNode);
+    IR_EXPR_FUNCTOR_DISPATCH(GENode);
+    IR_EXPR_FUNCTOR_DISPATCH(AndNode);
+    IR_EXPR_FUNCTOR_DISPATCH(OrNode);
+    IR_EXPR_FUNCTOR_DISPATCH(ReduceNode);
+    IR_EXPR_FUNCTOR_DISPATCH(CastNode);
+    IR_EXPR_FUNCTOR_DISPATCH(NotNode);
+    IR_EXPR_FUNCTOR_DISPATCH(SelectNode);
+    IR_EXPR_FUNCTOR_DISPATCH(RampNode);
+    IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
+    IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
+    IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
+    IR_EXPR_FUNCTOR_DISPATCH(UIntImmNode);
+    IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
+    IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
     return vtable;
   }
 };
@@ -241,20 +241,20 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     return vtable(n, this, std::forward<Args>(args)...);
   }
   // Functions that can be overriden by subclass
-  virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
+  virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmtDefault_(const Object* op, Args ...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     return R();
@@ -264,20 +264,20 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
   // initialize the vtable.
   static FType InitVTable() {
     FType vtable;
-    IR_STMT_FUNCTOR_DISPATCH(LetStmt);
-    IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
-    IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
-    IR_STMT_FUNCTOR_DISPATCH(For);
-    IR_STMT_FUNCTOR_DISPATCH(Allocate);
-    IR_STMT_FUNCTOR_DISPATCH(Store);
-    IR_STMT_FUNCTOR_DISPATCH(Free);
-    IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
-    IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
-    IR_STMT_FUNCTOR_DISPATCH(Provide);
-    IR_STMT_FUNCTOR_DISPATCH(Realize);
-    IR_STMT_FUNCTOR_DISPATCH(Prefetch);
+    IR_STMT_FUNCTOR_DISPATCH(LetStmtNode);
+    IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
+    IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
+    IR_STMT_FUNCTOR_DISPATCH(ForNode);
+    IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
+    IR_STMT_FUNCTOR_DISPATCH(StoreNode);
+    IR_STMT_FUNCTOR_DISPATCH(FreeNode);
+    IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
+    IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode);
+    IR_STMT_FUNCTOR_DISPATCH(ProvideNode);
+    IR_STMT_FUNCTOR_DISPATCH(RealizeNode);
+    IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
     IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
-    IR_STMT_FUNCTOR_DISPATCH(Evaluate);
+    IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
     return vtable;
   }
 };
@@ -298,38 +298,38 @@ class TVM_DLL ExprVisitor :
  protected:
   using ExprFunctor::VisitExpr;
   // list of functions to override.
-  void VisitExpr_(const Variable* op) override;
-  void VisitExpr_(const Load* op) override;
-  void VisitExpr_(const Let* op) override;
-  void VisitExpr_(const Call* op) override;
-  void VisitExpr_(const Add* op) override;
-  void VisitExpr_(const Sub* op) override;
-  void VisitExpr_(const Mul* op) override;
-  void VisitExpr_(const Div* op) override;
-  void VisitExpr_(const Mod* op) override;
-  void VisitExpr_(const FloorDiv* op) override;
-  void VisitExpr_(const FloorMod* op) override;
-  void VisitExpr_(const Min* op) override;
-  void VisitExpr_(const Max* op) override;
-  void VisitExpr_(const EQ* op) override;
-  void VisitExpr_(const NE* op) override;
-  void VisitExpr_(const LT* op) override;
-  void VisitExpr_(const LE* op) override;
-  void VisitExpr_(const GT* op) override;
-  void VisitExpr_(const GE* op) override;
-  void VisitExpr_(const And* op) override;
-  void VisitExpr_(const Or* op) override;
-  void VisitExpr_(const Reduce* op) override;
-  void VisitExpr_(const Cast* op) override;
-  void VisitExpr_(const Not* op) override;
-  void VisitExpr_(const Select* op) override;
-  void VisitExpr_(const Ramp* op) override;
-  void VisitExpr_(const Broadcast* op) override;
-  void VisitExpr_(const Shuffle* op) override;
-  void VisitExpr_(const IntImm* op) override;
-  void VisitExpr_(const UIntImm* op) override;
-  void VisitExpr_(const FloatImm* op) override;
-  void VisitExpr_(const StringImm* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitExpr_(const LetNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const AddNode* op) override;
+  void VisitExpr_(const SubNode* op) override;
+  void VisitExpr_(const MulNode* op) override;
+  void VisitExpr_(const DivNode* op) override;
+  void VisitExpr_(const ModNode* op) override;
+  void VisitExpr_(const FloorDivNode* op) override;
+  void VisitExpr_(const FloorModNode* op) override;
+  void VisitExpr_(const MinNode* op) override;
+  void VisitExpr_(const MaxNode* op) override;
+  void VisitExpr_(const EQNode* op) override;
+  void VisitExpr_(const NENode* op) override;
+  void VisitExpr_(const LTNode* op) override;
+  void VisitExpr_(const LENode* op) override;
+  void VisitExpr_(const GTNode* op) override;
+  void VisitExpr_(const GENode* op) override;
+  void VisitExpr_(const AndNode* op) override;
+  void VisitExpr_(const OrNode* op) override;
+  void VisitExpr_(const ReduceNode* op) override;
+  void VisitExpr_(const CastNode* op) override;
+  void VisitExpr_(const NotNode* op) override;
+  void VisitExpr_(const SelectNode* op) override;
+  void VisitExpr_(const RampNode* op) override;
+  void VisitExpr_(const BroadcastNode* op) override;
+  void VisitExpr_(const ShuffleNode* op) override;
+  void VisitExpr_(const IntImmNode* op) override;
+  void VisitExpr_(const UIntImmNode* op) override;
+  void VisitExpr_(const FloatImmNode* op) override;
+  void VisitExpr_(const StringImmNode* op) override;
 };
 
 /*!
@@ -343,38 +343,38 @@ class TVM_DLL ExprMutator :
  protected:
   using ExprFunctor::VisitExpr;
   // list of functions to override.
-  Expr VisitExpr_(const Variable* op) override;
-  Expr VisitExpr_(const Load* op) override;
-  Expr VisitExpr_(const Let* op) override;
-  Expr VisitExpr_(const Call* op) override;
-  Expr VisitExpr_(const Add* op) override;
-  Expr VisitExpr_(const Sub* op) override;
-  Expr VisitExpr_(const Mul* op) override;
-  Expr VisitExpr_(const Div* op) override;
-  Expr VisitExpr_(const Mod* op) override;
-  Expr VisitExpr_(const FloorDiv* op) override;
-  Expr VisitExpr_(const FloorMod* op) override;
-  Expr VisitExpr_(const Min* op) override;
-  Expr VisitExpr_(const Max* op) override;
-  Expr VisitExpr_(const EQ* op) override;
-  Expr VisitExpr_(const NE* op) override;
-  Expr VisitExpr_(const LT* op) override;
-  Expr VisitExpr_(const LE* op) override;
-  Expr VisitExpr_(const GT* op) override;
-  Expr VisitExpr_(const GE* op) override;
-  Expr VisitExpr_(const And* op) override;
-  Expr VisitExpr_(const Or* op) override;
-  Expr VisitExpr_(const Reduce* op) override;
-  Expr VisitExpr_(const Cast* op) override;
-  Expr VisitExpr_(const Not* op) override;
-  Expr VisitExpr_(const Select* op) override;
-  Expr VisitExpr_(const Ramp* op) override;
-  Expr VisitExpr_(const Broadcast* op) override;
-  Expr VisitExpr_(const Shuffle* op) override;
-  Expr VisitExpr_(const IntImm* op) override;
-  Expr VisitExpr_(const UIntImm* op) override;
-  Expr VisitExpr_(const FloatImm* op) override;
-  Expr VisitExpr_(const StringImm* op) override;
+  Expr VisitExpr_(const VarNode* op) override;
+  Expr VisitExpr_(const LoadNode* op) override;
+  Expr VisitExpr_(const LetNode* op) override;
+  Expr VisitExpr_(const CallNode* op) override;
+  Expr VisitExpr_(const AddNode* op) override;
+  Expr VisitExpr_(const SubNode* op) override;
+  Expr VisitExpr_(const MulNode* op) override;
+  Expr VisitExpr_(const DivNode* op) override;
+  Expr VisitExpr_(const ModNode* op) override;
+  Expr VisitExpr_(const FloorDivNode* op) override;
+  Expr VisitExpr_(const FloorModNode* op) override;
+  Expr VisitExpr_(const MinNode* op) override;
+  Expr VisitExpr_(const MaxNode* op) override;
+  Expr VisitExpr_(const EQNode* op) override;
+  Expr VisitExpr_(const NENode* op) override;
+  Expr VisitExpr_(const LTNode* op) override;
+  Expr VisitExpr_(const LENode* op) override;
+  Expr VisitExpr_(const GTNode* op) override;
+  Expr VisitExpr_(const GENode* op) override;
+  Expr VisitExpr_(const AndNode* op) override;
+  Expr VisitExpr_(const OrNode* op) override;
+  Expr VisitExpr_(const ReduceNode* op) override;
+  Expr VisitExpr_(const CastNode* op) override;
+  Expr VisitExpr_(const NotNode* op) override;
+  Expr VisitExpr_(const SelectNode* op) override;
+  Expr VisitExpr_(const RampNode* op) override;
+  Expr VisitExpr_(const BroadcastNode* op) override;
+  Expr VisitExpr_(const ShuffleNode* op) override;
+  Expr VisitExpr_(const IntImmNode* op) override;
+  Expr VisitExpr_(const UIntImmNode* op) override;
+  Expr VisitExpr_(const FloatImmNode* op) override;
+  Expr VisitExpr_(const StringImmNode* op) override;
 };
 
 /*!
@@ -396,20 +396,20 @@ class TVM_DLL StmtVisitor :
    */
   virtual void VisitExpr(const Expr& e) {}
   // statement visitor
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const IfThenElse* op) override;
-  void VisitStmt_(const LetStmt* op) override;
-  void VisitStmt_(const For* op) override;
-  void VisitStmt_(const Allocate* op) override;
-  void VisitStmt_(const Store* op) override;
-  void VisitStmt_(const Free* op) override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const ProducerConsumer* op) override;
-  void VisitStmt_(const Provide* op) override;
-  void VisitStmt_(const Realize* op) override;
-  void VisitStmt_(const Prefetch* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const IfThenElseNode* op) override;
+  void VisitStmt_(const LetStmtNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const FreeNode* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const ProducerConsumerNode* op) override;
+  void VisitStmt_(const ProvideNode* op) override;
+  void VisitStmt_(const RealizeNode* op) override;
+  void VisitStmt_(const PrefetchNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
-  void VisitStmt_(const Evaluate* op) override;
+  void VisitStmt_(const EvaluateNode* op) override;
 };
 
 /*!
@@ -490,20 +490,20 @@ class TVM_DLL StmtMutator :
     return e;
   }
   // statement visitor
-  Stmt VisitStmt_(const AttrStmt* op) override;
-  Stmt VisitStmt_(const IfThenElse* op) override;
-  Stmt VisitStmt_(const LetStmt* op) override;
-  Stmt VisitStmt_(const For* op) override;
-  Stmt VisitStmt_(const Allocate* op) override;
-  Stmt VisitStmt_(const Store* op) override;
-  Stmt VisitStmt_(const Free* op) override;
-  Stmt VisitStmt_(const AssertStmt* op) override;
-  Stmt VisitStmt_(const ProducerConsumer* op) override;
-  Stmt VisitStmt_(const Provide* op) override;
-  Stmt VisitStmt_(const Realize* op) override;
-  Stmt VisitStmt_(const Prefetch* op) override;
+  Stmt VisitStmt_(const AttrStmtNode* op) override;
+  Stmt VisitStmt_(const IfThenElseNode* op) override;
+  Stmt VisitStmt_(const LetStmtNode* op) override;
+  Stmt VisitStmt_(const ForNode* op) override;
+  Stmt VisitStmt_(const AllocateNode* op) override;
+  Stmt VisitStmt_(const StoreNode* op) override;
+  Stmt VisitStmt_(const FreeNode* op) override;
+  Stmt VisitStmt_(const AssertStmtNode* op) override;
+  Stmt VisitStmt_(const ProducerConsumerNode* op) override;
+  Stmt VisitStmt_(const ProvideNode* op) override;
+  Stmt VisitStmt_(const RealizeNode* op) override;
+  Stmt VisitStmt_(const PrefetchNode* op) override;
   Stmt VisitStmt_(const SeqStmtNode* op) override;
-  Stmt VisitStmt_(const Evaluate* op) override;
+  Stmt VisitStmt_(const EvaluateNode* op) override;
   /*!
    * \brief Alternative advance method for SeqStmtNode.
    *
index 5a81d59..aa1415e 100644 (file)
@@ -132,7 +132,7 @@ bool ExprUseVar(const Expr& e, const Var& v);
  * \param vset The variable set.
  * \return Whether e uses vset.
  */
-bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
+bool ExprUseVar(const Expr& e, const std::unordered_set<const VarNode*>& vset);
 
 /*!
  * \brief Convert a IR node to be SSA form.
@@ -148,7 +148,7 @@ TVM_DLL Stmt ConvertSSA(Stmt stmt);
  * \return The converted form.
  */
 Stmt Substitute(Stmt stmt,
-                const std::unordered_map<const Variable*, Expr>& value_map);
+                const std::unordered_map<const VarNode*, Expr>& value_map);
 
 /*!
  * \brief Substitute the var specified in key->var to be value.
@@ -157,7 +157,7 @@ Stmt Substitute(Stmt stmt,
  * \return The converted expression.
  */
 Expr Substitute(Expr expr,
-                const std::unordered_map<const Variable*, Expr>& value_map);
+                const std::unordered_map<const VarNode*, Expr>& value_map);
 
 /*!
  * \brief Substitute the var specified in key->var to be value.
index 681d068..ad8f825 100644 (file)
@@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode {
   virtual void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
   /*!
    * \brief Gather the bound from output tensor.
@@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   void GatherBound(
       const Operation& self,
@@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   Stmt BuildProvide(
       const Stage& stage,
@@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   Stmt BuildProvide(
       const Stage& stage,
@@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   void GatherBound(
       const Operation& self,
@@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   void GatherBound(
       const Operation& self,
@@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode {
   void PropBoundToInputs(
       const Operation& self,
       arith::Analyzer* analyzer,
-      const std::unordered_map<const Variable*, IntSet>& dom_map,
+      const std::unordered_map<const VarNode*, IntSet>& dom_map,
       std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
   void GatherBound(
       const Operation& self,
index c8a02a8..31e85f9 100644 (file)
@@ -39,7 +39,7 @@
 namespace tvm {
 namespace relay {
 
-using Any = tvm::ir::Any;
+using Any = tvm::ir::AnyNode;
 using Kind = TypeKind;
 using Type = tvm::Type;
 using TypeNode = tvm::TypeNode;
index 034405f..ba04239 100644 (file)
@@ -33,7 +33,7 @@ namespace ir {
 
 TVM_REGISTER_GLOBAL("_Var")
 .set_body_typed([](std::string s, DataType t) {
-    return Variable::make(t, s);
+    return VarNode::make(t, s);
   });
 
 TVM_REGISTER_GLOBAL("make.abs")
@@ -73,7 +73,7 @@ TVM_REGISTER_GLOBAL("make.For")
 .set_body_typed([](
   VarExpr loop_var, Expr min, Expr extent,
   int for_type, int device_api, Stmt body) {
-  return For::make(loop_var,
+  return ForNode::make(loop_var,
                    min,
                    extent,
                    static_cast<ForType>(for_type),
@@ -85,9 +85,9 @@ TVM_REGISTER_GLOBAL("make.Load")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
     DataType t = args[0];
     if (args.size() == 3) {
-      *ret = Load::make(t, args[1], args[2], const_true(t.lanes()));
+      *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes()));
     } else {
-      *ret = Load::make(t, args[1], args[2], args[3]);
+      *ret = LoadNode::make(t, args[1], args[2], args[3]);
     }
   });
 
@@ -95,14 +95,14 @@ TVM_REGISTER_GLOBAL("make.Store")
 .set_body([](TVMArgs args,  TVMRetValue *ret) {
     Expr value = args[1];
     if (args.size() == 3) {
-      *ret = Store::make(args[0], value, args[2], const_true(value.dtype().lanes()));
+      *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
     } else {
-      *ret = Store::make(args[0], value, args[2], args[3]);
+      *ret = StoreNode::make(args[0], value, args[2], args[3]);
     }
   });
 
 TVM_REGISTER_GLOBAL("make.Realize")
-.set_body_typed(Realize::make);
+.set_body_typed(RealizeNode::make);
 
 TVM_REGISTER_GLOBAL("make.Call")
 .set_body_typed([](
@@ -110,10 +110,10 @@ TVM_REGISTER_GLOBAL("make.Call")
   Array<Expr> args, int call_type,
   FunctionRef func, int value_index
 ) {
-  return Call::make(type,
+  return CallNode::make(type,
                     name,
                     args,
-                    static_cast<Call::CallType>(call_type),
+                    static_cast<CallNode::CallType>(call_type),
                     func,
                     value_index);
 });
@@ -122,9 +122,10 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
 .set_body_typed(CommReducerNode::make);
 
 // make from two arguments
-#define REGISTER_MAKE(Node)                                     \
-  TVM_REGISTER_GLOBAL("make."#Node)                             \
-  .set_body_typed(Node::make);                                  \
+#define REGISTER_MAKE(NodeName)                                     \
+  TVM_REGISTER_GLOBAL("make."#NodeName)                             \
+  .set_body_typed(NodeName ## Node::make);                          \
+
 
 REGISTER_MAKE(Reduce);
 REGISTER_MAKE(AttrStmt);
@@ -174,7 +175,7 @@ TVM_REGISTER_GLOBAL("make.Allocate")
   .set_body_typed([](
     VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
   ){
-    return Allocate::make(buffer_var, type, extents, condition, body);
+    return AllocateNode::make(buffer_var, type, extents, condition, body);
   });
 
 // operator overloading, smarter than make
index 804d8f1..4e635ad 100644 (file)
@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("_const")
   });
 
 TVM_REGISTER_GLOBAL("_str")
-.set_body_typed(ir::StringImm::make);
+.set_body_typed(ir::StringImmNode::make);
 
 
 TVM_REGISTER_GLOBAL("_Array")
@@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("_MapItems")
       auto* n = static_cast<const StrMapNode*>(ptr);
       auto rkvs = make_object<ArrayNode>();
       for (const auto& kv : n->data) {
-        rkvs->data.push_back(ir::StringImm::make(kv.first));
+        rkvs->data.push_back(ir::StringImmNode::make(kv.first));
         rkvs->data.push_back(kv.second);
       }
       *ret = Array<ObjectRef>(rkvs);
index 404f88d..68e0b05 100644 (file)
@@ -78,7 +78,7 @@ void ConstraintContext::ExitWithScope() {
 }
 
 bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
-  if (const auto* ptr = expr.as<ir::IntImm>()) {
+  if (const auto* ptr = expr.as<ir::IntImmNode>()) {
     return ptr->value >= lower_bound;
   }
   auto bd = this->const_int_bound(this->rewrite_simplify(expr));
@@ -87,15 +87,15 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
 }
 
 bool Analyzer::CanProve(const Expr& expr) {
-  if (const auto* ptr = expr.as<ir::UIntImm>()) {
+  if (const auto* ptr = expr.as<ir::UIntImmNode>()) {
     return ptr->value != 0;
   }
   auto res = this->rewrite_simplify(expr);
-  if (const auto* ptr = res.as<ir::UIntImm>()) {
+  if (const auto* ptr = res.as<ir::UIntImmNode>()) {
     return ptr->value != 0;
   }
   res = this->canonical_simplify(expr);
-  if (const auto* ptr = res.as<ir::UIntImm>()) {
+  if (const auto* ptr = res.as<ir::UIntImmNode>()) {
     return ptr->value != 0;
   }
   return false;
index bb2e340..40f86de 100644 (file)
@@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor {
   friend class BoundDeduceInputChecker;
   friend class Converter;
   BoundDeducer(Expr target, Expr expr,
-               const std::unordered_map<const Variable*, IntSet>& hint_map,
-               const std::unordered_map<const Variable*, IntSet>& relax_map)
+               const std::unordered_map<const VarNode*, IntSet>& hint_map,
+               const std::unordered_map<const VarNode*, IntSet>& relax_map)
   : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
 
   void Deduce();
@@ -94,29 +94,29 @@ class BoundDeducer: public ExprVisitor {
     }
   }
 
-  void VisitExpr_(const LT* op) final {
+  void VisitExpr_(const LTNode* op) final {
     LOG(FATAL) << "unable to deduce due to multiple comparison operator";
   }
 
-  void VisitExpr_(const LE* op) final {
+  void VisitExpr_(const LENode* op) final {
     LOG(FATAL) << "unable to deduce due to multiple comparison operator";
   }
 
-  void VisitExpr_(const GT* op) final {
+  void VisitExpr_(const GTNode* op) final {
     LOG(FATAL) << "unable to deduce due to multiple comparison operator";
   }
 
-  void VisitExpr_(const GE* op) final {
+  void VisitExpr_(const GENode* op) final {
     LOG(FATAL) << "unable to deduce due to multiple comparison operator";
   }
 
-  void VisitExpr_(const Add* op) final {
+  void VisitExpr_(const AddNode* op) final {
     bool left = op->a.get() == path_[iter_];
     result_ -= left ? op->b : op->a;
     this->VisitExpr(left ? op->a : op->b);
   }
 
-  void VisitExpr_(const Sub* op) final {
+  void VisitExpr_(const SubNode* op) final {
     bool left = op->a.get() == path_[iter_];
     if (left) {
       result_ += op->b;
@@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor {
     this->VisitExpr(left ? op->a : op->b);
   }
 
-  void VisitExpr_(const Mul* op) final {
+  void VisitExpr_(const MulNode* op) final {
     bool left = op->a.get() == path_[iter_];
     Expr operand = left ? op->b : op->a;
     Expr target_var = left ? op->a : op->b;
@@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor {
   CompareOp ReverseOp(CompareOp comp_op);
   Expr target_;
   Expr expr_;
-  const std::unordered_map<const Variable*, IntSet>& hint_map_;
-  const std::unordered_map<const Variable*, IntSet>& relax_map_;
+  const std::unordered_map<const VarNode*, IntSet>& hint_map_;
+  const std::unordered_map<const VarNode*, IntSet>& relax_map_;
   ExprIntSetMap expr_map_;
   std::vector<const Object*> path_;
   size_t iter_{0};
@@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
 
 void BoundDeducer::Transform() {
   // We will ensure to set expr_ such that it contains target_
-  if (const LT* op = expr_.as<LT>()) {
+  if (const LTNode* op = expr_.as<LTNode>()) {
     if (GetPath(target_, op->a).empty()) {
       // a < b -> b >= a + 1
       comp_op = kGreater;
@@ -245,7 +245,7 @@ void BoundDeducer::Transform() {
       expr_ = op->a;
       result_ = op->b - 1;
     }
-  } else if (const LE* op = expr_.as<LE>()) {
+  } else if (const LENode* op = expr_.as<LENode>()) {
     if (GetPath(target_, op->a).empty()) {
       // a <= b -> b >= a
       comp_op = kGreater;
@@ -256,7 +256,7 @@ void BoundDeducer::Transform() {
       expr_ = op->a;
       result_ = op->b;
     }
-  } else if (const GT* op = expr_.as<GT>()) {
+  } else if (const GTNode* op = expr_.as<GTNode>()) {
     if (GetPath(target_, op->a).empty()) {
       // a > b -> b <= a - 1
       comp_op = kLess;
@@ -268,7 +268,7 @@ void BoundDeducer::Transform() {
       expr_ = op->a;
       result_ = op->b + 1;
     }
-  } else if (const GE* op = expr_.as<GE>()) {
+  } else if (const GENode* op = expr_.as<GENode>()) {
     if (GetPath(target_, op->a).empty()) {
       // a >= b -> b <= a
       comp_op = kLess;
@@ -279,7 +279,7 @@ void BoundDeducer::Transform() {
       expr_ = op->a;
       result_ = op->b;
     }
-  } else if (const EQ* op = expr_.as<EQ>()) {
+  } else if (const EQNode* op = expr_.as<EQNode>()) {
     comp_op = kEqual;
     if (GetPath(target_, op->a).empty()) {
       // if the b == a -> a == b
@@ -330,8 +330,8 @@ void BoundDeducer::Relax() {
 }
 
 IntSet DeduceBound(Expr v, Expr e,
-  const std::unordered_map<const Variable*, IntSet>& hint_map,
-  const std::unordered_map<const Variable*, IntSet>& relax_map) {
+  const std::unordered_map<const VarNode*, IntSet>& hint_map,
+  const std::unordered_map<const VarNode*, IntSet>& relax_map) {
   BoundDeducer d(v, e, hint_map, relax_map);
   d.Deduce();
   if (!d.success_) return IntSet::nothing();
@@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e,
 IntSet DeduceBound(Expr v, Expr e,
                    const Map<Var, IntSet>& hint_map,
                    const Map<Var, IntSet>& relax_map) {
-  std::unordered_map<const Variable*, IntSet> hmap;
+  std::unordered_map<const VarNode*, IntSet> hmap;
   for (auto kv : hint_map) {
     hmap[kv.first.get()] = kv.second;
   }
-  std::unordered_map<const Variable*, IntSet> rmap;
+  std::unordered_map<const VarNode*, IntSet> rmap;
   for (auto kv : relax_map) {
     rmap[kv.first.get()] = kv.second;
   }
index d05ee2d..e33b0c5 100644 (file)
@@ -450,14 +450,14 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
   }
 
   using Rewriter::VisitExpr_;
-  Expr VisitExpr_(const Add* op) final;
-  Expr VisitExpr_(const Sub* op) final;
-  Expr VisitExpr_(const Mul* op) final;
-  Expr VisitExpr_(const Div* op) final;
-  Expr VisitExpr_(const Mod* op) final;
-  Expr VisitExpr_(const FloorDiv* op) final;
-  Expr VisitExpr_(const FloorMod* op) final;
-  Expr VisitExpr_(const Reduce* op) final;
+  Expr VisitExpr_(const AddNode* op) final;
+  Expr VisitExpr_(const SubNode* op) final;
+  Expr VisitExpr_(const MulNode* op) final;
+  Expr VisitExpr_(const DivNode* op) final;
+  Expr VisitExpr_(const ModNode* op) final;
+  Expr VisitExpr_(const FloorDivNode* op) final;
+  Expr VisitExpr_(const FloorModNode* op) final;
+  Expr VisitExpr_(const ReduceNode* op) final;
 
  private:
   /*!
@@ -553,7 +553,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
     }
     ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
     n->dtype = expr.dtype();
-    if (const auto* op = expr.as<IntImm>()) {
+    if (const auto* op = expr.as<IntImmNode>()) {
       n->base = op->value;
       return SumExpr(n);
     } else {
@@ -562,11 +562,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
     }
   }
   // Simplify the combiner used in reduce.
-  Expr SimplifyReduceCombiner(const Reduce* op);
+  Expr SimplifyReduceCombiner(const ReduceNode* op);
 };
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Add* op) {
+VisitExpr_(const AddNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -575,13 +575,13 @@ VisitExpr_(const Add* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<Add>(a, b);
+  Expr const_res = TryConstFold<AddNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // canonical form simplification.
   SumExpr ret = ToSumExpr(std::move(a));
 
-  if (const auto* op = b.as<IntImm>()) {
+  if (const auto* op = b.as<IntImmNode>()) {
     ret.CopyOnWrite()->AddToSelf(op->value);
   } else if (const auto* op = b.as<SumExprNode>()) {
     ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
@@ -592,7 +592,7 @@ VisitExpr_(const Add* op) {
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Sub* op) {
+VisitExpr_(const SubNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -601,13 +601,13 @@ VisitExpr_(const Sub* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<Sub>(a, b);
+  Expr const_res = TryConstFold<SubNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // canonical form simplification.
   SumExpr ret = ToSumExpr(std::move(a));
 
-  if (const auto* op = b.as<IntImm>()) {
+  if (const auto* op = b.as<IntImmNode>()) {
     ret.CopyOnWrite()->AddToSelf(-op->value);
   } else if (const auto* op = b.as<SumExprNode>()) {
     ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
@@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) {
 
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Mul* op) {
+VisitExpr_(const MulNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -628,14 +628,14 @@ VisitExpr_(const Mul* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<Mul>(a, b);
+  Expr const_res = TryConstFold<MulNode>(a, b);
   if (const_res.defined()) return const_res;
 
   // x * c
-  if (a.as<IntImm>()) {
+  if (a.as<IntImmNode>()) {
     std::swap(a, b);
   }
-  if (const auto* bconst = b.as<IntImm>()) {
+  if (const auto* bconst = b.as<IntImmNode>()) {
     if (a.as<SumExprNode>()) {
       SumExpr ret = Downcast<SumExpr>(std::move(a));
       ret.CopyOnWrite()->MulToSelf(bconst->value);
@@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) {
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<Expr>(op);
   } else {
-    return Mul::make(a, b);
+    return MulNode::make(a, b);
   }
 }
 
@@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Div* op) {
+VisitExpr_(const DivNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -735,7 +735,7 @@ VisitExpr_(const Div* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<Div>(a, b);
+  Expr const_res = TryConstFold<DivNode>(a, b);
   if (const_res.defined()) return const_res;
   PVar<Integer> c1;
   // x / c1
@@ -756,7 +756,7 @@ VisitExpr_(const Div* op) {
           analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
         lhs.CopyOnWrite()->DivideBy(cval);
         Expr temp = Normalize(extra);
-        if (const auto* pconst = temp.as<IntImm>()) {
+        if (const auto* pconst = temp.as<IntImmNode>()) {
           lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
         } else {
           // if 0 <= extra < cval, it means the extra can be eliminated.
@@ -782,12 +782,12 @@ VisitExpr_(const Div* op) {
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<Expr>(op);
   } else {
-    return Div::make(a, b);
+    return DivNode::make(a, b);
   }
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorDiv* op) {
+VisitExpr_(const FloorDivNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<FloorDiv>(a, b);
+  Expr const_res = TryConstFold<FloorDivNode>(a, b);
   if (const_res.defined()) return const_res;
   PVar<Integer> c1;
   // x / c1
@@ -813,7 +813,7 @@ VisitExpr_(const FloorDiv* op) {
       // continue simplification.
       lhs.CopyOnWrite()->DivideBy(cval);
       Expr temp = Normalize(extra);
-      if (const auto* pconst = temp.as<IntImm>()) {
+      if (const auto* pconst = temp.as<IntImmNode>()) {
         lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
       } else {
         // if 0 <= extra < cval, it means the extra can be eliminated.
@@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) {
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<Expr>(op);
   } else {
-    return FloorDiv::make(a, b);
+    return FloorDivNode::make(a, b);
   }
 }
 
@@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Mod* op) {
+VisitExpr_(const ModNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<Mod>(a, b);
+  Expr const_res = TryConstFold<ModNode>(a, b);
   if (const_res.defined()) return const_res;
 
   PVar<Integer> c1;
@@ -919,7 +919,7 @@ VisitExpr_(const Mod* op) {
       if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
           analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
         Expr temp = Normalize(extra);
-        if (temp.as<IntImm>()) {
+        if (temp.as<IntImmNode>()) {
           return truncmod(temp, c1.Eval());
         } else {
           // If temp < cval && temp >=0 then can remove the mod.
@@ -958,12 +958,12 @@ VisitExpr_(const Mod* op) {
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<Expr>(op);
   } else {
-    return Mod::make(a, b);
+    return ModNode::make(a, b);
   }
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const FloorMod* op) {
+VisitExpr_(const FloorModNode* op) {
   if (!IsIndexType(op->dtype)) {
     return Rewriter::VisitExpr_(op);
   }
@@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) {
   Expr b = this->CanonicalMutate(op->b);
 
   // const folding
-  Expr const_res = TryConstFold<FloorMod>(a, b);
+  Expr const_res = TryConstFold<FloorModNode>(a, b);
   if (const_res.defined()) return const_res;
 
   PVar<Integer> c1;
@@ -983,7 +983,7 @@ VisitExpr_(const FloorMod* op) {
       SumExpr lhs, extra;
       SeparateDivisibleParts(psum, cval, &lhs, &extra);
       Expr temp = Normalize(extra);
-      if (temp.as<IntImm>()) {
+      if (temp.as<IntImmNode>()) {
         return floormod(temp, c1.Eval());
       } else {
         // If temp < cval && temp >=0 then can remove the mod.
@@ -1018,13 +1018,13 @@ VisitExpr_(const FloorMod* op) {
   if (op->a.same_as(a) && op->b.same_as(b)) {
     return GetRef<Expr>(op);
   } else {
-    return FloorMod::make(a, b);
+    return FloorModNode::make(a, b);
   }
 }
 
 // Simplify reduce expression.
 Expr CanonicalSimplifier::Impl::
-SimplifyReduceCombiner(const Reduce* op) {
+SimplifyReduceCombiner(const ReduceNode* op) {
   // First simplify the results
   Array<Expr> simplified_result;
   for (const auto& res : op->combiner->result) {
@@ -1089,15 +1089,15 @@ SimplifyReduceCombiner(const Reduce* op) {
 
   CommReducer new_combiner =
       CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
-  return Reduce::make(
+  return ReduceNode::make(
       new_combiner, new_source, op->axis, op->condition, new_value_index);
 }
 
 Expr CanonicalSimplifier::Impl::
-VisitExpr_(const Reduce* op) {
+VisitExpr_(const ReduceNode* op) {
   // Recursively call simplification when necessary.
   Expr ret = RewriteSimplifier::Impl::VisitExpr_(op);
-  op = ret.as<Reduce>();
+  op = ret.as<ReduceNode>();
   // already been simplified by const reduction axis removal
   if (op == nullptr) return ret;
   if (op->axis.empty()) {
@@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) {
     // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
     // instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
     return this->VisitExpr(
-        Select::make(op->condition,
+        SelectNode::make(op->condition,
                      op->source[op->value_index],
                      op->combiner->identity_element[op->value_index]));
   }
index 806587a..aca26e8 100644 (file)
@@ -77,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
 }
 
 template<>
-inline Expr Compute<ir::Add>(Expr a, Expr b) {
+inline Expr Compute<ir::AddNode>(Expr a, Expr b) {
   return a + b;
 }
 
 template<>
-inline Expr Compute<ir::Sub>(Expr a, Expr b) {
+inline Expr Compute<ir::SubNode>(Expr a, Expr b) {
   return a - b;
 }
 
 template<>
-inline Expr Compute<ir::Mul>(Expr a, Expr b) {
+inline Expr Compute<ir::MulNode>(Expr a, Expr b) {
   return a * b;
 }
 
 template<>
-inline Expr Compute<ir::Div>(Expr a, Expr b) {
+inline Expr Compute<ir::DivNode>(Expr a, Expr b) {
   return truncdiv(a, b);
 }
 
 template<>
-inline Expr Compute<ir::Mod>(Expr a, Expr b) {
+inline Expr Compute<ir::ModNode>(Expr a, Expr b) {
   return truncmod(a, b);
 }
 
 template<>
-inline Expr Compute<ir::Max>(Expr a, Expr b) {
+inline Expr Compute<ir::MaxNode>(Expr a, Expr b) {
   return max(a, b);
 }
 
 template<>
-inline Expr Compute<ir::Min>(Expr a, Expr b) {
+inline Expr Compute<ir::MinNode>(Expr a, Expr b) {
   return min(a, b);
 }
 
index 8b4ea2f..db98a7e 100644 (file)
@@ -76,21 +76,21 @@ inline bool IsIndexType(const DataType& type) {
 
 
 #define TVM_ARITH_CONST_PROPAGATION(BODY)                               \
-  using ir::IntImm;                                                     \
-  using ir::UIntImm;                                                    \
-  using ir::FloatImm;                                                   \
-  const IntImm* pa = a.as<IntImm>();                                    \
-  const IntImm* pb = b.as<IntImm>();                                    \
-  const FloatImm* fa = a.as<FloatImm>();                                \
-  const FloatImm* fb = b.as<FloatImm>();                                \
+  using ir::IntImmNode;                                                 \
+  using ir::UIntImmNode;                                                \
+  using ir::FloatImmNode;                                               \
+  const IntImmNode* pa = a.as<IntImmNode>();                            \
+  const IntImmNode* pb = b.as<IntImmNode>();                            \
+  const FloatImmNode* fa = a.as<FloatImmNode>();                        \
+  const FloatImmNode* fb = b.as<FloatImmNode>();                        \
   BODY;
 
 
 #define TVM_INDEX_CONST_PROPAGATION(BODY)                               \
-  using ir::IntImm;                                                     \
-  using ir::UIntImm;                                                    \
-  const IntImm* pa = a.as<IntImm>();                                    \
-  const IntImm* pb = b.as<IntImm>();                                    \
+  using ir::IntImmNode;                                                 \
+  using ir::UIntImmNode;                                                \
+  const IntImmNode* pa = a.as<IntImmNode>();                            \
+  const IntImmNode* pb = b.as<IntImmNode>();                            \
   const DataType& ta = a.dtype();                                       \
   const DataType& tb = b.dtype();                                       \
   if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) {               \
@@ -100,13 +100,13 @@ inline bool IsIndexType(const DataType& type) {
 
 // specialization of constant folders.
 template<>
-inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::AddNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
+      if (pa && pb) return IntImmNode::make(rtype, pa->value + pb->value);
       if (pa && pa->value == 0) return b;
       if (pb && pb->value == 0) return a;
-      if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
+      if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
       if (fa && fa->value == 0) return b;
       if (fb && fb->value == 0) return a;
     });
@@ -114,22 +114,22 @@ inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::SubNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
+      if (pa && pb) return IntImmNode::make(rtype, pa->value - pb->value);
       if (pb && pb->value == 0) return a;
-      if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
+      if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
       if (fb && fb->value == 0) return a;
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MulNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
+      if (pa && pb) return IntImmNode::make(rtype, pa->value * pb->value);
       if (pa) {
         if (pa->value == 1) return b;
         if (pa->value == 0) return a;
@@ -138,7 +138,7 @@ inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
         if (pb->value == 1) return a;
         if (pb->value == 0) return b;
       }
-      if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
+      if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value);
       if (fa) {
         if (fa->value == 1) return b;
         if (fa->value == 0) return a;
@@ -152,14 +152,14 @@ inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::DivNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
         // due to division and mod can have different modes
         // NOTE: this will assumes truc div.
         CHECK_NE(pb->value, 0) << "Divide by zero";
-        return IntImm::make(rtype, pa->value / pb->value);
+        return IntImmNode::make(rtype, pa->value / pb->value);
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -169,7 +169,7 @@ inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
       if (fa && fb && fb->value != 0) {
-        return FloatImm::make(rtype, fa->value / fb->value);
+        return FloatImmNode::make(rtype, fa->value / fb->value);
       }
       if (fa && fa->value == 0) return a;
       if (fb) {
@@ -181,11 +181,11 @@ inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::ModNode>(Expr a, Expr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
-        return IntImm::make(rtype, pa->value % pb->value);
+        return IntImmNode::make(rtype, pa->value % pb->value);
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -199,12 +199,12 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::FloorDivNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
-        return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
+        return IntImmNode::make(rtype, arith::floordiv(pa->value, pb->value));
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -214,7 +214,7 @@ inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
         CHECK_NE(pb->value, 0) << "Divide by zero";
       }
       if (fa && fb && fb->value != 0) {
-        return FloatImm::make(rtype, std::floor(fa->value / fb->value));
+        return FloatImmNode::make(rtype, std::floor(fa->value / fb->value));
       }
       if (fa && fa->value == 0) return a;
       if (fb) {
@@ -226,11 +226,11 @@ inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::FloorModNode>(Expr a, Expr b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
-        return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
+        return IntImmNode::make(rtype, arith::floormod(pa->value, pb->value));
       }
       if (pa) {
         if (pa->value == 0) return a;
@@ -244,86 +244,86 @@ inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MinNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
-      if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
+      if (pa && pb) return IntImmNode::make(rtype, std::min(pa->value, pb->value));
+      if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::MaxNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
-      if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
+      if (pa && pb) return IntImmNode::make(rtype, std::max(pa->value, pb->value));
+      if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
     });
   if (a.same_as(b)) return a;
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::GTNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value > pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value > fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::GENode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value >= pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value >= fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::LTNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value < pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value < fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::LENode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value <= pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value <= fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::EQNode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value == pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value == fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
+inline Expr TryConstFold<ir::NENode>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
-      if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value);
-      if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value);
+      if (pa && pb) return UIntImmNode::make(DataType::UInt(1), pa->value != pb->value);
+      if (fa && fb) return UIntImmNode::make(DataType::UInt(1), fa->value != fb->value);
     });
   return Expr();
 }
 
 template<>
-inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
-  using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
+inline Expr TryConstFold<ir::AndNode>(Expr a, Expr b) {
+  using ir::UIntImmNode;
+  const UIntImmNode* pa = a.as<UIntImmNode>();
+  const UIntImmNode* pb = b.as<UIntImmNode>();
   if (pa && pa->value) return b;
   if (pa && !pa->value) return a;
   if (pb && pb->value) return a;
@@ -332,10 +332,10 @@ inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
-  using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
+inline Expr TryConstFold<ir::OrNode>(Expr a, Expr b) {
+  using ir::UIntImmNode;
+  const UIntImmNode* pa = a.as<UIntImmNode>();
+  const UIntImmNode* pb = b.as<UIntImmNode>();
   if (pa && pa->value) return a;
   if (pa && !pa->value) return b;
   if (pb && pb->value) return b;
@@ -344,11 +344,11 @@ inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
 }
 
 template<>
-inline Expr TryConstFold<ir::Not>(Expr a) {
-  using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
+inline Expr TryConstFold<ir::NotNode>(Expr a) {
+  using ir::UIntImmNode;
+  const UIntImmNode* pa = a.as<UIntImmNode>();
   if (pa) {
-    return UIntImm::make(DataType::UInt(1), !(pa->value));
+    return UIntImmNode::make(DataType::UInt(1), !(pa->value));
   }
   return Expr();
 }
index ef405d8..d3f885a 100644 (file)
@@ -140,17 +140,17 @@ class ConstIntBoundAnalyzer::Impl :
     return res;
   }
 
-  Entry VisitExpr_(const Cast* op) final {
+  Entry VisitExpr_(const CastNode* op) final {
     Entry a = VisitExpr(op->value);
     Entry b = Everything(op->dtype);
     return Intersect(a, b);
   }
 
-  Entry VisitExpr_(const IntImm* op) final {
+  Entry VisitExpr_(const IntImmNode* op) final {
     return MakeBound(op->value, op->value);
   }
 
-  Entry VisitExpr_(const UIntImm* op) final {
+  Entry VisitExpr_(const UIntImmNode* op) final {
     if (op->value <= static_cast<uint64_t>(kPosInf)) {
       return MakeBound(op->value, op->value);
     } else {
@@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl :
     }
   }
 
-  Entry VisitExpr_(const Add* op) final {
+  Entry VisitExpr_(const AddNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     Entry ret;
@@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl :
     return ret;
   }
 
-  Entry VisitExpr_(const Sub* op) final {
+  Entry VisitExpr_(const SubNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     Entry ret;
@@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl :
     return ret;
   }
 
-  Entry VisitExpr_(const Mul* op) final {
+  Entry VisitExpr_(const MulNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     return BinaryOpBoundry(a, b, InfAwareMul);
   }
 
-  Entry VisitExpr_(const Div* op) final {
+  Entry VisitExpr_(const DivNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     CHECK(!b.is_const(0)) << "divide by zero";
@@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl :
     return BinaryOpBoundry(a, b, InfAwareDiv);
   }
 
-  Entry VisitExpr_(const Mod* op) final {
+  Entry VisitExpr_(const ModNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     if (b.min_value > 0) {
@@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl :
     }
   }
 
-  Entry VisitExpr_(const FloorDiv* op) final {
+  Entry VisitExpr_(const FloorDivNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     CHECK(!b.is_const(0)) << "floordiv by zero";
@@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl :
     return BinaryOpBoundry(a, b, InfAwareFloorDiv);
   }
 
-  Entry VisitExpr_(const FloorMod* op) final {
+  Entry VisitExpr_(const FloorModNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     if (b.min_value > 0) {
@@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl :
     }
   }
 
-  Entry VisitExpr_(const Min* op) final {
+  Entry VisitExpr_(const MinNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     Entry ret;
@@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl :
     return ret;
   }
 
-  Entry VisitExpr_(const Max* op) final {
+  Entry VisitExpr_(const MaxNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     Entry ret;
@@ -264,25 +264,25 @@ class ConstIntBoundAnalyzer::Impl :
     return ret;
   }
 
-  Entry VisitExpr_(const Select* op) final {
+  Entry VisitExpr_(const SelectNode* op) final {
     Entry a = VisitExpr(op->true_value);
     Entry b = VisitExpr(op->false_value);
     return Union(a, b);
   }
 
-  Entry VisitExpr_(const Call* op) final {
+  Entry VisitExpr_(const CallNode* op) final {
     // only special handle >> and & which can be
     // used for index calculation.
-    if (op->is_intrinsic(Call::shift_right)) {
+    if (op->is_intrinsic(CallNode::shift_right)) {
       return VisitRightShift(op);
-    } else if (op->is_intrinsic(Call::bitwise_and)) {
+    } else if (op->is_intrinsic(CallNode::bitwise_and)) {
       return VisitBitwiseAnd(op);
     } else {
       return Everything(op->dtype);
     }
   }
 
-  Entry VisitExpr_(const Variable* op) final {
+  Entry VisitExpr_(const VarNode* op) final {
     Var v = GetRef<Var>(op);
     auto it = var_map_.find(v);
     if (it != var_map_.end()) {
@@ -292,13 +292,13 @@ class ConstIntBoundAnalyzer::Impl :
     }
   }
 
-  Entry VisitRightShift(const Call* op) {
+  Entry VisitRightShift(const CallNode* op) {
     Entry a = VisitExpr(op->args[0]);
     Entry b = VisitExpr(op->args[1]);
     return BinaryOpBoundry(a, b, InfAwareRightShift);
   }
 
-  Entry VisitBitwiseAnd(const Call* op) {
+  Entry VisitBitwiseAnd(const CallNode* op) {
     Entry a = VisitExpr(op->args[0]);
     Entry b = VisitExpr(op->args[1]);
     // handle positive index case.
@@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl :
       return kNegInf;
     }
     if (y == kPosInf || y == kNegInf) return y;
-    if (WillOverflow<Add>(x, y, kNegInf, kPosInf)) {
+    if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) {
       if (x > 0) return kPosInf;
       return kNegInf;
     }
@@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl :
    * \return the result.
    */
   static int64_t InfAwareMul(int64_t x, int64_t y) {
-    if (!WillOverflow<Mul>(x, y, kNegInf, kPosInf)) return x * y;
+    if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y;
     if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf;
     return kNegInf;
   }
index b8ec974..7785801 100644 (file)
@@ -60,7 +60,7 @@ class LinearEqDetector
     return true;
   }
 
-  LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -70,7 +70,7 @@ class LinearEqDetector
     return ret;
   }
 
-  LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -80,7 +80,7 @@ class LinearEqDetector
     return ret;
   }
 
-  LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final {
     if (fail_) return LinearEqEntry();
     LinearEqEntry a = VisitExpr(op->a, op->a);
     LinearEqEntry b = VisitExpr(op->b, op->b);
@@ -96,7 +96,7 @@ class LinearEqDetector
     ret.coeff = MulCombine(a.base, b.coeff);
     return ret;
   }
-  LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final {
+  LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final {
     LinearEqEntry ret;
     if (op == var_.get()) {
       ret.coeff = make_const(op->dtype, 1);
@@ -152,7 +152,7 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
     base = std::move(ret.base);
   }
 
-  std::unordered_set<const Variable*> vset;
+  std::unordered_set<const VarNode*> vset;
   for (size_t i = vars.size(); i > 1; --i) {
     vset.insert(vars[i - 1].get());
     // The previous coeff contains the variable
@@ -167,11 +167,11 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
 // Detect clip condition as min max value
 bool DetectClipBound(
     const Expr& cond,
-    std::unordered_map<const Variable*, IntervalEntry>* bmap) {
+    std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
   int flag = 0;
   Var var;
   auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
-    if (const Variable* v = n.as<Variable>()) {
+    if (const VarNode* v = n.as<VarNode>()) {
       if (bmap->count(v)) {
         if (flag == 0) {
           var = Downcast<Var>(n);
@@ -188,16 +188,16 @@ bool DetectClipBound(
   if (flag != 1) return false;
   // canonical form: exp >= 0
   Expr canonical;
-  if (const LT* op = cond.as<LT>()) {
+  if (const LTNode* op = cond.as<LTNode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->b - op->a - make_const(op->a.dtype(), 1);
-  } else if (const LE* op = cond.as<LE>()) {
+  } else if (const LENode* op = cond.as<LENode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->b - op->a;
-  } else if (const GT* op = cond.as<GT>()) {
+  } else if (const GTNode* op = cond.as<GTNode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->a - op->b - make_const(op->a.dtype(), 1);
-  } else if (const GE* op = cond.as<GE>()) {
+  } else if (const GENode* op = cond.as<GENode>()) {
     if (!op->a.dtype().is_int()) return false;
     canonical = op->a - op->b;
   } else {
@@ -210,7 +210,7 @@ bool DetectClipBound(
   if (is_const_int(ret.coeff, 1)) {
     // var + shift >=0 -> var >= -shift
     if (p.min_value.defined()) {
-      p.min_value = ir::Max::make(p.min_value, -ret.base);
+      p.min_value = ir::MaxNode::make(p.min_value, -ret.base);
     } else {
       p.min_value = -ret.base;
     }
@@ -219,7 +219,7 @@ bool DetectClipBound(
   if (is_const_int(ret.coeff, -1)) {
     // -var + shift >=0 -> var <= shift
     if (p.max_value.defined()) {
-      p.max_value = ir::Min::make(p.max_value, ret.base);
+      p.max_value = ir::MinNode::make(p.max_value, ret.base);
     } else {
       p.max_value = ret.base;
     }
@@ -243,8 +243,8 @@ void SplitCommExpr(const Expr& e, std::vector<Expr>* ret) {
 // e must be connected by and.
 Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars) {
   std::vector<Expr> splits;
-  SplitCommExpr<ir::And>(e, &splits);
-  std::unordered_map<const Variable*, IntervalEntry> rmap;
+  SplitCommExpr<ir::AndNode>(e, &splits);
+  std::unordered_map<const VarNode*, IntervalEntry> rmap;
   for (Var v : vars) {
     rmap[v.get()] = IntervalEntry();
   }
index 02f3578..1821c16 100644 (file)
@@ -53,15 +53,15 @@ class FuncTouchedDomain final : public StmtExprVisitor {
     return ret;
   }
 
-  void VisitStmt_(const For *op) final {
-    const Variable* var = op->loop_var.get();
+  void VisitStmt_(const ForNode *op) final {
+    const VarNode* var = op->loop_var.get();
     dom_map_[var] = IntSet::range(
         Range::make_by_min_extent(op->min, op->extent));
     StmtExprVisitor::VisitStmt_(op);
     dom_map_.erase(var);
   }
 
-  void VisitStmt_(const LetStmt* op) final {
+  void VisitStmt_(const LetStmtNode* op) final {
     dom_map_[op->var.get()] =
         arith::EvalSet(op->value, dom_map_);
     StmtExprVisitor::VisitStmt_(op);
@@ -69,11 +69,11 @@ class FuncTouchedDomain final : public StmtExprVisitor {
   }
 
   /* TODO: Thread extent unitest not generated.*/
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       const IterVarNode* thread_axis = op->node.as<IterVarNode>();
       CHECK(thread_axis);
-      const Variable* var = thread_axis->var.get();
+      const VarNode* var = thread_axis->var.get();
       dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
       StmtExprVisitor::VisitStmt_(op);
       dom_map_.erase(var);
@@ -82,7 +82,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     if (consider_calls_ && tensor_->op.same_as(op->func)
         && tensor_->value_index == op->value_index) {
       Touch(op->args);
@@ -90,7 +90,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
     StmtExprVisitor::VisitExpr_(op);
   }
 
-  void VisitStmt_(const Provide* op) final {
+  void VisitStmt_(const ProvideNode* op) final {
     if (consider_provides_ && tensor_->op.same_as(op->func)
         && tensor_->value_index == op->value_index) {
       Touch(op->args);
@@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
   const Tensor &tensor_;
   bool consider_calls_, consider_provides_;
   std::vector<std::vector<IntSet> > bounds_;
-  std::unordered_map<const Variable*, IntSet> dom_map_;
+  std::unordered_map<const VarNode*, IntSet> dom_map_;
 };
 
 Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) {
index e3adf1f..fd51091 100644 (file)
@@ -47,30 +47,30 @@ inline bool WillOverflow(int64_t x,
 }
 
 template<>
-inline bool WillOverflow<ir::Add>(int64_t x,
-                                  int64_t y,
-                                  int64_t min_value,
-                                  int64_t max_value) {
+inline bool WillOverflow<ir::AddNode>(int64_t x,
+                                      int64_t y,
+                                      int64_t min_value,
+                                      int64_t max_value) {
   if ((y > 0) && (x > max_value - y)) return true;
   if ((y < 0) && (x < min_value - y)) return true;
   return false;
 }
 
 template<>
-inline bool WillOverflow<ir::Sub>(int64_t x,
-                                  int64_t y,
-                                  int64_t min_value,
-                                  int64_t max_value) {
+inline bool WillOverflow<ir::SubNode>(int64_t x,
+                                      int64_t y,
+                                      int64_t min_value,
+                                      int64_t max_value) {
   if ((y > 0) && (x < min_value + y)) return true;
   if ((y < 0) && (x > max_value + y)) return true;
   return false;
 }
 
 template<>
-inline bool WillOverflow<ir::Mul>(int64_t x,
-                                  int64_t y,
-                                  int64_t min_value,
-                                  int64_t max_value) {
+inline bool WillOverflow<ir::MulNode>(int64_t x,
+                                      int64_t y,
+                                      int64_t min_value,
+                                      int64_t max_value) {
   if (y == 0) return false;
   if (y > 0) {
     if (x < min_value / y)  return true;
@@ -84,10 +84,10 @@ inline bool WillOverflow<ir::Mul>(int64_t x,
 }
 
 template<>
-inline bool WillOverflow<ir::Mod>(int64_t x,
-                                  int64_t y,
-                                  int64_t min_value,
-                                  int64_t max_value) {
+inline bool WillOverflow<ir::ModNode>(int64_t x,
+                                      int64_t y,
+                                      int64_t min_value,
+                                      int64_t max_value) {
   return y == 0;
 }
 
index bf1cdf0..c60c825 100644 (file)
@@ -83,15 +83,15 @@ struct is_logical_op {
     static const bool value = true;             \
   };
 
-TVM_DECLARE_LOGICAL_OP(And);
-TVM_DECLARE_LOGICAL_OP(Or);
-TVM_DECLARE_LOGICAL_OP(EQ);
-TVM_DECLARE_LOGICAL_OP(NE);
-TVM_DECLARE_LOGICAL_OP(GE);
-TVM_DECLARE_LOGICAL_OP(GT);
-TVM_DECLARE_LOGICAL_OP(LE);
-TVM_DECLARE_LOGICAL_OP(LT);
-TVM_DECLARE_LOGICAL_OP(Not);
+TVM_DECLARE_LOGICAL_OP(AndNode);
+TVM_DECLARE_LOGICAL_OP(OrNode);
+TVM_DECLARE_LOGICAL_OP(EQNode);
+TVM_DECLARE_LOGICAL_OP(NENode);
+TVM_DECLARE_LOGICAL_OP(GENode);
+TVM_DECLARE_LOGICAL_OP(GTNode);
+TVM_DECLARE_LOGICAL_OP(LENode);
+TVM_DECLARE_LOGICAL_OP(LTNode);
+TVM_DECLARE_LOGICAL_OP(NotNode);
 
 /*!
  * \brief Combine two interval set under arithmetic operations.
@@ -118,9 +118,9 @@ inline IntervalSet Combine(Analyzer* analyzer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::AddNode>(Analyzer* analyer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value + b->min_value);
   }
@@ -136,9 +136,9 @@ inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::SubNode>(Analyzer* analyer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value - b->min_value);
   }
@@ -155,9 +155,9 @@ inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
 
 
 template<>
-inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::MulNode>(Analyzer* analyzer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value * b->min_value);
   }
@@ -178,11 +178,11 @@ inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
       Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
-      using ir::Select;
+      using ir::SelectNode;
       Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
       Expr e1 = a->min_value * b->min_value;
       Expr e2 = a->max_value * b->min_value;
-      return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
   DLOG(WARNING) << "Return Everything in CombineInterval Mul";
@@ -190,9 +190,9 @@ inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::DivNode>(Analyzer* analyzer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value / b->min_value);
   }
@@ -213,11 +213,11 @@ inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
       Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
-      using ir::Select;
+      using ir::SelectNode;
       Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
       Expr e1 = a->min_value / b->min_value;
       Expr e2 = a->max_value / b->min_value;
-      return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
   DLOG(WARNING) << "Return Everything in CombineInterval Div";
@@ -225,9 +225,9 @@ inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::ModNode>(Analyzer* analyzer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
   }
@@ -256,9 +256,9 @@ inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
 
 
 template<>
-inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
-                                         IntervalSet a,
-                                         IntervalSet b) {
+inline IntervalSet Combine<ir::FloorDivNode>(Analyzer* analyzer,
+                                             IntervalSet a,
+                                             IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
   }
@@ -279,11 +279,11 @@ inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
       Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
       return IntervalSet(min_value, max_value);
     } else if (a->HasUpperBound() && a->HasLowerBound()) {
-      using ir::Select;
+      using ir::SelectNode;
       Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
       Expr e1 = floordiv(a->min_value, b->min_value);
       Expr e2 = floordiv(a->max_value, b->min_value);
-      return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
+      return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1));
     }
   }
   DLOG(WARNING) << "Return Everything in CombineInterval Div";
@@ -291,9 +291,9 @@ inline IntervalSet Combine<ir::FloorDiv>(Analyzer* analyzer,
 }
 
 template<>
-inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
-                                         IntervalSet a,
-                                         IntervalSet b) {
+inline IntervalSet Combine<ir::FloorModNode>(Analyzer* analyzer,
+                                             IntervalSet a,
+                                             IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
   }
@@ -317,9 +317,9 @@ inline IntervalSet Combine<ir::FloorMod>(Analyzer* analyzer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::MaxNode>(Analyzer* analzyer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(max(a->min_value,  b->min_value));
   }
@@ -330,9 +330,9 @@ inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
 }
 
 template<>
-inline IntervalSet Combine<ir::Min>(Analyzer* analzyer,
-                                    IntervalSet a,
-                                    IntervalSet b) {
+inline IntervalSet Combine<ir::MinNode>(Analyzer* analzyer,
+                                        IntervalSet a,
+                                        IntervalSet b) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
   }
@@ -380,15 +380,15 @@ class IntervalSetEvaluator :
     return IntervalSet(min_set->min_value, max_set->max_value);
   }
 
-  IntervalSet VisitExpr_(const IntImm* op) final {
+  IntervalSet VisitExpr_(const IntImmNode* op) final {
     return IntervalSet::SinglePoint(GetRef<Expr>(op));
   }
 
-  IntervalSet VisitExpr_(const UIntImm* op) final {
+  IntervalSet VisitExpr_(const UIntImmNode* op) final {
     return IntervalSet::SinglePoint(GetRef<Expr>(op));
   }
 
-  IntervalSet VisitExpr_(const Variable* op) final {
+  IntervalSet VisitExpr_(const VarNode* op) final {
     Var var = GetRef<Var>(op);
     auto it = dom_map_.find(var);
     if (it != dom_map_.end()) {
@@ -405,75 +405,75 @@ class IntervalSetEvaluator :
     }
   }
 
-  IntervalSet VisitExpr_(const Add* op) final {
+  IntervalSet VisitExpr_(const AddNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Sub* op) final {
+  IntervalSet VisitExpr_(const SubNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Mul* op) final {
+  IntervalSet VisitExpr_(const MulNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Div* op) final {
+  IntervalSet VisitExpr_(const DivNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Mod* op) final {
+  IntervalSet VisitExpr_(const ModNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const FloorDiv* op) final {
+  IntervalSet VisitExpr_(const FloorDivNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const FloorMod* op) final {
+  IntervalSet VisitExpr_(const FloorModNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Min* op) final {
+  IntervalSet VisitExpr_(const MinNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Max* op) final {
+  IntervalSet VisitExpr_(const MaxNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const EQ* op) final {
+  IntervalSet VisitExpr_(const EQNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const NE* op) final {
+  IntervalSet VisitExpr_(const NENode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const LT* op) final {
+  IntervalSet VisitExpr_(const LTNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const LE* op) final {
+  IntervalSet VisitExpr_(const LENode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const GT* op) final {
+  IntervalSet VisitExpr_(const GTNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const GE* op) final {
+  IntervalSet VisitExpr_(const GENode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const And* op) final {
+  IntervalSet VisitExpr_(const AndNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Or* op) final {
+  IntervalSet VisitExpr_(const OrNode* op) final {
     return VisitBinaryExpr_(op);
   }
 
-  IntervalSet VisitExpr_(const Ramp* op) final {
+  IntervalSet VisitExpr_(const RampNode* op) final {
     CHECK(eval_vec_);
     IntervalSet base = Eval(op->base);
     PVar<Integer> stride;
@@ -481,12 +481,12 @@ class IntervalSetEvaluator :
       DataType t = op->base.dtype();
       int64_t vstride = stride.Eval()->value;
       if (vstride> 0) {
-        return Combine<Add>(
+        return Combine<AddNode>(
             analyzer_,
             base,
             IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
       } else {
-        return Combine<Add>(
+        return Combine<AddNode>(
             analyzer_,
             base,
             IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
@@ -496,12 +496,12 @@ class IntervalSetEvaluator :
     return IntervalSet::Everything();
   }
 
-  IntervalSet VisitExpr_(const Broadcast* op) final {
+  IntervalSet VisitExpr_(const BroadcastNode* op) final {
     CHECK(eval_vec_);
     return VisitExpr(op->value);
   }
 
-  IntervalSet VisitExpr_(const Select* op) final {
+  IntervalSet VisitExpr_(const SelectNode* op) final {
     IntervalSet true_set = this->Eval(op->true_value);
     IntervalSet false_set = this->Eval(op->false_value);
     return Union(analyzer_, false_set, true_set);
@@ -720,7 +720,7 @@ Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
 }
 
 Map<Var, IntSet> ConvertDomMap(
-    const std::unordered_map<const Variable*, IntSet>& dom_map) {
+    const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   Map<Var, IntSet> dmap;
   for (auto kv : dom_map) {
     dmap.Set(GetRef<Var>(kv.first), kv.second);
@@ -746,7 +746,7 @@ IntSet EvalSet(Expr e,
 }
 
 IntSet EvalSet(Expr e,
-               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   return EvalSet(e, ConvertDomMap(dom_map));
 }
 
@@ -761,12 +761,12 @@ IntSet EvalSet(Range r,
 }
 
 IntSet EvalSet(Range r,
-               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   return EvalSet(r, ConvertDomMap(dom_map));
 }
 
 IntSet EvalSet(IntSet s,
-               const std::unordered_map<const Variable*, IntSet>& dom_map) {
+               const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   Analyzer ana;
   auto dmap = ConvertDomMap(dom_map);
   IntervalSetEvaluator m(&ana, dmap);
@@ -796,7 +796,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
 
 ExprIntSetMap EvalSetForEachSubExpr(
     Expr e,
-    const std::unordered_map<const Variable*, IntSet>& dom_map) {
+    const std::unordered_map<const VarNode*, IntSet>& dom_map) {
   Analyzer ana;
   auto dmap = ConvertDomMap(dom_map);
   SubExprIntervalSetEvaluator m(&ana, dmap);
index bfce2c2..961c476 100644 (file)
@@ -30,14 +30,14 @@ namespace arith {
 using namespace ir;
 
 Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const For* op) {
+VisitStmt_(const ForNode* op) {
   analyzer_->Bind(op->loop_var,
                   Range::make_by_min_extent(op->min, op->extent));
   return StmtExprMutator::VisitStmt_(op);
 }
 
 Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const LetStmt* op) {
+VisitStmt_(const LetStmtNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     analyzer_->Bind(op->var, value);
@@ -57,7 +57,7 @@ VisitStmt_(const LetStmt* op) {
 }
 
 Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const IfThenElse* op) {
+VisitStmt_(const IfThenElseNode* op) {
   Expr condition = this->VisitExpr(op->condition);
   Stmt then_case, else_case;
   {
@@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) {
   }
   if (op->else_case.defined()) {
       With<ConstraintContext> ctx(analyzer_,
-                                  analyzer_->rewrite_simplify(Not::make(condition)));
+                                  analyzer_->rewrite_simplify(NotNode::make(condition)));
       else_case = this->VisitStmt(op->else_case);
   }
   if (is_one(condition)) return then_case;
@@ -74,7 +74,7 @@ VisitStmt_(const IfThenElse* op) {
     if (else_case.defined()) {
       return else_case;
     }
-    return Evaluate::make(0);
+    return EvaluateNode::make(0);
   }
 
   if (condition.same_as(op->condition) &&
@@ -91,7 +91,7 @@ VisitStmt_(const IfThenElse* op) {
 }
 
 Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AttrStmt* op) {
+VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::thread_extent ||
       op->attr_key == attr::virtual_thread) {
     IterVar iv = Downcast<IterVar>(op->node);
@@ -106,7 +106,7 @@ VisitStmt_(const AttrStmt* op) {
 }
 
 Stmt IRMutatorWithAnalyzer::
-VisitStmt_(const AssertStmt* op) {
+VisitStmt_(const AssertStmtNode* op) {
   Expr condition = this->VisitExpr(op->condition);
   Expr message = this->VisitExpr(op->message);
   With<ConstraintContext> ctx(analyzer_, condition);
@@ -126,7 +126,7 @@ VisitStmt_(const AssertStmt* op) {
 }
 
 Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Call* op) {
+VisitExpr_(const CallNode* op) {
   // add condition context to if_then_else
   if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) {
     Expr cond = this->VisitExpr(op->args[0]);
@@ -137,7 +137,7 @@ VisitExpr_(const Call* op) {
     }
     {
       With<ConstraintContext> constraint(analyzer_,
-                                         analyzer_->rewrite_simplify(Not::make(cond)));
+                                         analyzer_->rewrite_simplify(NotNode::make(cond)));
       false_value = this->VisitExpr(op->args[2]);
     }
     if (is_zero(cond)) {
@@ -151,7 +151,7 @@ VisitExpr_(const Call* op) {
         false_value.same_as(op->args[2])) {
       return GetRef<Expr>(op);
     } else {
-      return Call::make(op->dtype, op->name,
+      return CallNode::make(op->dtype, op->name,
                         {cond, true_value, false_value},
                         op->call_type);
     }
@@ -160,7 +160,7 @@ VisitExpr_(const Call* op) {
 }
 
 Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Let* op) {
+VisitExpr_(const LetNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     analyzer_->Bind(op->var, value);
@@ -172,12 +172,12 @@ VisitExpr_(const Let* op) {
       body.same_as(op->body)) {
     return GetRef<Expr>(op);
   } else {
-    return Let::make(op->var, value, body);
+    return LetNode::make(op->var, value, body);
   }
 }
 
 Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Select* op) {
+VisitExpr_(const SelectNode* op) {
   Expr cond = this->VisitExpr(op->condition);
   Expr true_value, false_value;
   {
@@ -186,7 +186,7 @@ VisitExpr_(const Select* op) {
   }
   {
     With<ConstraintContext> constraint(analyzer_,
-                                       analyzer_->rewrite_simplify(Not::make(cond)));
+                                       analyzer_->rewrite_simplify(NotNode::make(cond)));
     false_value = VisitExpr(op->false_value);
   }
   if (is_zero(cond)) {
@@ -201,12 +201,12 @@ VisitExpr_(const Select* op) {
       false_value.same_as(op->false_value)) {
     return GetRef<Expr>(op);
   } else {
-    return Select::make(cond, true_value, false_value);
+    return SelectNode::make(cond, true_value, false_value);
   }
 }
 
 Expr IRMutatorWithAnalyzer::
-VisitExpr_(const Reduce* op) {
+VisitExpr_(const ReduceNode* op) {
   // Setup the domain information before simplification.
   for (const IterVar& iv : op->axis) {
     analyzer_->Bind(iv->var, iv->dom);
index 9e3a86b..1e96c0a 100644 (file)
@@ -49,15 +49,15 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator {
   using StmtExprMutator::VisitExpr_;
 
   // override functions that need to populate the context information.
-  Stmt VisitStmt_(const ir::For* op) override;
-  Stmt VisitStmt_(const ir::LetStmt* op) override;
-  Stmt VisitStmt_(const ir::IfThenElse* op) override;
-  Stmt VisitStmt_(const ir::AttrStmt* op) override;
-  Stmt VisitStmt_(const ir::AssertStmt* op) override;
-  Expr VisitExpr_(const ir::Let* op) override;
-  Expr VisitExpr_(const ir::Select* op) override;
-  Expr VisitExpr_(const ir::Call* op) override;
-  Expr VisitExpr_(const ir::Reduce* op) override;
+  Stmt VisitStmt_(const ir::ForNode* op) override;
+  Stmt VisitStmt_(const ir::LetStmtNode* op) override;
+  Stmt VisitStmt_(const ir::IfThenElseNode* op) override;
+  Stmt VisitStmt_(const ir::AttrStmtNode* op) override;
+  Stmt VisitStmt_(const ir::AssertStmtNode* op) override;
+  Expr VisitExpr_(const ir::LetNode* op) override;
+  Expr VisitExpr_(const ir::SelectNode* op) override;
+  Expr VisitExpr_(const ir::CallNode* op) override;
+  Expr VisitExpr_(const ir::ReduceNode* op) override;
 
  protected:
   /*! \brief internal analyzer field. */
index b8750df..07ec186 100644 (file)
@@ -38,13 +38,13 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
     return analyzer_.Simplify(expr);
   }
 
-  void VisitStmt_(const For* op) {
+  void VisitStmt_(const ForNode* op) {
     analyzer_.Bind(op->loop_var,
                    Range::make_by_min_extent(op->min, op->extent));
     return StmtExprVisitor::VisitStmt_(op);
   }
 
-  void VisitStmt_(const AttrStmt* op) {
+  void VisitStmt_(const AttrStmtNode* op) {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::virtual_thread) {
       IterVar iv = Downcast<IterVar>(op->node);
@@ -57,7 +57,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Reduce* op) {
+  void VisitExpr_(const ReduceNode* op) {
     // Setup the domain information before simplification.
     for (const IterVar& iv : op->axis) {
       analyzer_.Bind(iv->var, iv->dom);
index a83e987..8e2e065 100644 (file)
@@ -124,15 +124,15 @@ class ModularSetAnalyzer::Impl :
     return Everything();
   }
 
-  Entry VisitExpr_(const Cast* op) final {
+  Entry VisitExpr_(const CastNode* op) final {
     return VisitExpr(op->value);
   }
 
-  Entry VisitExpr_(const IntImm* op) final {
+  Entry VisitExpr_(const IntImmNode* op) final {
     return Entry(0, op->value);
   }
 
-  Entry VisitExpr_(const UIntImm* op) final {
+  Entry VisitExpr_(const UIntImmNode* op) final {
     if (op->value < std::numeric_limits<int64_t>::max()) {
       return Entry(0, static_cast<int>(op->value));
     } else {
@@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl :
     }
   }
 
-  Entry VisitExpr_(const Add* op) final {
+  Entry VisitExpr_(const AddNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
     return Entry(coeff, a.base + b.base);
   }
 
-  Entry VisitExpr_(const Sub* op) final {
+  Entry VisitExpr_(const SubNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
     return Entry(coeff, a.base - b.base);
   }
 
-  Entry VisitExpr_(const Mul* op) final {
+  Entry VisitExpr_(const MulNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     // Simplification rule, x, y, z are in Z
@@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl :
     return Everything();
   }
 
-  Entry VisitExpr_(const Div* op) final {
+  Entry VisitExpr_(const DivNode* op) final {
     Entry b = VisitExpr(op->b);
     if (b.is_const()) {
       return DivByConst(op->a, b.base, false);
@@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl :
     return Everything();
   }
 
-  Entry VisitExpr_(const FloorDiv* op) final {
+  Entry VisitExpr_(const FloorDivNode* op) final {
     Entry b = VisitExpr(op->b);
     if (b.is_const()) {
       return DivByConst(op->a, b.base, true);
@@ -204,35 +204,35 @@ class ModularSetAnalyzer::Impl :
     return Everything();
   }
 
-  Entry VisitExpr_(const Min* op) final {
+  Entry VisitExpr_(const MinNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     return Union(a, b);
   }
 
-  Entry VisitExpr_(const Max* op) final {
+  Entry VisitExpr_(const MaxNode* op) final {
     Entry a = VisitExpr(op->a);
     Entry b = VisitExpr(op->b);
     return Union(a, b);
   }
 
-  Entry VisitExpr_(const Select* op) final {
+  Entry VisitExpr_(const SelectNode* op) final {
     Entry a = VisitExpr(op->true_value);
     Entry b = VisitExpr(op->false_value);
     return Union(a, b);
   }
 
-  Entry VisitExpr_(const Call* op) final {
+  Entry VisitExpr_(const CallNode* op) final {
     // only special handle >> which can be
     // used for index calculation.
-    if (op->is_intrinsic(Call::shift_right)) {
+    if (op->is_intrinsic(CallNode::shift_right)) {
       return VisitRightShift(op);
     } else {
       return Everything();
     }
   }
 
-  Entry VisitExpr_(const Variable* op) final {
+  Entry VisitExpr_(const VarNode* op) final {
     Var v = GetRef<Var>(op);
     auto it = var_map_.find(v);
     if (it != var_map_.end()) {
@@ -242,7 +242,7 @@ class ModularSetAnalyzer::Impl :
     }
   }
 
-  Entry VisitRightShift(const Call* op) {
+  Entry VisitRightShift(const CallNode* op) {
     Entry b = VisitExpr(op->args[1]);
     // a c x  / c -> a x
     if (b.is_const()) {
index bff9564..e964abb 100644 (file)
@@ -283,7 +283,7 @@ class PConstWithTypeLike :
   void InitMatch_() const {}
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
+    if (const ir::IntImmNode* ptr = node.as<ir::IntImmNode>()) {
       return ptr->value == value_;
     } else {
       return false;
@@ -325,30 +325,30 @@ class PConstWithTypeLike :
 
 
 // raise ambiguity error for operator overload of / and %
-TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
-TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
+TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a));
+TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a));
 
 // arithmetic expressions
-TVM_PATTERN_BINARY_OP(operator+, ir::Add);
-TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
-TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
-TVM_PATTERN_BINARY_OP(min, ir::Min);
-TVM_PATTERN_BINARY_OP(max, ir::Max);
-TVM_PATTERN_BINARY_OP(div, ir::Div);
-TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
-TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
-TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
-TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
+TVM_PATTERN_BINARY_OP(operator+, ir::AddNode);
+TVM_PATTERN_BINARY_OP(operator-, ir::SubNode);
+TVM_PATTERN_BINARY_OP(operator*, ir::MulNode);
+TVM_PATTERN_BINARY_OP(min, ir::MinNode);
+TVM_PATTERN_BINARY_OP(max, ir::MaxNode);
+TVM_PATTERN_BINARY_OP(div, ir::DivNode);
+TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode);
+TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode);
+TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode);
+TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode);
 
 // logical expressions
-TVM_PATTERN_BINARY_OP(operator>, ir::GT);
-TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
-TVM_PATTERN_BINARY_OP(operator<, ir::LT);
-TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
-TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
-TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
-TVM_PATTERN_BINARY_OP(operator&&, ir::And);
-TVM_PATTERN_BINARY_OP(operator||, ir::Or);
+TVM_PATTERN_BINARY_OP(operator>, ir::GTNode);
+TVM_PATTERN_BINARY_OP(operator>=, ir::GENode);
+TVM_PATTERN_BINARY_OP(operator<, ir::LTNode);
+TVM_PATTERN_BINARY_OP(operator<=, ir::LENode);
+TVM_PATTERN_BINARY_OP(operator==, ir::EQNode);
+TVM_PATTERN_BINARY_OP(operator!=, ir::NENode);
+TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode);
+TVM_PATTERN_BINARY_OP(operator||, ir::OrNode);
 
 /*!
  * \brief Pattern not expression.
@@ -365,7 +365,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Not* ptr = node.as<ir::Not>()) {
+    if (const ir::NotNode* ptr = node.as<ir::NotNode>()) {
       if (!value_.Match_(ptr->a)) return false;
       return true;
     } else {
@@ -374,7 +374,7 @@ class PNotExpr : public Pattern<PNotExpr<TA> > {
   }
 
   Expr Eval() const {
-    return ir::Not::make(value_.Eval());
+    return ir::NotNode::make(value_.Eval());
   }
 
  private:
@@ -411,7 +411,7 @@ class PSelectExpr :
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Select* ptr = node.as<ir::Select>()) {
+    if (const ir::SelectNode* ptr = node.as<ir::SelectNode>()) {
       if (!condition_.Match_(ptr->condition)) return false;
       if (!true_value_.Match_(ptr->true_value)) return false;
       if (!false_value_.Match_(ptr->false_value)) return false;
@@ -422,7 +422,7 @@ class PSelectExpr :
   }
 
   Expr Eval() const {
-    return ir::Select::make(
+    return ir::SelectNode::make(
         condition_.Eval(), true_value_.Eval(), false_value_.Eval());
   }
 
@@ -473,7 +473,7 @@ class PCastExpr :
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Cast* ptr = node.as<ir::Cast>()) {
+    if (const ir::CastNode* ptr = node.as<ir::CastNode>()) {
       if (!dtype_.Match_(ptr->dtype)) return false;
       if (!value_.Match_(ptr->value)) return false;
       return true;
@@ -483,7 +483,7 @@ class PCastExpr :
   }
 
   Expr Eval() const {
-    return ir::Cast::make(dtype_.Eval(), value_.Eval());
+    return ir::CastNode::make(dtype_.Eval(), value_.Eval());
   }
 
  private:
@@ -531,7 +531,7 @@ class PRampExpr :
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
+    if (const ir::RampNode* ptr = node.as<ir::RampNode>()) {
       if (!base_.Match_(ptr->base)) return false;
       if (!stride_.Match_(ptr->stride)) return false;
       if (!lanes_.Match_(ptr->lanes)) return false;
@@ -542,7 +542,7 @@ class PRampExpr :
   }
 
   Expr Eval() const {
-    return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
+    return ir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
   }
 
  private:
@@ -593,7 +593,7 @@ class PBroadcastExpr :
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
+    if (const ir::BroadcastNode* ptr = node.as<ir::BroadcastNode>()) {
       if (!value_.Match_(ptr->value)) return false;
       if (!lanes_.Match_(ptr->lanes)) return false;
       return true;
@@ -603,7 +603,7 @@ class PBroadcastExpr :
   }
 
   Expr Eval() const {
-    return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
+    return ir::BroadcastNode::make(value_.Eval(), lanes_.Eval());
   }
 
  private:
@@ -662,10 +662,10 @@ struct PCallExprInitMatchFunctor {
 };
 
 struct PCallExprMatchFunctor {
-  const ir::Call* call_;
+  const ir::CallNode* call_;
   bool matched_{true};
 
-  explicit PCallExprMatchFunctor(const ir::Call* call)
+  explicit PCallExprMatchFunctor(const ir::CallNode* call)
       : call_(call) {}
 
   template<typename T>
@@ -705,7 +705,7 @@ class PCallExpr :
   }
 
   bool Match_(const ObjectRef& node) const {
-    if (const ir::Call* ptr = node.as<ir::Call>()) {
+    if (const ir::CallNode* ptr = node.as<ir::CallNode>()) {
       if (ptr->args.size() != sizeof...(TArgs)) return false;
       if (ptr->name != Op::kName) return false;
       detail::PCallExprMatchFunctor fmatch(ptr);
@@ -727,18 +727,18 @@ class PCallExpr :
 };
 
 // arithemetic intrinsics
-#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)        \
-  struct OpName {                                                     \
-    static Expr Eval(Array<Expr> args) {                              \
-      return ir::Call::make(args[0].dtype(), kName, args,             \
-                            ir::Call::PureIntrinsic);                 \
-    }                                                                 \
-    static constexpr const char* kName = IntrinStr;                   \
-  };                                                                  \
-  template<typename TA, typename TB>                                  \
-  inline PCallExpr<OpName, TA, TB>                                    \
-  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {              \
-    return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());             \
+#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)          \
+  struct OpName {                                                       \
+    static Expr Eval(Array<Expr> args) {                                \
+      return ir::CallNode::make(args[0].dtype(), kName, args,           \
+                                ir::CallNode::PureIntrinsic);           \
+    }                                                                   \
+    static constexpr const char* kName = IntrinStr;                     \
+  };                                                                    \
+  template<typename TA, typename TB>                                    \
+  inline PCallExpr<OpName, TA, TB>                                      \
+  FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {                \
+    return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());         \
   }
 
 TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
@@ -748,18 +748,18 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
 TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
 
 // unary intrinsics
-#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)         \
-  struct OpName {                                                     \
-    static Expr Eval(Array<Expr> args) {                              \
-      return ir::Call::make(args[0].dtype(), kName, args,             \
-                            ir::Call::PureIntrinsic);                 \
-    }                                                                 \
-    static constexpr const char* kName = IntrinStr;                   \
-  };                                                                  \
-  template<typename TA>                                               \
-  inline PCallExpr<OpName, TA>                                        \
-  FuncName(const Pattern<TA>& a) {                                    \
-    return PCallExpr<OpName, TA>(a.derived());                           \
+#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)           \
+  struct OpName {                                                       \
+    static Expr Eval(Array<Expr> args) {                                \
+      return ir::CallNode::make(args[0].dtype(), kName, args,           \
+                                ir::CallNode::PureIntrinsic);           \
+    }                                                                   \
+    static constexpr const char* kName = IntrinStr;                     \
+  };                                                                    \
+  template<typename TA>                                                 \
+  inline PCallExpr<OpName, TA>                                          \
+  FuncName(const Pattern<TA>& a) {                                      \
+    return PCallExpr<OpName, TA>(a.derived());                          \
   }
 
 TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
@@ -767,9 +767,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
 // if_then_else
 struct PIfThenElseOp {
   static Expr Eval(Array<Expr> args) {
-    return ir::Call::make(
+    return ir::CallNode::make(
         args[1].dtype(), kName, args,
-        ir::Call::PureIntrinsic);
+        ir::CallNode::PureIntrinsic);
   }
   static constexpr const char* kName = "tvm_if_then_else";
 };
index f883bf1..2421e10 100644 (file)
@@ -69,7 +69,7 @@ using namespace ir;
 RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::
 TryCompare(const Expr& x, int64_t val) {
   Expr diff = this->VisitExpr(x);
-  if (const auto* ptr = diff.as<IntImm>()) {
+  if (const auto* ptr = diff.as<IntImmNode>()) {
     if (ptr->value == val) {
       return kEQ;
     } else if (ptr->value > val) {
@@ -116,10 +116,10 @@ Update(const Var& var, const Expr& info, bool override) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Add* op) {
+VisitExpr_(const AddNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Add>();
-  Expr const_res = TryConstFold<Add>(op->a, op->b);
+  op = ret.as<AddNode>();
+  Expr const_res = TryConstFold<AddNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y, z, b1, b2, s1, s2;
@@ -231,10 +231,10 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const Expr& const
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Sub* op) {
+VisitExpr_(const SubNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Sub>();
-  Expr const_res = TryConstFold<Sub>(op->a, op->b);
+  op = ret.as<SubNode>();
+  Expr const_res = TryConstFold<SubNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y, z, b1, b2, s1, s2;
@@ -430,10 +430,10 @@ VisitExpr_(const Sub* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Mul* op) {
+VisitExpr_(const MulNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Mul>();
-  Expr const_res = TryConstFold<Mul>(op->a, op->b);
+  op = ret.as<MulNode>();
+  Expr const_res = TryConstFold<MulNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y, z, b1, b2, s1, s2;
@@ -469,10 +469,10 @@ VisitExpr_(const Mul* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Div* op) {
+VisitExpr_(const DivNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Div>();
-  Expr const_res = TryConstFold<Div>(op->a, op->b);
+  op = ret.as<DivNode>();
+  Expr const_res = TryConstFold<DivNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y, z, b1;
@@ -482,7 +482,7 @@ VisitExpr_(const Div* op) {
   PVar<int> lanes;
 
   // x / 2.0 = x * 0.5
-  if (const FloatImm* ptr = op->b.as<FloatImm>()) {
+  if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
     CHECK(op->dtype.is_float());
     return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
   }
@@ -691,10 +691,10 @@ VisitExpr_(const Div* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Mod* op) {
+VisitExpr_(const ModNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Mod>();
-  Expr const_res = TryConstFold<Mod>(op->a, op->b);
+  op = ret.as<ModNode>();
+  Expr const_res = TryConstFold<ModNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -781,10 +781,10 @@ VisitExpr_(const Mod* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const FloorDiv* op) {
+VisitExpr_(const FloorDivNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<FloorDiv>();
-  Expr const_res = TryConstFold<FloorDiv>(op->a, op->b);
+  op = ret.as<FloorDivNode>();
+  Expr const_res = TryConstFold<FloorDivNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y, z, b1;
@@ -925,10 +925,10 @@ VisitExpr_(const FloorDiv* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const FloorMod* op) {
+VisitExpr_(const FloorModNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<FloorMod>();
-  Expr const_res = TryConstFold<FloorMod>(op->a, op->b);
+  op = ret.as<FloorModNode>();
+  Expr const_res = TryConstFold<FloorModNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -995,10 +995,10 @@ VisitExpr_(const FloorMod* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Min* op) {
+VisitExpr_(const MinNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Min>();
-  Expr const_res = TryConstFold<Min>(op->a, op->b);
+  op = ret.as<MinNode>();
+  Expr const_res = TryConstFold<MinNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1180,10 +1180,10 @@ VisitExpr_(const Min* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Max* op) {
+VisitExpr_(const MaxNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Max>();
-  Expr const_res = TryConstFold<Max>(op->a, op->b);
+  op = ret.as<MaxNode>();
+  Expr const_res = TryConstFold<MaxNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1353,10 +1353,10 @@ VisitExpr_(const Max* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const EQ* op) {
+VisitExpr_(const EQNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<EQ>();
-  Expr const_res = TryConstFold<EQ>(op->a, op->b);
+  op = ret.as<EQNode>();
+  Expr const_res = TryConstFold<EQNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1387,30 +1387,30 @@ VisitExpr_(const EQ* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const NE* op) {
-  return this->VisitExpr(Not::make(op->a == op->b));
+VisitExpr_(const NENode* op) {
+  return this->VisitExpr(NotNode::make(op->a == op->b));
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const LE* op) {
-  return this->VisitExpr(Not::make(op->b < op->a));
+VisitExpr_(const LENode* op) {
+  return this->VisitExpr(NotNode::make(op->b < op->a));
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const GT* op) {
+VisitExpr_(const GTNode* op) {
   return this->VisitExpr(op->b < op->a);
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const GE* op) {
-  return this->VisitExpr(Not::make(op->a < op->b));
+VisitExpr_(const GENode* op) {
+  return this->VisitExpr(NotNode::make(op->a < op->b));
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const LT* op) {
+VisitExpr_(const LTNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<LT>();
-  Expr const_res = TryConstFold<LT>(op->a, op->b);
+  op = ret.as<LTNode>();
+  Expr const_res = TryConstFold<LTNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1563,10 +1563,10 @@ VisitExpr_(const LT* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Not* op) {
+VisitExpr_(const NotNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Not>();
-  Expr const_res = TryConstFold<Not>(op->a);
+  op = ret.as<NotNode>();
+  Expr const_res = TryConstFold<NotNode>(op->a);
   if (const_res.defined()) return const_res;
   // Pattern var to match any expression
   PVar<Expr> x, y;
@@ -1588,10 +1588,10 @@ VisitExpr_(const Not* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const And* op) {
+VisitExpr_(const AndNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<And>();
-  Expr const_res = TryConstFold<And>(op->a, op->b);
+  op = ret.as<AndNode>();
+  Expr const_res = TryConstFold<AndNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1637,10 +1637,10 @@ VisitExpr_(const And* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Or* op) {
+VisitExpr_(const OrNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Or>();
-  Expr const_res = TryConstFold<Or>(op->a, op->b);
+  op = ret.as<OrNode>();
+  Expr const_res = TryConstFold<OrNode>(op->a, op->b);
   if (const_res.defined()) return const_res;
 
   // Pattern var to match any expression
@@ -1687,9 +1687,9 @@ VisitExpr_(const Or* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Select* op) {
+VisitExpr_(const SelectNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Select>();
+  op = ret.as<SelectNode>();
   if (op == nullptr) return ret;
   // Pattern var to match any expression
   PVar<Expr> x, y;
@@ -1698,25 +1698,25 @@ VisitExpr_(const Select* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Call* op) {
+VisitExpr_(const CallNode* op) {
   // add condition context to if_then_else
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Call>();
+  op = ret.as<CallNode>();
   if (op == nullptr) return ret;
-  if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) {
+  if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) {
     return op->args[0];
-  } else if (op->is_intrinsic(Call::shift_right)) {
-    if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
+  } else if (op->is_intrinsic(CallNode::shift_right)) {
+    if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
       // the operator overload will eagerly constant fold.
       return op->args[0] >> op->args[1];
     }
-  } else if (op->is_intrinsic(Call::bitwise_and)) {
-    if (op->args[0].as<IntImm>() && op->args[1].as<IntImm>()) {
+  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
+    if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
       // the operator overload will eagerly constant fold.
       return op->args[0] & op->args[1];
     }
   }
-  if (op->is_intrinsic(Call::likely)) {
+  if (op->is_intrinsic(CallNode::likely)) {
     for (const auto& constraint : literal_constraints_) {
       // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
       if (Equal(constraint, op->args[0])) {
@@ -1728,7 +1728,7 @@ VisitExpr_(const Call* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Variable* op) {
+VisitExpr_(const VarNode* op) {
   Var var = GetRef<Var>(op);
   auto it = var_map_.find(var);
   if (it != var_map_.end()) {
@@ -1738,14 +1738,14 @@ VisitExpr_(const Variable* op) {
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Cast* op) {
+VisitExpr_(const CastNode* op) {
   Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-  op = ret.as<Cast>();
+  op = ret.as<CastNode>();
   return cast(op->dtype, op->value);
 }
 
 Expr RewriteSimplifier::Impl::
-VisitExpr_(const Let* op) {
+VisitExpr_(const LetNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (!ir::HasSideEffect(value)) {
     // it is fine to discard the let binding
@@ -1758,7 +1758,7 @@ VisitExpr_(const Let* op) {
       body.same_as(op->body)) {
     return GetRef<Expr>(op);
   } else {
-    return Let::make(op->var, value, body);
+    return LetNode::make(op->var, value, body);
   }
 }
 
index cf9dd6e..f2659a9 100644 (file)
@@ -50,29 +50,29 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
       : IRMutatorWithAnalyzer(parent) {}
 
   void Update(const Var& var, const Expr& info, bool override_info);
-  Expr VisitExpr_(const Add* op) override;
-  Expr VisitExpr_(const Sub* op) override;
-  Expr VisitExpr_(const Mul* op) override;
-  Expr VisitExpr_(const Div* op) override;
-  Expr VisitExpr_(const Mod* op) override;
-  Expr VisitExpr_(const FloorDiv* op) override;
-  Expr VisitExpr_(const FloorMod* op) override;
-  Expr VisitExpr_(const Min* op) override;
-  Expr VisitExpr_(const Max* op) override;
-  Expr VisitExpr_(const EQ* op) override;
-  Expr VisitExpr_(const NE* op) override;
-  Expr VisitExpr_(const LT* op) override;
-  Expr VisitExpr_(const LE* op) override;
-  Expr VisitExpr_(const GT* op) override;
-  Expr VisitExpr_(const GE* op) override;
-  Expr VisitExpr_(const And* op) override;
-  Expr VisitExpr_(const Or* op) override;
-  Expr VisitExpr_(const Not* op) override;
-  Expr VisitExpr_(const Select* op) override;
-  Expr VisitExpr_(const Call* op) override;
-  Expr VisitExpr_(const Variable* op) override;
-  Expr VisitExpr_(const Cast* op) override;
-  Expr VisitExpr_(const Let* op) override;
+  Expr VisitExpr_(const AddNode* op) override;
+  Expr VisitExpr_(const SubNode* op) override;
+  Expr VisitExpr_(const MulNode* op) override;
+  Expr VisitExpr_(const DivNode* op) override;
+  Expr VisitExpr_(const ModNode* op) override;
+  Expr VisitExpr_(const FloorDivNode* op) override;
+  Expr VisitExpr_(const FloorModNode* op) override;
+  Expr VisitExpr_(const MinNode* op) override;
+  Expr VisitExpr_(const MaxNode* op) override;
+  Expr VisitExpr_(const EQNode* op) override;
+  Expr VisitExpr_(const NENode* op) override;
+  Expr VisitExpr_(const LTNode* op) override;
+  Expr VisitExpr_(const LENode* op) override;
+  Expr VisitExpr_(const GTNode* op) override;
+  Expr VisitExpr_(const GENode* op) override;
+  Expr VisitExpr_(const AndNode* op) override;
+  Expr VisitExpr_(const OrNode* op) override;
+  Expr VisitExpr_(const NotNode* op) override;
+  Expr VisitExpr_(const SelectNode* op) override;
+  Expr VisitExpr_(const CallNode* op) override;
+  Expr VisitExpr_(const VarNode* op) override;
+  Expr VisitExpr_(const CastNode* op) override;
+  Expr VisitExpr_(const LetNode* op) override;
 
   std::function<void()> EnterConstraint(const Expr& constraint);
 
index 4996cfd..73b5dce 100644 (file)
@@ -50,14 +50,14 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
     return operator()(std::move(stmt));
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
     With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
     With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
     return Parent::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const LetStmt* op) {
+  Stmt VisitStmt_(const LetStmtNode* op) {
     Expr value = this->VisitExpr(op->value);
     if (!ir::HasSideEffect(value)) {
       // it is fine to discard the let binding
@@ -78,13 +78,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   }
 
   // eliminate useless stores
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = Parent::VisitStmt_(op);
-    op = stmt.as<Store>();
-    if (const Load* load = op->value.as<Load>()) {
+    op = stmt.as<StoreNode>();
+    if (const LoadNode* load = op->value.as<LoadNode>()) {
       if (load->buffer_var.same_as(op->buffer_var) &&
           Equal(load->index, op->index)) {
-        return Evaluate::make(0);
+        return EvaluateNode::make(0);
       }
     }
     return GetRef<Stmt>(op);
index 4d2330f..11452d3 100644 (file)
@@ -29,8 +29,8 @@ namespace tvm {
 namespace autotvm {
 
 // for loop
-void FeatureVisitor::VisitStmt_(const For* op) {
-  const auto *extent = op->extent.as<IntImm>();
+void FeatureVisitor::VisitStmt_(const ForNode* op) {
+  const auto *extent = op->extent.as<IntImmNode>();
   int64_t loop_extent = -1;
   if (extent != nullptr)
     loop_extent = extent->value;
@@ -57,11 +57,11 @@ void FeatureVisitor::VisitStmt_(const For* op) {
 }
 
 // parallel axis, virtual thread
-void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
+void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::thread_extent ||
       op->attr_key == attr::virtual_thread) {
     VarExpr var = op->node.as<tvm::IterVarNode>()->var;
-    const auto *extent = op->value.as<IntImm>();
+    const auto *extent = op->value.as<IntImmNode>();
     CHECK(extent);
 
     std::string name = var.get()->name_hint;
@@ -95,13 +95,13 @@ void FeatureVisitor::VisitStmt_(const AttrStmt* op) {
 }
 
 // memory access
-void FeatureVisitor::VisitExpr_(const Load* op) {
+void FeatureVisitor::VisitExpr_(const LoadNode* op) {
   EnterMem_(op->buffer_var, op->index);
   StmtExprVisitor::VisitExpr_(op);
   ExitMem_();
 }
 
-void FeatureVisitor::VisitStmt_(const Store* op) {
+void FeatureVisitor::VisitStmt_(const StoreNode* op) {
   EnterMem_(op->buffer_var, op->index);
   StmtExprVisitor::VisitStmt_(op);
   ExitMem_();
index 685becd..9f65fb4 100644 (file)
@@ -51,12 +51,12 @@ enum AnnotationType {
 class FeatureVisitor : public StmtExprVisitor {
  public:
   // for loop
-  void VisitStmt_(const For* op) final;
-  void VisitStmt_(const AttrStmt* op) final;
+  void VisitStmt_(const ForNode* op) final;
+  void VisitStmt_(const AttrStmtNode* op) final;
 
   // memory access
-  void VisitExpr_(const Load* op) final;
-  void VisitStmt_(const Store* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitStmt_(const StoreNode* op) final;
 
   using StmtExprVisitor::VisitStmt_;
   using StmtExprVisitor::VisitExpr_;
index 51b1354..0ee4b11 100644 (file)
@@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor {
     this->VisitExpr(expr);
   }
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     // TODO(lmzheng): handle more index types (multiple occurrence)
     if (pattern_map.count(op) == 0) {
       pattern_map[op] = TouchPattern();
@@ -60,16 +60,16 @@ class IndexParser: public ExprVisitor {
     }
   }
 
-  void VisitExpr_(const Mul* op) final {
-    if (op->a.as<Variable>()) {
-      if (const auto stride = op->b.as<IntImm>()) {
+  void VisitExpr_(const MulNode* op) final {
+    if (op->a.as<VarNode>()) {
+      if (const auto stride = op->b.as<IntImmNode>()) {
         next_stride_ = stride->value;
       }
     }
     ExprVisitor::VisitExpr_(op);
   }
 
-  std::unordered_map<const Variable*, TouchPattern> pattern_map;
+  std::unordered_map<const VarNode*, TouchPattern> pattern_map;
 
  private:
   int64_t next_stride_ = 1;
@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
     feature_row.push_back(Array<Expr>{std::string("_itervar_"), var});
 
     Array<Expr> attr{std::string("_attr_"),
-                     FloatImm::make(DataType::Float(32), trans(fea.length)),
-                     IntImm::make(DataType::Int(32), fea.nest_level),
-                     FloatImm::make(DataType::Float(32), trans(fea.topdown_product)),
-                     FloatImm::make(DataType::Float(32), trans(fea.bottomup_product)),
+                     FloatImmNode::make(DataType::Float(32), trans(fea.length)),
+                     IntImmNode::make(DataType::Int(32), fea.nest_level),
+                     FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
+                     FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
     };
     // one hot annotation
     for (int i = 0; i < kNum; i++) {
@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
 
     // arithmetic
     feature_row.push_back(Array<Expr>{std::string("_arith_"),
-                                      FloatImm::make(DataType::Float(32), trans(fea.add_ct)),
-                                      FloatImm::make(DataType::Float(32), trans(fea.mul_ct)),
-                                      FloatImm::make(DataType::Float(32), trans(fea.div_ct)),
+            FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
+            FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
+            FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
     });
 
     // touch map
@@ -281,14 +281,15 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<Expr> > > *re
     std::sort(bufs.begin(), bufs.end());
     for (auto k : bufs) {
       TouchPattern &v = fea.touch_feature[k];
-      feature_row.push_back(Array<Expr>{k,
-                                        FloatImm::make(DataType::Float(32), trans(v.stride)),
-                                        FloatImm::make(DataType::Float(32), trans(v.mod)),
-                                        FloatImm::make(DataType::Float(32), trans(v.count)),
-                                        FloatImm::make(DataType::Float(32), trans(v.reuse)),
-                                        FloatImm::make(DataType::Float(32), trans(v.thread_count)),
-                                        FloatImm::make(DataType::Float(32), trans(v.thread_reuse)),
-      });
+      feature_row.push_back(
+          Array<Expr>{k,
+                FloatImmNode::make(DataType::Float(32), trans(v.stride)),
+                FloatImmNode::make(DataType::Float(32), trans(v.mod)),
+                FloatImmNode::make(DataType::Float(32), trans(v.count)),
+                FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
+                FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
+                FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
+                });
     }
 
     ret_feature->push_back(feature_row);
index 2bcf6b8..5265aad 100644 (file)
@@ -92,31 +92,31 @@ class TouchExtractor : public FeatureVisitor {
   }
 
   // arithmetic stats
-  void VisitExpr_(const Add* op) final {
+  void VisitExpr_(const AddNode* op) final {
     if (op->dtype.is_float())
       itervar_map[itervar_stack_.back()].add_ct++;
     FeatureVisitor::VisitExpr_(op);
   }
 
-  void VisitExpr_(const Sub* op) final {
+  void VisitExpr_(const SubNode* op) final {
     if (op->dtype.is_float())
       itervar_map[itervar_stack_.back()].add_ct++;
     FeatureVisitor::VisitExpr_(op);
   }
 
-  void VisitExpr_(const Mul* op) final {
+  void VisitExpr_(const MulNode* op) final {
     if (op->dtype.is_float())
       itervar_map[itervar_stack_.back()].mul_ct++;
     FeatureVisitor::VisitExpr_(op);
   }
 
-  void VisitExpr_(const Div* op) final {
+  void VisitExpr_(const DivNode* op) final {
     if (op->dtype.is_float())
       itervar_map[itervar_stack_.back()].div_ct++;
     FeatureVisitor::VisitExpr_(op);
   }
 
-  void VisitExpr_(const Mod* op) final {
+  void VisitExpr_(const ModNode* op) final {
     if (op->dtype.is_float())
       itervar_map[itervar_stack_.back()].div_ct++;
     FeatureVisitor::VisitExpr_(op);
index 38f0b95..77b1c9d 100644 (file)
@@ -65,39 +65,39 @@ Target CreateTarget(const std::string& target_name,
   std::string device_flag = "-device=";
   std::string keys_flag = "-keys=";
   for (auto& item : options) {
-    t->options_array.push_back(ir::StringImm::make(item));
+    t->options_array.push_back(ir::StringImmNode::make(item));
 
     if (item.find(libs_flag) == 0) {
       std::stringstream ss(item.substr(libs_flag.length()));
       std::string lib_item;
       while (std::getline(ss, lib_item, ',')) {
-        t->libs_array.push_back(ir::StringImm::make(lib_item));
+        t->libs_array.push_back(ir::StringImmNode::make(lib_item));
       }
     } else if (item.find(device_flag) == 0) {
       t->device_name = item.substr(device_flag.length());
-      t->keys_array.push_back(ir::StringImm::make(t->device_name));
+      t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
     } else if (item.find(keys_flag) == 0) {
       std::stringstream ss(item.substr(keys_flag.length()));
       std::string key_item;
       while (std::getline(ss, key_item, ',')) {
-        t->keys_array.push_back(ir::StringImm::make(key_item));
+        t->keys_array.push_back(ir::StringImmNode::make(key_item));
       }
     }
   }
 
   if (t->device_name.length() > 0) {
-    t->keys_array.push_back(ir::StringImm::make(t->device_name));
+    t->keys_array.push_back(ir::StringImmNode::make(t->device_name));
   }
   t->device_type = kDLCPU;
   t->thread_warp_size = 1;
   if (target_name == "c" && t->device_name == "micro_dev") {
     t->device_type = kDLMicroDev;
   } else if (target_name == "c" || target_name == "llvm") {
-    t->keys_array.push_back(ir::StringImm::make("cpu"));
+    t->keys_array.push_back(ir::StringImmNode::make("cpu"));
   } else if (target_name == "cuda" || target_name == "nvptx") {
     t->device_type = kDLGPU;
-    t->keys_array.push_back(ir::StringImm::make("cuda"));
-    t->keys_array.push_back(ir::StringImm::make("gpu"));
+    t->keys_array.push_back(ir::StringImmNode::make("cuda"));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
     t->max_num_threads = 1024;
     t->thread_warp_size = 32;
   } else if (target_name == "rocm" || target_name == "opencl") {
@@ -107,8 +107,8 @@ Target CreateTarget(const std::string& target_name,
     } else {
       t->device_type = kDLROCM;
     }
-    t->keys_array.push_back(ir::StringImm::make(target_name));
-    t->keys_array.push_back(ir::StringImm::make("gpu"));
+    t->keys_array.push_back(ir::StringImmNode::make(target_name));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
     t->max_num_threads = 256;
     if (t->device_name == "intel_graphics") {
       t->thread_warp_size = 16;
@@ -119,20 +119,20 @@ Target CreateTarget(const std::string& target_name,
     } else {
       t->device_type = kDLVulkan;
     }
-    t->keys_array.push_back(ir::StringImm::make(target_name));
-    t->keys_array.push_back(ir::StringImm::make("gpu"));
+    t->keys_array.push_back(ir::StringImmNode::make(target_name));
+    t->keys_array.push_back(ir::StringImmNode::make("gpu"));
     t->max_num_threads = 256;
   } else if (target_name == "sdaccel") {
     t->device_type = kDLOpenCL;
-    t->keys_array.push_back(ir::StringImm::make("sdaccel"));
-    t->keys_array.push_back(ir::StringImm::make("hls"));
+    t->keys_array.push_back(ir::StringImmNode::make("sdaccel"));
+    t->keys_array.push_back(ir::StringImmNode::make("hls"));
   } else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
     t->device_type = kDLAOCL;
-    t->keys_array.push_back(ir::StringImm::make("aocl"));
-    t->keys_array.push_back(ir::StringImm::make("hls"));
+    t->keys_array.push_back(ir::StringImmNode::make("aocl"));
+    t->keys_array.push_back(ir::StringImmNode::make("hls"));
   } else if (target_name == "opengl") {
     t->device_type = kOpenGL;
-    t->keys_array.push_back(ir::StringImm::make("opengl"));
+    t->keys_array.push_back(ir::StringImmNode::make("opengl"));
   } else if (target_name == "stackvm") {
     t->device_type = kDLCPU;
   } else if (target_name == "ext_dev") {
@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("_TargetFromString")
 std::vector<std::string> TargetNode::keys() const {
   std::vector<std::string> result;
   for (auto& expr : keys_array) {
-    result.push_back(expr.as<ir::StringImm>()->value);
+    result.push_back(expr.as<ir::StringImmNode>()->value);
   }
   return result;
 }
@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
 std::vector<std::string> TargetNode::options() const {
   std::vector<std::string> result;
   for (auto& expr : options_array) {
-    result.push_back(expr.as<ir::StringImm>()->value);
+    result.push_back(expr.as<ir::StringImmNode>()->value);
   }
   return result;
 }
@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
 std::unordered_set<std::string> TargetNode::libs() const {
   std::unordered_set<std::string> result;
   for (auto& expr : libs_array) {
-    result.insert(expr.as<ir::StringImm>()->value);
+    result.insert(expr.as<ir::StringImmNode>()->value);
   }
   return result;
 }
@@ -348,7 +348,7 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
   bool has_any = false;
   if (!compact) {
     for (const auto& it : shape) {
-      if (it.as<Variable>()) {
+      if (it.as<VarNode>()) {
         has_any = true;
         break;
       }
@@ -860,7 +860,7 @@ TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
 
   std::vector<std::string> tags_vector;
   for (auto& tag : tags) {
-    tags_vector.push_back(tag.as<tvm::ir::StringImm>()->value);
+    tags_vector.push_back(tag.as<tvm::ir::StringImmNode>()->value);
   }
 
   generic_func
index a3f1459..ea9d7ba 100644 (file)
@@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign(
 
 // Print a reference expression to a buffer.
 std::string CodeGenC::GetBufferRef(
-    DataType t, const Variable* buffer, Expr index) {
+    DataType t, const VarNode* buffer, Expr index) {
   std::ostringstream os;
   std::string vid = GetVarID(buffer);
   std::string scope;
@@ -265,13 +265,13 @@ std::string CodeGenC::GetStructRef(
 }
 
 
-bool CodeGenC::HandleTypeMatch(const Variable* buf_var, DataType t) const {
+bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
   auto it = handle_data_type_.find(buf_var);
   if (it == handle_data_type_.end()) return false;
   return it->second == t;
 }
 
-void CodeGenC::RegisterHandleType(const Variable* buf_var, DataType t) {
+void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) {
   auto it = handle_data_type_.find(buf_var);
   if (it == handle_data_type_.end()) {
     handle_data_type_[buf_var] = t;
@@ -296,11 +296,11 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
 }
 
 std::string CodeGenC::GetVecLoad(
-    DataType t, const Variable* buffer, Expr base) {
+    DataType t, const VarNode* buffer, Expr base) {
   return GetBufferRef(t, buffer, base);
 }
 
-void CodeGenC::PrintVecStore(const Variable* buffer,
+void CodeGenC::PrintVecStore(const VarNode* buffer,
                              DataType t, Expr base,
                              const std::string& value) {
   std::string ref = GetBufferRef(t, buffer, base);
@@ -321,7 +321,7 @@ void CodeGenC::BindThreadIndex(const IterVar& iv) {
   LOG(FATAL) << "not implemented";
 }
 
-void CodeGenC::PrintStorageSync(const Call* op) { // NOLINT(*)
+void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*)
 }
 
 void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
@@ -359,7 +359,7 @@ void CodeGenC::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
 }
 
 
-inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
   if (op->dtype == DataType::Int(32)) {
     std::ostringstream temp;
     temp << op->value;
@@ -372,7 +372,7 @@ inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOL
   }
 }
 
-inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const UIntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
   if (op->dtype == DataType::UInt(32)) {
     std::ostringstream temp;
     temp << op->value << "U";
@@ -385,7 +385,7 @@ inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NO
   }
 }
 
-inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
   switch (op->dtype.bits()) {
     case 64: case 32: {
       std::ostringstream temp;
@@ -405,16 +405,16 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N
   }
 }
 
-void CodeGenC::VisitExpr_(const IntImm* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
   PrintConst(op, os, this);
 }
-void CodeGenC::VisitExpr_(const UIntImm* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const UIntImmNode* op, std::ostream& os) {  // NOLINT(*)
   PrintConst(op, os, this);
 }
-void CodeGenC::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
   PrintConst(op, os, this);
 }
-void CodeGenC::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
   os << "\"" << op->value << "\"";
 }
 
@@ -442,7 +442,7 @@ inline void PrintBinaryExpr(const T* op,
   }
 }
 
-inline void PrintBinaryIntrinsic(const Call* op,
+inline void PrintBinaryIntrinsic(const CallNode* op,
                                   const char* opstr,
                                   std::ostream& os,  // NOLINT(*)
                                   CodeGenC* p) {
@@ -457,67 +457,67 @@ inline void PrintBinaryIntrinsic(const Call* op,
     p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os);
   }
 }
-void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) {  // NOLINT(*)
   std::stringstream value;
   this->PrintExpr(op->value, value);
   os << CastFromTo(value.str(), op->value.dtype(), op->dtype);
 }
-void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) {  // NOLINT(*)
   os << GetVarID(op);
 }
-void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "+", os, this);
 }
-void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "-", os, this);
 }
-void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "*", os, this);
 }
-void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "/", os, this);
 }
-void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "%", os, this);
 }
-void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "min", os, this);
 }
-void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "max", os, this);
 }
-void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const EQNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "==", os, this);
 }
-void CodeGenC::VisitExpr_(const NE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const NENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "!=", os, this);
 }
-void CodeGenC::VisitExpr_(const LT* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const LTNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "<", os, this);
 }
-void CodeGenC::VisitExpr_(const LE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const LENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "<=", os, this);
 }
-void CodeGenC::VisitExpr_(const GT* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const GTNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, ">", os, this);
 }
-void CodeGenC::VisitExpr_(const GE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const GENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, ">=", os, this);
 }
-void CodeGenC::VisitExpr_(const And* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const AndNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "&&", os, this);
 }
-void CodeGenC::VisitExpr_(const Or* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const OrNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "||", os, this);
 }
-void CodeGenC::VisitExpr_(const Not* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) {  // NOLINT(*)
   os << '!';
   PrintExpr(op->a, os);
 }
 
-void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
-  if (op->call_type == Call::Extern ||
-      op->call_type == Call::PureExtern) {
+void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
+  if (op->call_type == CallNode::Extern ||
+      op->call_type == CallNode::PureExtern) {
     os << op->name << "(";
     for (size_t i = 0; i < op->args.size(); i++) {
       this->PrintExpr(op->args[i], os);
@@ -526,20 +526,20 @@ void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
       }
     }
     os << ")";
-  } else if (op->is_intrinsic(Call::bitwise_and)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     PrintBinaryIntrinsic(op, " & ", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_xor)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     PrintBinaryIntrinsic(op, " ^ ", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_or)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
     PrintBinaryIntrinsic(op, " | ", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_not)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
     CHECK_EQ(op->args.size(), 1U);
     os << "(~";
     this->PrintExpr(op->args[0], os);
     os << ')';
-  } else if (op->is_intrinsic(Call::shift_left)) {
+  } else if (op->is_intrinsic(CallNode::shift_left)) {
     PrintBinaryIntrinsic(op, " << ", os, this);
-  } else if (op->is_intrinsic(Call::shift_right)) {
+  } else if (op->is_intrinsic(CallNode::shift_right)) {
     PrintBinaryIntrinsic(op, " >> ", os, this);
   } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
     os << "(";
@@ -550,7 +550,7 @@ void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
     PrintExpr(op->args[2], os);
     os << ")";
   } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-    const Load *l = op->args[0].as<Load>();
+    const LoadNode *l = op->args[0].as<LoadNode>();
     CHECK(op->args.size() == 1 && l);
     os << "((";
     this->PrintType(l->dtype.element_of(), os);
@@ -562,28 +562,28 @@ void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
     CHECK_EQ(op->args.size(), 3U);
     os << GetStructRef(
         op->dtype, op->args[0], op->args[1],
-        op->args[2].as<IntImm>()->value);
+        op->args[2].as<IntImmNode>()->value);
   } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
     CHECK_EQ(op->args.size(), 1U);
     os << "(";
     this->PrintExpr(op->args[0], os);
     os << " == NULL)";
-  } else if (op->is_intrinsic(Call::reinterpret)) {
+  } else if (op->is_intrinsic(CallNode::reinterpret)) {
     // generate (*( TYPE *)(&(ARG)))
     os << "(*(";
     this->PrintType(op->dtype, os);
     os << " *)(&(";
     this->PrintExpr(op->args[0], os);
     os << ")))";
-  } else if (op->is_intrinsic(Call::isnan)) {
+  } else if (op->is_intrinsic(CallNode::isnan)) {
     os << "(";
     this->PrintExpr(op->args[0], os);
     os << " != ";
     this->PrintExpr(op->args[0], os);
     os << ")";
   } else {
-    if (op->call_type == Call::Intrinsic ||
-        op->call_type == Call::PureIntrinsic) {
+    if (op->call_type == CallNode::Intrinsic ||
+        op->call_type == CallNode::PureIntrinsic) {
       LOG(FATAL) << "Unresolved intrinsic " << op->name
                  << " with return type " << op->dtype;
     } else {
@@ -610,7 +610,7 @@ void CodeGenC::PrintVecBinaryOp(
   }
 }
 
-void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
   int lanes = op->dtype.lanes();
   // delcare type.
   if (op->dtype.lanes() == 1) {
@@ -663,7 +663,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
   }
 }
 
-void CodeGenC::VisitStmt_(const Store* op) {
+void CodeGenC::VisitStmt_(const StoreNode* op) {
   DataType t = op->value.dtype();
   if (t.lanes() == 1) {
     std::string value = this->PrintExpr(op->value);
@@ -714,14 +714,14 @@ void CodeGenC::VisitStmt_(const Store* op) {
   }
 }
 
-void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
   std::string value = PrintExpr(op->value);
   CHECK(!var_idmap_.count(op->var.get()));
   var_idmap_[op->var.get()] = value;
   os << PrintExpr(op->body);
 }
 
-void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) {  // NOLINT(*)
   // constraint of current logic
   CHECK_EQ(op->base.dtype(), DataType::Int(32));
   os << "((int" << op->lanes << ")(";
@@ -733,15 +733,15 @@ void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) {  // NOLINT(*)
   os << "))";
 }
 
-void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) {
+void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
   LOG(FATAL) << "Shuffle: not supported ";
 }
 
-void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   LOG(FATAL) << "Broadcast: not supported ";
 }
 
-void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) {  // NOLINT(*)
   os << "(";
   PrintExpr(op->condition, os);
   os << " ? ";
@@ -751,7 +751,7 @@ void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
   os << ")";
 }
 
-void CodeGenC::VisitStmt_(const LetStmt* op) {
+void CodeGenC::VisitStmt_(const LetStmtNode* op) {
   std::string value = PrintExpr(op->value);
   if (print_ssa_form_) {
     CHECK(!var_idmap_.count(op->var.get()));
@@ -776,7 +776,7 @@ void CodeGenC::VisitStmt_(const LetStmt* op) {
   PrintStmt(op->body);
 }
 
-void CodeGenC::VisitStmt_(const Allocate* op) {
+void CodeGenC::VisitStmt_(const AllocateNode* op) {
   CHECK(!is_zero(op->condition));
   std::string vid = AllocVarID(op->buffer_var.get());
   if (op->new_expr.defined()) {
@@ -791,7 +791,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
     int32_t constant_size = op->constant_allocation_size();
     CHECK_GT(constant_size, 0)
         << "Can only handle constant size stack allocation for now";
-    const Variable* buffer = op->buffer_var.as<Variable>();
+    const VarNode* buffer = op->buffer_var.as<VarNode>();
     std::string scope = alloc_storage_scope_.at(buffer);
     PrintStorageScope(scope, stream);
     stream << ' ';
@@ -803,7 +803,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) {
   this->PrintStmt(op->body);
 }
 
-void CodeGenC::VisitStmt_(const AttrStmt* op) {
+void CodeGenC::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == ir::attr::thread_extent) {
     IterVar iv = Downcast<IterVar>(op->node);
     if (iv->thread_tag.length() != 0) {
@@ -812,21 +812,21 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
       }
     }
   } else if (op->attr_key == ir::attr::storage_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
-    alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
+    alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
   } else if (op->attr_key == ir::attr::volatile_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     volatile_buf_.insert(v);
   }
   this->PrintStmt(op->body);
 }
 
-void CodeGenC::VisitStmt_(const AssertStmt* op) {
+void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
   std::string cond = PrintExpr(op->condition);
   PrintIndent();
-  if (const auto* str = op->message.as<StringImm>()) {
+  if (const auto* str = op->message.as<StringImmNode>()) {
     // GLOG style check
     stream << "CHECK(" << cond << ") << \"" << str->value << "\";\n";
   } else {
@@ -835,7 +835,7 @@ void CodeGenC::VisitStmt_(const AssertStmt* op) {
   this->PrintStmt(op->body);
 }
 
-void CodeGenC::VisitStmt_(const For* op) {
+void CodeGenC::VisitStmt_(const ForNode* op) {
   std::string extent = PrintExpr(op->extent);
   PrintIndent();
   std::string vid = AllocVarID(op->loop_var.get());
@@ -852,7 +852,7 @@ void CodeGenC::VisitStmt_(const For* op) {
   stream << "}\n";
 }
 
-void CodeGenC::VisitStmt_(const IfThenElse* op) {
+void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
   std::string cond = PrintExpr(op->condition);
   PrintIndent();
   if (cond[0] == '(' && cond[cond.length() - 1] == ')') {
@@ -881,9 +881,9 @@ void CodeGenC::VisitStmt_(const SeqStmtNode* op) {
   }
 }
 
-void CodeGenC::VisitStmt_(const Evaluate* op) {
+void CodeGenC::VisitStmt_(const EvaluateNode* op) {
   if (is_const(op->value)) return;
-  const Call* call = op->value.as<Call>();
+  const CallNode* call = op->value.as<CallNode>();
   if (call) {
     if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
       this->PrintStorageSync(call); return;
@@ -894,7 +894,7 @@ void CodeGenC::VisitStmt_(const Evaluate* op) {
           call->args[3].dtype(),
           call->args[0],
           call->args[1],
-          call->args[2].as<IntImm>()->value);
+          call->args[2].as<IntImmNode>()->value);
       this->PrintIndent();
       this->stream << ref << " = " << value << ";\n";
       return;
@@ -907,7 +907,7 @@ void CodeGenC::VisitStmt_(const Evaluate* op) {
   }
 }
 
-void CodeGenC::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
   PrintStmt(op->body);
 }
 
index eae1e49..593bbcd 100644 (file)
@@ -102,46 +102,46 @@ class CodeGenC :
    */
   virtual void InitFuncState(LoweredFunc f);
   // expression
-  void VisitExpr_(const Variable* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Load* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Let* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Call* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Add* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Sub* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const NE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const LT* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const LE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const GT* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const GE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const And* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Or* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Cast* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Shuffle* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const FloatImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const StringImm* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const VarNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LoadNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LetNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const CallNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const AddNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const SubNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MulNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const DivNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const ModNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MinNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MaxNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const EQNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const NENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LTNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const GTNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const GENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const AndNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const OrNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const CastNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const NotNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const SelectNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const RampNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const ShuffleNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const UIntImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
   // statment
-  void VisitStmt_(const LetStmt* op) override;
-  void VisitStmt_(const Store* op) override;
-  void VisitStmt_(const For* op) override;
-  void VisitStmt_(const IfThenElse* op) override;
-  void VisitStmt_(const Allocate* op) override;
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const Evaluate* op) override;
+  void VisitStmt_(const LetStmtNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const IfThenElseNode* op) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const EvaluateNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
-  void VisitStmt_(const ProducerConsumer* op) override;
+  void VisitStmt_(const ProducerConsumerNode* op) override;
   /*!
    * Print Type represetnation of type t.
    * \param t The type representation.
@@ -154,15 +154,15 @@ class CodeGenC :
    */
   virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*)
   virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*)
-  virtual void PrintStorageSync(const Call* op);  // NOLINT(*)
+  virtual void PrintStorageSync(const CallNode* op);  // NOLINT(*)
   // Binary vector op.
   virtual void PrintVecBinaryOp(
       const std::string&op, DataType op_type,
       Expr lhs, Expr rhs, std::ostream& os);  // NOLINT(*)
   // print vector load
-  virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base);
+  virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base);
   // print vector store
-  virtual void PrintVecStore(const Variable* buffer,
+  virtual void PrintVecStore(const VarNode* buffer,
                              DataType t, Expr base,
                              const std::string& value);  // NOLINT(*)
   // print load of single element
@@ -180,28 +180,28 @@ class CodeGenC :
       DataType t, const Expr& buffer, const Expr& index, int kind);
   // print reference to a buffer as type t in index.
   virtual std::string GetBufferRef(
-      DataType t, const Variable* buffer, Expr index);
+      DataType t, const VarNode* buffer, Expr index);
   /*!
    * \brief If buffer is allocated as type t.
    * \param buf_var The buffer variable.
    * \param t The type to be checked.
    */
-  bool HandleTypeMatch(const Variable* buf_var, DataType t) const;
+  bool HandleTypeMatch(const VarNode* buf_var, DataType t) const;
   /*!
    * \brief Register the data type of buf_var
    * \param buf_var The buffer variable.
    * \param t The type to be checked.
    */
-  void RegisterHandleType(const Variable* buf_var, DataType t);
+  void RegisterHandleType(const VarNode* buf_var, DataType t);
   // override
   void PrintSSAAssign(
       const std::string& target, const std::string& src, DataType t) final;
   /*! \brief restrict keyword */
   std::string restrict_keyword_{""};
   /*! \brief the storage scope of allocation */
-  std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
+  std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
   /*! \brief the data type of allocated buffers */
-  std::unordered_map<const Variable*, DataType> handle_data_type_;
+  std::unordered_map<const VarNode*, DataType> handle_data_type_;
   /*! \brief reserves common C keywords */
   void ReserveKeywordsAsUnique();
 
@@ -209,7 +209,7 @@ class CodeGenC :
   /*! \brief whether to print in SSA form */
   bool print_ssa_form_{false};
   /*! \brief set of volatile buf access */
-  std::unordered_set<const Variable*> volatile_buf_;
+  std::unordered_set<const VarNode*> volatile_buf_;
 };
 
 }  // namespace codegen
index 5066182..85751c2 100644 (file)
@@ -142,7 +142,7 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Cannot convert type " << t << " to C type";
 }
 
-void CodeGenCHost::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   std::string v = PrintExpr(op->value);
   os << "((";
   PrintType(op->dtype, os);
@@ -194,11 +194,11 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar
   this->stream << "}\n";
 }
 
-void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*)
   if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
     std::string stack_name = GetUniqueName("stack");
-    const std::string& type = op->args[0].as<StringImm>()->value;
-    const IntImm* num = op->args[1].as<IntImm>();
+    const std::string& type = op->args[0].as<StringImmNode>()->value;
+    const IntImmNode* num = op->args[1].as<IntImmNode>();
     CHECK(num != nullptr);
     static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
     size_t unit = sizeof(TVMValue);
@@ -218,10 +218,10 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
     this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
     os << stack_name;
   } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
-    const StringImm* s = op->args[0].as<StringImm>();
+    const StringImmNode* s = op->args[0].as<StringImmNode>();
     CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
-    int64_t begin = op->args[3].as<IntImm>()->value;
-    int64_t end = op->args[4].as<IntImm>()->value;
+    int64_t begin = op->args[3].as<IntImmNode>()->value;
+    int64_t end = op->args[4].as<IntImmNode>()->value;
     int64_t num_args = end - begin;
     CHECK_GE(num_args, 0);
     std::string func_name = s->value;
@@ -237,14 +237,14 @@ void CodeGenCHost::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
   }
 }
 
-void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
+void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
   if (emit_asserts_) {
     std::string cond = PrintExpr(op->condition);
     PrintIndent();
     stream << "if (!(" << cond << ")) {\n";
     int assert_if_scope = this->BeginScope();
     PrintIndent();
-    stream << "TVMAPISetLastError(\"" << op->message.as<StringImm>()->value << "\");\n";
+    stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
     PrintIndent();
     stream << "return -1;\n";
     this->EndScope(assert_if_scope);
@@ -254,11 +254,11 @@ void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*)
   this->PrintStmt(op->body);
 }
 
-void CodeGenCHost::VisitExpr_(const Min *op, std::ostream& os) {  // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) {  // NOLINT(*)
   PrintTernaryCondExpr(op, "<", os);
 }
 
-void CodeGenCHost::VisitExpr_(const Max *op, std::ostream& os) {  // NOLINT(*)
+void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) {  // NOLINT(*)
   PrintTernaryCondExpr(op, ">", os);
 }
 
index 44f8385..43fe98d 100644 (file)
@@ -42,14 +42,14 @@ class CodeGenCHost final : public CodeGenC {
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
 
   // overload visitor functions
-  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const Call *op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*)
   // overload min and max to use the ternary operator, so we don't rely on the
   // standard library implementations
-  void VisitExpr_(const Min *op, std::ostream& os) final;  // NOLINT(*)
-  void VisitExpr_(const Max *op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const MinNode *op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const MaxNode *op, std::ostream& os) final;  // NOLINT(*)
 
-  void VisitStmt_(const AssertStmt *op) final; // NOLINT(*)
+  void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
 
  private:
   std::string module_name_;
index 06b542a..53a008d 100644 (file)
@@ -93,7 +93,7 @@ std::string CodeGenCUDA::Finish() {
   return CodeGenC::Finish();
 }
 
-void CodeGenCUDA::VisitStmt_(const ir::For* op) {
+void CodeGenCUDA::VisitStmt_(const ir::ForNode* op) {
   CHECK(is_const_int(op->min, 0));
   if (op->for_type == ir::ForType::Unrolled) {
     PrintIndent();
@@ -265,8 +265,8 @@ void CodeGenCUDA::PrintVecElemStore(
   }
 }
 
-void CodeGenCUDA::PrintStorageSync(const Call* op) {
-  const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
+  const std::string& sync = op->args[0].as<StringImmNode>()->value;
   if (sync == "warp") {
     // DO nothing.
   } else if (sync == "shared") {
@@ -314,7 +314,7 @@ void CodeGenCUDA::PrintStorageScope(
   }
 }
 
-void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
+void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
   if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
     need_mma_h_ = true;
     CHECK_EQ(op->args.size(), 6U);
@@ -348,7 +348,7 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
     this->PrintExpr(op->args[4], os);
     os << "], ";
     this->PrintExpr(op->args[6], os);
-    if (const StringImm *str = op->args[7].as<StringImm>()) {
+    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
       os << ", nvcuda::wmma::mem_" << str->value;
     } else {
       LOG(FATAL) << "Invalid parameters";
@@ -369,20 +369,20 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
   }
 }
 
-void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
+void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::fragment_shape) {
-    const Variable* buffer = op->node.as<Variable>();
-    const StringImm* shape_str = op->value.as<StringImm>();
+    const VarNode* buffer = op->node.as<VarNode>();
+    const StringImmNode* shape_str = op->value.as<StringImmNode>();
     fragment_shapes[buffer] = shape_str->value;
   } else if (op->attr_key == attr::fragment_layout) {
-    const Variable* buffer = op->node.as<Variable>();
-    const StringImm* layout_str = op->value.as<StringImm>();
+    const VarNode* buffer = op->node.as<VarNode>();
+    const StringImmNode* layout_str = op->value.as<StringImmNode>();
     fragment_layouts[buffer] = layout_str->value;
   }
   CodeGenC::VisitStmt_(op);
 }
 
-void CodeGenCUDA::VisitStmt_(const Allocate* op) {
+void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
   CHECK(!is_zero(op->condition));
   std::string vid = AllocVarID(op->buffer_var.get());
   if (op->new_expr.defined()) {
@@ -397,7 +397,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
     int32_t constant_size = op->constant_allocation_size();
     CHECK_GT(constant_size, 0)
       << "Can only handle constant size stack allocation for now";
-    const Variable* buffer = op->buffer_var.as<Variable>();
+    const VarNode* buffer = op->buffer_var.as<VarNode>();
     std::string scope = alloc_storage_scope_.at(buffer);
     if (scope.find("wmma.") == 0) {
       if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
@@ -425,9 +425,9 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
   this->PrintStmt(op->body);
 }
 
-void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
+void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) {
   if (is_const(op->value)) return;
-  const Call* call = op->value.as<Call>();
+  const CallNode* call = op->value.as<CallNode>();
   if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) {
     PrintIndent();
     stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
@@ -442,7 +442,7 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
   }
 }
 
-void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
+void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
   os << "((make_int" << op->lanes << ")(";
   for (int i = 0; i < op->lanes; i++) {
     os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
@@ -452,7 +452,7 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
   os << "))";
 }
 
-void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
     // make_int8x4
     const int64_t *p = as_const_int(op->value);
@@ -474,7 +474,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLIN
   os << ')';
 }
 
-void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
+void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
   std::vector<std::string> to_shuffle(op->vectors.size());
   for (int i = 0, e = op->vectors.size(); i < e; ++i) {
     CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
@@ -492,7 +492,7 @@ void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
   os << ')';
 }
 
-inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
+inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
   switch (op->dtype.bits()) {
     case 64: case 32: {
       std::ostringstream temp;
@@ -523,12 +523,12 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { /
 }
 
 
-void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
+void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
   PrintConst(op, os, this);
 }
 
 void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
-    const Variable* variable, std::ostream &os) {
+    const VarNode* variable, std::ostream &os) {
   std::stringstream type;
   PrintType(t, type);
   std::string shape_str = fragment_shapes[variable];
@@ -550,7 +550,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
 }
 
 int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
-                                         const Variable* variable, int32_t size) {
+                                         const VarNode* variable, int32_t size) {
   std::string shape_str = fragment_shapes[variable];
   size_t m, n, k;
   size_t last_pos = 0, pos = 0;
index 74d6fba..fc2e6ae 100644 (file)
@@ -43,8 +43,8 @@ class CodeGenCUDA final : public CodeGenC {
     return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
-  void VisitStmt_(const ir::For* op) final;
-  void PrintStorageSync(const Call* op) final;
+  void VisitStmt_(const ir::ForNode* op) final;
+  void PrintStorageSync(const CallNode* op) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
   void PrintVecBinaryOp(
       const std::string&op, DataType t,
@@ -56,14 +56,14 @@ class CodeGenCUDA final : public CodeGenC {
       const std::string& vec, DataType t, int i, const std::string& value) final;
   void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
   // overload visitor
-  void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const FloatImm *op, std::ostream& os) final;
-  void VisitExpr_(const Call *op, std::ostream& os) final;
-  void VisitStmt_(const Evaluate *op) final;
-  void VisitStmt_(const Allocate *op) final;
-  void VisitStmt_(const AttrStmt *op) final;
+  void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
+  void VisitExpr_(const CallNode *op, std::ostream& os) final;
+  void VisitStmt_(const EvaluateNode *op) final;
+  void VisitStmt_(const AllocateNode *op) final;
+  void VisitStmt_(const AttrStmtNode *op) final;
 
  private:
   // Whether global barrier is needed.
@@ -81,13 +81,13 @@ class CodeGenCUDA final : public CodeGenC {
   // whether need mma.h
   bool need_mma_h_{false};
 
-  std::unordered_map<const Variable*, std::string> fragment_shapes;
-  std::unordered_map<const Variable*, std::string> fragment_layouts;
-  friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
+  std::unordered_map<const VarNode*, std::string> fragment_shapes;
+  std::unordered_map<const VarNode*, std::string> fragment_layouts;
+  friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
   void PrintWmmaScope(
-      const std::string& scope, DataType t, const Variable* variable, std::ostream& os);
+      const std::string& scope, DataType t, const VarNode* variable, std::ostream& os);
   int32_t GetWmmaFragmentSize(
-      const std::string &scope, const Variable* variable, int32_t size);
+      const std::string &scope, const VarNode* variable, int32_t size);
 };
 
 }  // namespace codegen
index b239578..4e92fcb 100644 (file)
@@ -196,8 +196,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
 }
 
-void CodeGenMetal::PrintStorageSync(const Call* op) {
-  const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenMetal::PrintStorageSync(const CallNode* op) {
+  const std::string& sync = op->args[0].as<StringImmNode>()->value;
   if (sync == "warp") {
     this->PrintIndent();
     this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n";
@@ -234,7 +234,7 @@ void CodeGenMetal::PrintStorageScope(
   }
 }
 
-void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   std::string v = PrintExpr(op->value);
   PrintType(op->dtype, os);
   os << "(";
@@ -245,8 +245,8 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLI
   os << ')';
 }
 
-void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
-  if (op->is_intrinsic(Call::reinterpret)) {
+void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
+  if (op->is_intrinsic(CallNode::reinterpret)) {
     // generate as_type<TYPE>(ARG)
     os << "(as_type<";
     this->PrintType(op->dtype, os);
index 728e3e0..d9c5e95 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -40,7 +40,7 @@ class CodeGenMetal final : public CodeGenC {
   void PrintArgUnionDecl();
   void InitFuncState(LoweredFunc f) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
-  void PrintStorageSync(const Call* op) final;  // NOLINT(*)
+  void PrintStorageSync(const CallNode* op) final;  // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
   void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
   // print load of single element
@@ -50,10 +50,10 @@ class CodeGenMetal final : public CodeGenC {
   void PrintVecElemStore(
       const std::string& vec, DataType t, int i, const std::string& value) final;
   // overload visitor
-  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
 
   // overload visitor
-  void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
 
  private:
   int thread_index_bits_{32};
index e466e28..8914db8 100644 (file)
@@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
 }
 
-void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
+void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
                                  Expr base, std::ostream& os) {  // NOLINT(*)
   if (!HandleTypeMatch(buffer, t.element_of())) {
     os << '(';
@@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t,
   PrintExpr(base, os);
 }
 std::string CodeGenOpenCL::GetVecLoad(
-    DataType t, const Variable* buffer, Expr base) {
+    DataType t, const VarNode* buffer, Expr base) {
   std::ostringstream os;
   os << "vload" << t.lanes() << "(0, ";
   PrintVecAddr(buffer, t, base, os);
@@ -168,7 +168,7 @@ std::string CodeGenOpenCL::GetVecLoad(
   return os.str();
 }
 
-void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
+void CodeGenOpenCL::PrintVecStore(const VarNode* buffer,
                                   DataType t, Expr base,
                                   const std::string& value) {
   this->PrintIndent();
@@ -177,8 +177,8 @@ void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
   stream << ");\n";
 }
 
-void CodeGenOpenCL::PrintStorageSync(const Call* op) {
-  const std::string& sync = op->args[0].as<StringImm>()->value;
+void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
+  const std::string& sync = op->args[0].as<StringImmNode>()->value;
   if (sync == "warp") {
     this->PrintIndent();
     this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
@@ -215,7 +215,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType
   return os.str();
 }
 
-void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   std::string v = PrintExpr(op->value);
   os << "((";
   PrintType(op->dtype, os);
@@ -227,7 +227,7 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOL
   os << "))";
 }
 
-void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) {  // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const CallNode *op, std::ostream& os) {  // NOLINT(*)
   /* Return type of ternary expression is not always same as its sub-expressions,
    * add a cast */
   if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
@@ -238,7 +238,7 @@ void CodeGenOpenCL::VisitExpr_(const Call *op, std::ostream& os) {  // NOLINT(*)
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) {  // NOLINT(*)
   /* Return type of ternary expression is not always same as its sub-expressions,
    * add a cast */
   os << "(";
@@ -247,7 +247,7 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*)
+void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*)
   if (std::isinf(op->value)) {
     if (op->value < 0) {
       os << "-";
index 36324eb..a606a3a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -42,23 +42,23 @@ class CodeGenOpenCL final : public CodeGenC {
   void InitFuncState(LoweredFunc f) final;
   void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
   void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
-  void PrintStorageSync(const Call* op) final;  // NOLINT(*)
+  void PrintStorageSync(const CallNode* op) final;  // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
-  std::string GetVecLoad(DataType t, const Variable* buffer,
+  std::string GetVecLoad(DataType t, const VarNode* buffer,
                          Expr base) final;
-  void PrintVecStore(const Variable* buffer,
+  void PrintVecStore(const VarNode* buffer,
                      DataType t, Expr base,
                      const std::string& value) final;  // NOLINT(*)
   // the address of load/store
-  void PrintVecAddr(const Variable* buffer, DataType t,
+  void PrintVecAddr(const VarNode* buffer, DataType t,
                     Expr base, std::ostream& os);  // NOLINT(*)
   std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*)
 
   // overload visitor
-  void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*)
-  void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
 
  private:
   // whether enable fp16 and fp64 extension
index 29fcf85..5666de3 100644 (file)
@@ -188,13 +188,13 @@ void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
   this->stream << "}\n";
 }
 
-void CodeGenOpenGL::VisitStmt_(const Store* op) {
+void CodeGenOpenGL::VisitStmt_(const StoreNode* op) {
   LOG(FATAL) << "Store statement not supported in OpenGL."
              << " Texture store should be a Call statement.";
 }
 
 // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
-std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
+std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) {
   std::ostringstream os;
   os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
   PrintExpr(index, os);
@@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
 // Print a reference expression to a buffer.
 // Format: texelFetch(buffer, index, 0).r
 std::string CodeGenOpenGL::GetBufferRef(
-    DataType t, const Variable* buffer, Expr index) {
+    DataType t, const VarNode* buffer, Expr index) {
   CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
   CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
 
@@ -242,34 +242,34 @@ void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) {
 
 // Codegen for immediate values
 
-void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) {
   CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const UIntImmNode* op, std::ostream& os) {
   CHECK_EQ(op->dtype, DataType::UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {
   CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
   CodeGenC::VisitExpr_(op, os);
 }
 
-void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
+void CodeGenOpenGL::VisitExpr_(const StringImmNode*, std::ostream& os) {
   LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
 }
 
-void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
-  auto call = op->value.as<Call>();
-  if (call == nullptr || call->name != Call::glsl_texture_store) {
+void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
+  auto call = op->value.as<CallNode>();
+  if (call == nullptr || call->name != CallNode::glsl_texture_store) {
     // Fallback to normal logic.
     CodeGenC::VisitStmt_(op);
   }
 
   CHECK_EQ(call->args.size(), 2);
-  auto buffer = call->args[0].as<Variable>();
+  auto buffer = call->args[0].as<VarNode>();
   auto value = call->args[1];
 
   // Doesn't support store to vector.
index 46e87a8..bb69365 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -43,24 +43,24 @@ class CodeGenOpenGL final : public CodeGenC {
 
   void InitFuncState(LoweredFunc f) final;
   void BindThreadIndex(const IterVar& iv) final;
-  void VisitStmt_(const Store* op) final;
-  std::string TexelFetch(const Variable* buffer, Expr index);
-  std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final;
+  void VisitStmt_(const StoreNode* op) final;
+  std::string TexelFetch(const VarNode* buffer, Expr index);
+  std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final;
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
 
   // Codegen for immediate values
-  void VisitExpr_(const IntImm* op, std::ostream& os) final;  // NOLINT(*)
-  void VisitExpr_(const UIntImm* op, std::ostream& os) final;  // NOLINT(*)
-  void VisitExpr_(const FloatImm* op, std::ostream& os) final;  // NOLINT(*)
-  void VisitExpr_(const StringImm* op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const IntImmNode* op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const UIntImmNode* op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;  // NOLINT(*)
+  void VisitExpr_(const StringImmNode* op, std::ostream& os) final;  // NOLINT(*)
 
   // Match glsl_texture_store Call.
-  void VisitStmt_(const Evaluate* op) final;  // NOLINT(*)
+  void VisitStmt_(const EvaluateNode* op) final;  // NOLINT(*)
 
  private:
-  const Variable* output_{nullptr};
-  std::unordered_set<const Variable*> inputs_;
-  const Variable* output_iter_var_{nullptr};
+  const VarNode* output_{nullptr};
+  std::unordered_set<const VarNode*> inputs_;
+  const VarNode* output_iter_var_{nullptr};
   std::unordered_map<std::string, runtime::OpenGLShader> shaders_;
   std::string thread_extent_var_;
 };
index 2b11d11..aa3b6ef 100644 (file)
@@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) {
   return e.vid;
 }
 
-std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
+std::string CodeGenSourceBase::AllocVarID(const VarNode* v) {
   CHECK(!var_idmap_.count(v))
       << "Need input to be in SSA form dup " << v->name_hint;
   std::string key = v->name_hint;
@@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const Variable* v) {
   return vid;
 }
 
-std::string CodeGenSourceBase::GetVarID(const Variable* v) const {
+std::string CodeGenSourceBase::GetVarID(const VarNode* v) const {
   auto it = var_idmap_.find(v);
   CHECK(it != var_idmap_.end())
       << "Find undefined Variable " << v->name_hint;
index 7fd0eef..b39ee46 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -66,13 +66,13 @@ class CodeGenSourceBase {
    * \param v The variable.
    * \return the variable name.
    */
-  std::string AllocVarID(const Variable* v);
+  std::string AllocVarID(const VarNode* v);
   /*!
    * \brief Get a variable name.
    * \param v The variable.
    * \return the variable name.
    */
-  std::string GetVarID(const Variable* v) const;
+  std::string GetVarID(const VarNode* v) const;
   /*!
    * \brief Get the SSA ID corresponds to src
    *  If necessary, generate new assignment
@@ -110,7 +110,7 @@ class CodeGenSourceBase {
   /*! \brief the stream to be printed */
   std::ostringstream stream;
   /*! \brief name of each variable */
-  std::unordered_map<const Variable*, std::string> var_idmap_;
+  std::unordered_map<const VarNode*, std::string> var_idmap_;
 
  private:
   /*! \brief assignment map of ssa */
index d12e54d..e7231a1 100644 (file)
@@ -98,7 +98,7 @@ inline void PrintBinaryExpr(const T* op,
   os << ')';
 }
 
-void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) {  // NOLINT(*)
+void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) {  // NOLINT(*)
   const char *opstr = "std::min";
   if (op->dtype.is_float()) {
     switch (op->dtype.bits()) {
@@ -112,7 +112,7 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) {  // NOLINT(
   PrintBinaryExpr(op, opstr, os, this);
 }
 
-void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) {  // NOLINT(*)
+void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) {  // NOLINT(*)
   const char *opstr = "std::max";
   if (op->dtype.is_float()) {
     switch (op->dtype.bits()) {
index e678edb..e406cb5 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,8 +38,8 @@ class CodeGenVivadoHLS final : public CodeGenC {
   void PrintType(DataType t, std::ostream& os);
   void AddFunction(LoweredFunc f);
   void PreFunctionBody(LoweredFunc f);
-  void VisitExpr_(const Min *op, std::ostream& os);
-  void VisitExpr_(const Max *op, std::ostream& os);
+  void VisitExpr_(const MinNode *op, std::ostream& os);
+  void VisitExpr_(const MaxNode *op, std::ostream& os);
 };
 
 }  // namespace codegen
index 219b485..571ec52 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -54,7 +54,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt")
 .set_body([](const TVMArgs& args, TVMRetValue* rv){
     Expr e = args[0];
-    const Call* call = e.as<Call>();
+    const CallNode* call = e.as<CallNode>();
     CHECK(call != nullptr);
 
     auto one = make_const(call->args[0].dtype(), 1);
@@ -67,7 +67,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
 .set_body([](const TVMArgs& args, TVMRetValue* rv){
     Expr e = args[0];
-    const Call* call = e.as<Call>();
+    const CallNode* call = e.as<CallNode>();
     CHECK(call != nullptr);
 
     auto one = make_const(call->args[0].dtype(), 1);
index f64887e..a0665bf 100644 (file)
@@ -61,12 +61,12 @@ struct Direct {
 template<typename T>
 inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
   Expr e = args[0];
-  const Call* call = e.as<Call>();
+  const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   std::string name = T()(call->dtype, call->name);
   if (name.length() != 0) {
-    *rv = Call::make(
-        call->dtype, name, call->args, Call::PureExtern);
+    *rv = CallNode::make(
+        call->dtype, name, call->args, CallNode::PureExtern);
   } else {
     *rv = e;
   }
index a2b3685..397f9d3 100644 (file)
@@ -70,7 +70,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
     function_->addFnAttr("amdgpu-flat-work-group-size", attr.str());
   }
 
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     CHECK(!is_zero(op->condition));
     llvm::Value* buf = nullptr;
     if (op->new_expr.defined()) {
@@ -153,8 +153,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
     return builder_->CreateCall(f, {});
   }
 
-  llvm::Value* CreateStorageSync(const Call* op) final {
-    const std::string& sync = op->args[0].as<StringImm>()->value;
+  llvm::Value* CreateStorageSync(const CallNode* op) final {
+    const std::string& sync = op->args[0].as<StringImmNode>()->value;
     if (sync == "warp") {
       return nullptr;
     } else if (sync == "shared") {
@@ -234,7 +234,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
   Array<Expr> bitcode_files = (*find_rocm_bitcodes)();
 
   for (auto &bitcode : bitcode_files) {
-    std::string path = bitcode.as<StringImm>()->value;
+    std::string path = bitcode.as<StringImmNode>()->value;
     llvm::SMDiagnostic err;
     std::unique_ptr<llvm::Module> mlib = llvm::parseIRFile(path, err, *ctx);
     if (mlib.get() == nullptr) {
index 39d5114..fdc1b42 100644 (file)
@@ -39,25 +39,25 @@ class CodeGenARM final : public CodeGenCPU {
     native_vector_bits_ = 16 * 8;
     CodeGenCPU::InitTarget(tm);
   }
-  llvm::Value* CreateIntrinsic(const Call* op) override;
+  llvm::Value* CreateIntrinsic(const CallNode* op) override;
 
  private:
-  Expr ARMPopcount(const Call* op);
+  Expr ARMPopcount(const CallNode* op);
 };
 
-llvm::Value* CodeGenARM::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic("llvm_intrin")) {
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
-        op->args[0].as<UIntImm>()->value);
+        op->args[0].as<UIntImmNode>()->value);
     if (id == ::llvm::Intrinsic::ctpop) {
       Expr e = ARMPopcount(op);
-      return CodeGenCPU::CreateIntrinsic(e.as<Call>());
+      return CodeGenCPU::CreateIntrinsic(e.as<CallNode>());
     }
   }
   return CodeGenCPU::CreateIntrinsic(op);
 }
 
-Expr CodeGenARM::ARMPopcount(const Call *call) {
+Expr CodeGenARM::ARMPopcount(const CallNode *call) {
   using namespace ir;
   const Expr& e = call->args[2];
   ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop;
@@ -68,10 +68,10 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
   if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
      (total_size != 128 && total_size != 64)) {
     Array<Expr> vcnt_args;
-    vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
-    vcnt_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+    vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
+    vcnt_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
     vcnt_args.push_back(e);
-    return ir::Call::make(call->dtype,  "llvm_intrin", vcnt_args, Call::PureIntrinsic);
+    return ir::CallNode::make(call->dtype,  "llvm_intrin", vcnt_args, CallNode::PureIntrinsic);
   }
 
   // Popcount lowering rule:
@@ -90,40 +90,44 @@ Expr CodeGenARM::ARMPopcount(const Call *call) {
   // Interpret input as vector of 8bit values
   Expr input8 = reinterpret(uint8_type, e);
   // Popcount 8bit->8bit
-  const Call* c0 = input8.as<Call>();
+  const CallNode* c0 = input8.as<CallNode>();
   CHECK(c0 != nullptr);
   Array<Expr> vcnt8_args;
-  vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), ctpop_id));
-  vcnt8_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+  vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), ctpop_id));
+  vcnt8_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt8_args.push_back(input8);
-  Expr vcnt8 = ir::Call::make(uint8_type,  "llvm_intrin", vcnt8_args, Call::PureIntrinsic);
+  Expr vcnt8 = ir::CallNode::make(
+    uint8_type,  "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic);
 
   // Accumulation 8->16bit
   Array<Expr> vcnt16_args;
-  vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
-  vcnt16_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+  vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+  vcnt16_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt16_args.push_back(vcnt8);
-  Expr vcnt16 = ir::Call::make(uint16_type, "llvm_intrin", vcnt16_args, Call::PureIntrinsic);
+  Expr vcnt16 = ir::CallNode::make(
+    uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic);
   if (call->dtype.bits() == 16) {
     return vcnt16;
   }
 
   // Accumulation 16->32bit
   Array<Expr> vcnt32_args;
-  vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
-  vcnt32_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+  vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+  vcnt32_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt32_args.push_back(vcnt16);
-  Expr vcnt32 = ir::Call::make(uint32_type,  "llvm_intrin", vcnt32_args, Call::PureIntrinsic);
+  Expr vcnt32 = ir::CallNode::make(
+    uint32_type,  "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic);
   if (call->dtype.bits() == 32) {
     return vcnt32;
   }
 
   // Accumulation 32->64bit
   Array<Expr> vcnt64_args;
-  vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), vpaddlu_id));
-  vcnt64_args.push_back(ir::UIntImm::make(DataType::UInt(32), 1));
+  vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), vpaddlu_id));
+  vcnt64_args.push_back(ir::UIntImmNode::make(DataType::UInt(32), 1));
   vcnt64_args.push_back(vcnt32);
-  return ir::Call::make(call->dtype,  "llvm_intrin", vcnt64_args, Call::PureIntrinsic);
+  return ir::CallNode::make(
+    call->dtype,  "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic);
 }
 
 TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm")
index 9f1a292..0622269 100644 (file)
@@ -319,7 +319,7 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(
   }
 }
 
-llvm::Value* CodeGenCPU::CreateCallExtern(const Call* op) {
+llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
   std::vector<llvm::Value*> arg_values(op->args.size());
   for (size_t i = 0; i < op->args.size(); ++i) {
     arg_values[i] = MakeValue(op->args[i]);
@@ -417,7 +417,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
   return end_block;
 }
 
-void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
+void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
   // There are two reasons why we create another function for compute_scope
   // - Make sure the generated compute function is clearly separately(though it can get inlined)
   // - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
@@ -436,12 +436,12 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
   llvm::Function* fcompute =
       llvm::Function::Create(ftype,
                              llvm::Function::PrivateLinkage,
-                             op->value.as<StringImm>()->value,
+                             op->value.as<StringImmNode>()->value,
                              module_.get());
   BasicBlock* compute_call_end = CheckCallSuccess(
       builder_->CreateCall(fcompute, arg_values));
   // setup compute fuinction.
-  std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+  std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
   size_t idx = 0;
   for (auto it = fcompute->arg_begin();
        it != fcompute->arg_end(); ++it, ++idx) {
@@ -497,7 +497,7 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* nu
 
 void CodeGenCPU::UnpackClosureData(llvm::Value* cdata,
                                    const Array<Var>& vfields,
-                                   std::unordered_map<const Variable*, llvm::Value*>* vmap) {
+                                   std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
   for (size_t i = 0; i < vfields.size(); ++i) {
     (*vmap)[vfields[i].get()] =
         builder_->CreateLoad(builder_->CreateInBoundsGEP(
@@ -528,7 +528,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
   llvm::Value* penv = &(*it++);
   cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
   // setup new variable map, swap it with current var context.
-  std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+  std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
   UnpackClosureData(cdata, vfields, &new_vmap);
   // setup parallel env
   ParallelEnv par_env;
@@ -594,7 +594,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
   auto it = f->arg_begin();
   cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
   // setup new variable map, swap it with current var context.
-  std::unordered_map<const Variable*, llvm::Value*> new_vmap;
+  std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
   UnpackClosureData(cdata, vfields, &new_vmap);
   CHECK(parallel_env_.penv == nullptr);
   std::swap(function_, f);
@@ -673,7 +673,7 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
                            llvm::Value **ret_tcode, const DataType &r_type,
                            const int64_t begin, const int64_t end) {
   using llvm::BasicBlock;
-  std::string func_name = args[0].as<StringImm>()->value;
+  std::string func_name = args[0].as<StringImmNode>()->value;
   llvm::Value *handle = GetPackedFuncHandle(func_name);
   // call the function
   int64_t nargs = end - begin;
@@ -701,24 +701,24 @@ CodeGenCPU::MakeCallPacked(const Array<Expr> &args, llvm::Value **rvalue,
   return end_block;
 }
 
-llvm::Value *CodeGenCPU::CreateCallPacked(const Call *op) {
+llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) {
   CHECK_EQ(op->args.size(), 5U);
   llvm::Value *rvalue = nullptr;
   llvm::Value *ret_tcode = nullptr;
   MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype,
-                 op->args[3].as<IntImm>()->value,
-                 op->args[4].as<IntImm>()->value);
+                 op->args[3].as<IntImmNode>()->value,
+                 op->args[4].as<IntImmNode>()->value);
   return rvalue;
 }
 
-llvm::Value *CodeGenCPU::CreateCallTracePacked(const Call *op) {
+llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
   using llvm::BasicBlock;
   CHECK_EQ(op->args.size(), 6U);
   llvm::Value *rvalue = nullptr;
   llvm::Value *ret_tcode = nullptr;
   BasicBlock *end_block = MakeCallPacked(
-      op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImm>()->value,
-      op->args[4].as<IntImm>()->value);
+      op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as<IntImmNode>()->value,
+      op->args[4].as<IntImmNode>()->value);
   // Get traced value.
   llvm::Value *traced_value = MakeValue(op->args[5]);
   // The update_block handles case when we need to update the return value.
@@ -786,7 +786,7 @@ void CodeGenCPU::AddStartupFunction() {
   }
 }
 
-llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
     return CreateCallPacked(op);
   } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) {
@@ -798,7 +798,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
     return ConstInt32(-1);
   } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
     CHECK_EQ(op->args.size(), 3U);
-    int kind = op->args[2].as<IntImm>()->value;
+    int kind = op->args[2].as<IntImmNode>()->value;
     llvm::Value* ref = this->CreateStructRefPtr(
         op->dtype, MakeValue(op->args[0]),
         MakeValue(op->args[1]), kind);
@@ -809,7 +809,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
     }
   } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) {
     CHECK_EQ(op->args.size(), 4U);
-    int kind = op->args[2].as<IntImm>()->value;
+    int kind = op->args[2].as<IntImmNode>()->value;
     llvm::Value* value = MakeValue(op->args[3]);
     llvm::Value* ref = this->CreateStructRefPtr(
         op->args[3].dtype(), MakeValue(op->args[0]),
@@ -823,7 +823,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
     return ConstInt32(0);
   } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
     CHECK_EQ(op->args.size(), 2U);
-    const std::string& type = op->args[0].as<StringImm>()->value;
+    const std::string& type = op->args[0].as<StringImmNode>()->value;
     return WithFunctionEntry([&]() -> llvm::AllocaInst* {
         const int64_t* pval = as_const_int(op->args[1]);
         CHECK(pval) << "require stack alloca to contain constant value";
@@ -846,13 +846,13 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
   }
 }
 
-void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
+void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
   using llvm::BasicBlock;
   llvm::Value* cond = MakeValue(op->condition);
   std::ostringstream os;
   os << "Assert fail: " << op->condition;
-  if (op->message.as<StringImm>()) {
-    os << ", " << op->message.as<StringImm>()->value;
+  if (op->message.as<StringImmNode>()) {
+    os << ", " << op->message.as<StringImmNode>()->value;
   }
   llvm::Value* msg = GetConstString(os.str());
   BasicBlock* fail_block = BasicBlock::Create(
@@ -869,9 +869,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmt* op) {
   CodeGenLLVM::VisitStmt_(op);
 }
 
-void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
+void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == ir::attr::coproc_uop_scope) {
-    this->CreateStaticInit(op->value.as<StringImm>()->value, op->body);
+    this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
   } else  if (op->attr_key == ir::attr::compute_scope) {
     this->CreateComputeScope(op);
   } else if (attr::IsPragmaKey(op->attr_key)) {
@@ -893,7 +893,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
           RuntimeTVMParallelBarrier(),
           {MakeValue(parallel_env_.task_id),  parallel_env_.penv});
     } else if (op->attr_key == ir::attr::pragma_import_llvm) {
-      const StringImm* value = op->value.as<StringImm>();
+      const StringImmNode* value = op->value.as<StringImmNode>();
       CHECK(value != nullptr);
       this->HandleImport(value->value);
       this->VisitStmt(op->body);
@@ -906,7 +906,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmt* op) {
   }
 }
 
-void CodeGenCPU::VisitStmt_(const For* op) {
+void CodeGenCPU::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
   if (op->for_type == ForType::Serial ||
       op->for_type == ForType::Unrolled) {
@@ -914,7 +914,7 @@ void CodeGenCPU::VisitStmt_(const For* op) {
   } else if (op->for_type == ForType::Parallel) {
     if (parallel_env_.penv == nullptr) {
       CreateParallelLaunch(
-          For::make(
+          ForNode::make(
               op->loop_var, op->min, op->extent,
               op->for_type, op->device_api, op->body), 0);
     } else {
@@ -936,8 +936,8 @@ void CodeGenCPU::VisitStmt_(const For* op) {
                         op->body);
       } else {
         Expr step = (op->extent + num_task - make_const(t, 1)) / num_task;
-        Expr begin = Min::make(task_id * step, op->extent);
-        Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent);
+        Expr begin = MinNode::make(task_id * step, op->extent);
+        Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
         CreateSerialFor(MakeValue(begin),
                         MakeValue(end),
                         ConstInt32(1),
index b9e1275..46f3f96 100644 (file)
@@ -45,11 +45,11 @@ class CodeGenCPU : public CodeGenLLVM {
   void AddFunction(const LoweredFunc& f) override;
   void AddMainFunction(const std::string& entry_func_name) override;
   std::unique_ptr<llvm::Module> Finish() override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const For* op) override;
-  llvm::Value* CreateIntrinsic(const Call* op) override;
-  llvm::Value* CreateCallExtern(const Call* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  llvm::Value* CreateIntrinsic(const CallNode* op) override;
+  llvm::Value* CreateCallExtern(const CallNode* op) override;
 
  protected:
   void AddStartupFunction() final;
@@ -99,22 +99,22 @@ class CodeGenCPU : public CodeGenLLVM {
   llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
   void UnpackClosureData(llvm::Value*cdata,
                          const Array<Var>& fields,
-                         std::unordered_map<const Variable*, llvm::Value*>* vmap);
+                         std::unordered_map<const VarNode*, llvm::Value*>* vmap);
   // Make packed call.
   llvm::BasicBlock *MakeCallPacked(const Array<Expr> &args,
                                    llvm::Value **rvalue,
                                    llvm::Value **ret_tcode, const DataType &r_type,
                                    const int64_t begin, const int64_t end);
   // create call into tvm packed function.
-  llvm::Value* CreateCallPacked(const Call* op);
+  llvm::Value* CreateCallPacked(const CallNode* op);
   // Create trace call into tvm packed function.
-  llvm::Value* CreateCallTracePacked(const Call *op);
+  llvm::Value* CreateCallTracePacked(const CallNode *op);
   // Create static initialization
   void CreateStaticInit(const std::string& init_fname, const Stmt& body);
   // Create parallel launch
   void CreateParallelLaunch(const Stmt& body, int num_task);
   // Create a new compute scope.
-  void CreateComputeScope(const AttrStmt* op);
+  void CreateComputeScope(const AttrStmtNode* op);
   // Check if the call to packed function is successful
   // if not directly finalize function and pass on return code.
   // return the end block after the check
index b0d86a9..e2ba19a 100644 (file)
@@ -229,7 +229,7 @@ llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) {
   return nullptr;
 }
 
-llvm::Value* CodeGenLLVM::CreateStorageSync(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) {
   LOG(FATAL) << "not implemented";
   return nullptr;
 }
@@ -333,7 +333,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const {
 // This trick comes from Halide's CodeGen_LLVM
 //
 void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
-                               const Variable* buffer,
+                               const VarNode* buffer,
                                Expr index,
                                DataType type) {
   if (alias_var_set_.count(buffer) != 0) {
@@ -348,7 +348,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
   // create meta-data for alias analysis
   // Use a group of binary tree ranges of memory banks.
   if (index.defined()) {
-    const Ramp* ramp = index.as<Ramp>();
+    const RampNode* ramp = index.as<RampNode>();
     if (ramp) {
       int base, stride;
       if (arith::GetConstInt(ramp->base, &base) &&
@@ -388,7 +388,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
 }
 
 void CodeGenLLVM::GetAlignment(DataType t,
-                               const Variable* buf_var,
+                               const VarNode* buf_var,
                                const Expr& index,
                                int* p_alignment,
                                int* p_native_bits) {
@@ -633,13 +633,13 @@ llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
   return builder_->CreateInBoundsGEP(buffer, index);
 }
 
-llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
+llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {
   auto it = var_map_.find(v);
   CHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint;
   return it->second;
 }
 
-llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
   std::vector<llvm::Value*> arg_value;
   std::vector<llvm::Type*> arg_type;
   for (size_t i = 0; i < op->args.size(); ++i) {
@@ -658,11 +658,11 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
   return call;
 }
 
-llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
+llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic("llvm_intrin")) {
     CHECK_GE(op->args.size(), 2U);
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
-        op->args[0].as<UIntImm>()->value);
+        op->args[0].as<UIntImmNode>()->value);
     const uint64_t *num_signature = as_const_uint(op->args[1]);
     CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
                          << "but " << op->args[1] << " got!\n";
@@ -681,17 +681,17 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
     llvm::Function* f = llvm::Intrinsic::getDeclaration(
         module_.get(), id, sig_type);
     return builder_->CreateCall(f, arg_value);
-  } else if (op->is_intrinsic(Call::bitwise_and)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
-  } else if (op->is_intrinsic(Call::bitwise_or)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
     return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1]));
-  } else if (op->is_intrinsic(Call::bitwise_not)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
     return builder_->CreateNot(MakeValue(op->args[0]));
-  } else if (op->is_intrinsic(Call::bitwise_xor)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
-  } else if (op->is_intrinsic(Call::shift_left)) {
+  } else if (op->is_intrinsic(CallNode::shift_left)) {
     return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
-  } else if (op->is_intrinsic(Call::shift_right)) {
+  } else if (op->is_intrinsic(CallNode::shift_right)) {
     if (op->args[0].dtype().is_int()) {
       return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1]));
     } else {
@@ -700,9 +700,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
     return CreateStorageSync(op);
   } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-    const Load *l = op->args[0].as<Load>();
+    const LoadNode *l = op->args[0].as<LoadNode>();
     CHECK(op->args.size() == 1 && l);
-    const Ramp *r = l->index.as<Ramp>();
+    const RampNode *r = l->index.as<RampNode>();
     llvm::Value* ptr;
     unsigned addrspace;
     if (!r) {
@@ -718,7 +718,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
           ptr->getType())->getAddressSpace();
     }
     return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace));
-  } else if (op->is_intrinsic(Call::reinterpret) && is_zero(op->args[0])) {
+  } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) {
     return llvm::Constant::getNullValue(t_void_p_);
   } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
     return builder_->CreateIsNull(MakeValue(op->args[0]));
@@ -746,10 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
     value->addIncoming(then_value, then_value_block);
     value->addIncoming(else_value, else_value_block);
     return value;
-  } else if (op->is_intrinsic(Call::reinterpret)) {
+  } else if (op->is_intrinsic(CallNode::reinterpret)) {
     llvm::Type * target = LLVMType(op->dtype);
     return builder_->CreateBitCast(MakeValue(op->args[0]), target);
-  } else if (op->is_intrinsic(Call::isnan)) {
+  } else if (op->is_intrinsic(CallNode::isnan)) {
     // TODO(hgt312): set fast math flag
     llvm::Value* a = MakeValue(op->args[0]);
     return builder_->CreateFCmpUNO(a, a);
@@ -778,7 +778,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
 
 void CodeGenLLVM::Scalarize(const Expr& e,
                             std::function<void(int i, llvm::Value* v)> f) {
-  if (const Ramp* ramp = e.as<Ramp>()) {
+  if (const RampNode* ramp = e.as<RampNode>()) {
     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
       Expr offset = ramp->base + (ramp->stride * i);
       f(i, MakeValue(offset));
@@ -793,32 +793,32 @@ void CodeGenLLVM::Scalarize(const Expr& e,
 
 
 // Visitors
-llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
   return GetVarValue(op);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
   return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
 }
-llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
   return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImmNode* op) {
   return llvm::ConstantInt::get(LLVMType(op->dtype), op->value);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
   return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
   return GetConstString(op->value);
 }
 
 #define DEFINE_CODEGEN_BINARY_OP(Op)                                    \
   llvm::Value* CodeGenLLVM::Create ## Op(                               \
-      DataType t, llvm::Value* a, llvm::Value *b) {                         \
+      DataType t, llvm::Value* a, llvm::Value *b) {                     \
     if (t.is_int()) {                                                   \
       if (t.bits() >= 32) {                                             \
         return builder_->CreateNSW ## Op (a, b);                        \
@@ -836,8 +836,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) {
       return builder_->CreateF ## Op (a, b);                            \
     }                                                                   \
   }                                                                     \
-  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
-    return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b));  \
+  llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) {          \
+    return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
   }
 
 DEFINE_CODEGEN_BINARY_OP(Add);
@@ -846,7 +846,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
 
 #define DEFINE_CODEGEN_CMP_OP(Op)                                       \
   llvm::Value* CodeGenLLVM::Create ## Op(                               \
-      DataType t, llvm::Value* a, llvm::Value* b) {                         \
+      DataType t, llvm::Value* a, llvm::Value* b) {                     \
     if (t.is_int()) {                                                   \
       return builder_->CreateICmpS ## Op (a, b);                        \
     } else if (t.is_uint()) {                                           \
@@ -856,7 +856,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
       return builder_->CreateFCmpO ## Op (a, b);                        \
     }                                                                   \
 }                                                                       \
-  llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) {                  \
+  llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) {          \
     return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
   }
 
@@ -865,7 +865,7 @@ DEFINE_CODEGEN_CMP_OP(LE);
 DEFINE_CODEGEN_CMP_OP(GT);
 DEFINE_CODEGEN_CMP_OP(GE);
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   if (op->dtype.is_int()) {
@@ -878,7 +878,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) {
   }
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   if (op->dtype.is_int()) {
@@ -891,19 +891,19 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) {
   }
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
@@ -913,7 +913,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) {
   }
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
   llvm::Value* a = MakeValue(op->a);
   llvm::Value* b = MakeValue(op->b);
   if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
@@ -923,33 +923,33 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) {
   }
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
   return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
   return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
   return builder_->CreateNot(MakeValue(op->a));
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
   return builder_->CreateSelect(
       MakeValue(op->condition),
       MakeValue(op->true_value),
       MakeValue(op->false_value));
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
   CHECK(!var_map_.count(op->var.get()));
   var_map_[op->var.get()] = MakeValue(op->value);
   analyzer_->Bind(op->var, op->value);
   return MakeValue(op->body);
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
   DataType t = op->dtype;
   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
   llvm::Value* buffer = MakeValue(op->buffer_var);
@@ -966,7 +966,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
     // vector load
     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
       buffer->getType())->getAddressSpace();
-    if (const Ramp* ramp = op->index.as<Ramp>()) {
+    if (const RampNode* ramp = op->index.as<RampNode>()) {
       if (is_one(ramp->stride)) {
         int alignment, native_bits;
         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
@@ -994,12 +994,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
   return ret;
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
-  if (op->call_type == Call::Intrinsic ||
-      op->call_type == Call::PureIntrinsic) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  if (op->call_type == CallNode::Intrinsic ||
+      op->call_type == CallNode::PureIntrinsic) {
     return CreateIntrinsic(op);
-  } else if (op->call_type == Call::Extern ||
-             op->call_type == Call::PureExtern) {
+  } else if (op->call_type == CallNode::Extern ||
+             op->call_type == CallNode::PureExtern) {
     return CreateCallExtern(op);
   } else {
     LOG(FATAL) << "Unknown call type " <<
@@ -1009,7 +1009,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
   }
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
   llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype));
   for (int i = 0; i < op->lanes; ++i) {
     vec = builder_->CreateInsertElement(
@@ -1019,7 +1019,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
   return vec;
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
   std::vector<llvm::Value *> vecs(op->vectors.size());
   int total_lanes = 0;
   for (int i = 0, e = op->vectors.size(); i < e; ++i) {
@@ -1039,11 +1039,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
   return res;
 }
 
-llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
+llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
   return CreateBroadcast(MakeValue(op->value), op->lanes);
 }
 
-void CodeGenLLVM::VisitStmt_(const Store* op) {
+void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
   CHECK(is_one(op->predicate));
   DataType t = op->value.dtype();
   bool is_volatile = volatile_buf_.count(op->buffer_var.get());
@@ -1062,7 +1062,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
     // vector store
     unsigned addrspace = llvm::dyn_cast<llvm::PointerType>(
         buffer->getType())->getAddressSpace();
-    if (const Ramp* ramp = op->index.as<Ramp>()) {
+    if (const RampNode* ramp = op->index.as<RampNode>()) {
       if (is_one(ramp->stride)) {
         int alignment, native_bits;
         GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits);
@@ -1089,7 +1089,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {
   this->Scalarize(op->index, f);
 }
 
-void CodeGenLLVM::VisitStmt_(const For* op) {
+void CodeGenLLVM::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
   analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
   if (op->for_type == ForType::Unrolled) {
@@ -1103,7 +1103,7 @@ void CodeGenLLVM::VisitStmt_(const For* op) {
 }
 
 
-void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
+void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
   using llvm::BasicBlock;
   llvm::Value* cond = MakeValue(op->condition);
   BasicBlock* then_block = BasicBlock::Create(
@@ -1130,7 +1130,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
 }
 
 
-void CodeGenLLVM::VisitStmt_(const Allocate* op) {
+void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
   CHECK(!is_zero(op->condition));
   llvm::Value* buf = nullptr;
   if (op->new_expr.defined()) {
@@ -1170,7 +1170,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
   this->VisitStmt(op->body);
 }
 
-void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
+void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::thread_extent) {
     IterVar iv = Downcast<IterVar>(op->node);
     if (iv->thread_tag.length() != 0) {
@@ -1180,29 +1180,29 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
       }
     }
   } else if (op->attr_key == ir::attr::storage_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     alloc_storage_info_[v].scope =
-        runtime::StorageScope::make(op->value.as<StringImm>()->value);
+        runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
   } else if (op->attr_key == ir::attr::storage_alignment) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     alloc_storage_info_[v].alignment =
-        static_cast<int>(op->value.as<IntImm>()->value);
+        static_cast<int>(op->value.as<IntImmNode>()->value);
   } else if (op->attr_key == ir::attr::volatile_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     volatile_buf_.insert(v);
   }
   this->VisitStmt(op->body);
 }
 
-void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
+void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
   With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
   this->VisitStmt(op->body);
 }
 
-void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
+void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
   CHECK(!var_map_.count(op->var.get()));
   if (op->var.dtype().is_handle()) {
     if (!is_restricted_) {
@@ -1220,11 +1220,11 @@ void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
   }
 }
 
-void CodeGenLLVM::VisitStmt_(const Evaluate* op) {
+void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
   MakeValue(op->value);
 }
 
-void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) {
   this->VisitStmt(op->body);
 }
 }  // namespace codegen
index 076ffb2..67ca7c1 100644 (file)
@@ -103,46 +103,46 @@ class CodeGenLLVM :
     return llvm::ConstantInt::getSigned(t_int32_, value);
   }
   // override codegen
-  llvm::Value* VisitExpr_(const Variable* op) override;
-  llvm::Value* VisitExpr_(const Cast* op) override;
-  llvm::Value* VisitExpr_(const IntImm* op) override;
-  llvm::Value* VisitExpr_(const UIntImm* op) override;
-  llvm::Value* VisitExpr_(const FloatImm* op) override;
-  llvm::Value* VisitExpr_(const StringImm* op) override;
-  llvm::Value* VisitExpr_(const Add* op) override;
-  llvm::Value* VisitExpr_(const Sub* op) override;
-  llvm::Value* VisitExpr_(const Mul* op) override;
-  llvm::Value* VisitExpr_(const Div* op) override;
-  llvm::Value* VisitExpr_(const Mod* op) override;
-  llvm::Value* VisitExpr_(const Min* op) override;
-  llvm::Value* VisitExpr_(const Max* op) override;
-  llvm::Value* VisitExpr_(const LT* op) override;
-  llvm::Value* VisitExpr_(const LE* op) override;
-  llvm::Value* VisitExpr_(const GT* op) override;
-  llvm::Value* VisitExpr_(const GE* op) override;
-  llvm::Value* VisitExpr_(const EQ* op) override;
-  llvm::Value* VisitExpr_(const NE* op) override;
-  llvm::Value* VisitExpr_(const And* op) override;
-  llvm::Value* VisitExpr_(const Or* op) override;
-  llvm::Value* VisitExpr_(const Not* op) override;
-  llvm::Value* VisitExpr_(const Select* op) override;
-  llvm::Value* VisitExpr_(const Let* op) override;
-  llvm::Value* VisitExpr_(const Load* op) override;
-  llvm::Value* VisitExpr_(const Call* op) override;
-  llvm::Value* VisitExpr_(const Ramp* op) override;
-  llvm::Value* VisitExpr_(const Shuffle* op) override;
-  llvm::Value* VisitExpr_(const Broadcast* op) override;
+  llvm::Value* VisitExpr_(const VarNode* op) override;
+  llvm::Value* VisitExpr_(const CastNode* op) override;
+  llvm::Value* VisitExpr_(const IntImmNode* op) override;
+  llvm::Value* VisitExpr_(const UIntImmNode* op) override;
+  llvm::Value* VisitExpr_(const FloatImmNode* op) override;
+  llvm::Value* VisitExpr_(const StringImmNode* op) override;
+  llvm::Value* VisitExpr_(const AddNode* op) override;
+  llvm::Value* VisitExpr_(const SubNode* op) override;
+  llvm::Value* VisitExpr_(const MulNode* op) override;
+  llvm::Value* VisitExpr_(const DivNode* op) override;
+  llvm::Value* VisitExpr_(const ModNode* op) override;
+  llvm::Value* VisitExpr_(const MinNode* op) override;
+  llvm::Value* VisitExpr_(const MaxNode* op) override;
+  llvm::Value* VisitExpr_(const LTNode* op) override;
+  llvm::Value* VisitExpr_(const LENode* op) override;
+  llvm::Value* VisitExpr_(const GTNode* op) override;
+  llvm::Value* VisitExpr_(const GENode* op) override;
+  llvm::Value* VisitExpr_(const EQNode* op) override;
+  llvm::Value* VisitExpr_(const NENode* op) override;
+  llvm::Value* VisitExpr_(const AndNode* op) override;
+  llvm::Value* VisitExpr_(const OrNode* op) override;
+  llvm::Value* VisitExpr_(const NotNode* op) override;
+  llvm::Value* VisitExpr_(const SelectNode* op) override;
+  llvm::Value* VisitExpr_(const LetNode* op) override;
+  llvm::Value* VisitExpr_(const LoadNode* op) override;
+  llvm::Value* VisitExpr_(const CallNode* op) override;
+  llvm::Value* VisitExpr_(const RampNode* op) override;
+  llvm::Value* VisitExpr_(const ShuffleNode* op) override;
+  llvm::Value* VisitExpr_(const BroadcastNode* op) override;
   // stmt
-  void VisitStmt_(const Store* op) override;
-  void VisitStmt_(const For* op) override;
-  void VisitStmt_(const IfThenElse* op) override;
-  void VisitStmt_(const Allocate* op) override;
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const LetStmt* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const IfThenElseNode* op) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const LetStmtNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
-  void VisitStmt_(const Evaluate* op) override;
-  void VisitStmt_(const ProducerConsumer* op) override;
+  void VisitStmt_(const EvaluateNode* op) override;
+  void VisitStmt_(const ProducerConsumerNode* op) override;
 
  protected:
   /*! \brief The storage information */
@@ -173,13 +173,13 @@ class CodeGenLLVM :
     return res;
   }
   // create intrinstic given call
-  virtual llvm::Value* CreateIntrinsic(const Call* op);
+  virtual llvm::Value* CreateIntrinsic(const CallNode* op);
   // create extern function call
-  virtual llvm::Value* CreateCallExtern(const Call* op);
+  virtual llvm::Value* CreateCallExtern(const CallNode* op);
   // Get the corresponding thread index
   virtual llvm::Value* GetThreadIndex(const IterVar& iv);
   // Get the corresponding thread index
-  virtual llvm::Value* CreateStorageSync(const Call* op);
+  virtual llvm::Value* CreateStorageSync(const CallNode* op);
   // apply optimization on the module.
   virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
   // Scalarize by iterating elements of e.
@@ -211,19 +211,19 @@ class CodeGenLLVM :
   void InitFuncState();
   // Get alignment given index.
   void GetAlignment(
-      DataType t, const Variable* buf_var, const Expr& index,
+      DataType t, const VarNode* buf_var, const Expr& index,
       int* p_alignment, int* p_native_bits);
   // Get constant string
   llvm::Value* GetConstString(const std::string& str);
   // do a scalarize call with f
   llvm::Value* CreateScalarizedCall(
-      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
+      const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
   // handle module import
   void HandleImport(const std::string& code);
   // cast operatpr
   llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
   // comparison op
-  llvm::Value* GetVarValue(const Variable* v) const;
+  llvm::Value* GetVarValue(const VarNode* v) const;
   llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
@@ -245,7 +245,7 @@ class CodeGenLLVM :
                        llvm::Value* stride,
                        const VarExpr& loop_var, const Stmt& body);
   // add alias information.
-  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, DataType type);
+  void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, Expr index, DataType type);
   // The IRBuilder.
   using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
   // The current function
@@ -280,9 +280,9 @@ class CodeGenLLVM :
   /*! \brief native vector bits of current targetx*/
   int native_vector_bits_{0};
   /*! \brief the storage scope of allocation */
-  std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
+  std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
   // The definition of local variable.
-  std::unordered_map<const Variable*, llvm::Value*> var_map_;
+  std::unordered_map<const VarNode*, llvm::Value*> var_map_;
   // global strings
   std::unordered_map<std::string, llvm::Constant*> str_map_;
   // Whether current function is restricted
@@ -290,9 +290,9 @@ class CodeGenLLVM :
   // The analyzer information
   std::unique_ptr<arith::Analyzer> analyzer_;
   // set of var that are not restricted(can alias)
-  std::unordered_set<const Variable*> alias_var_set_;
+  std::unordered_set<const VarNode*> alias_var_set_;
   // set of volatile buffer.
-  std::unordered_set<const Variable*> volatile_buf_;
+  std::unordered_set<const VarNode*> volatile_buf_;
   /*! \brief Helper struct for debug infos. */
   struct DebugInfo {
     std::unique_ptr<llvm::DIBuilder> di_builder_;
index a0caf65..877bbba 100644 (file)
@@ -46,7 +46,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
               llvm::ValueAsMetadata::get(ConstInt32(1)) }));
   }
 
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     CHECK(!is_zero(op->condition));
     llvm::Value* buf = nullptr;
     if (op->new_expr.defined()) {
@@ -129,8 +129,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
     return builder_->CreateCall(f, {});
   }
 
-  llvm::Value* CreateStorageSync(const Call* op) final {
-    const std::string& sync = op->args[0].as<StringImm>()->value;
+  llvm::Value* CreateStorageSync(const CallNode* op) final {
+    const std::string& sync = op->args[0].as<StringImmNode>()->value;
     if (sync == "warp") {
       // TODO(tqchen) warp sync in CUDA9
       return nullptr;
index d613883..03656cc 100644 (file)
@@ -65,14 +65,14 @@ bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature)
 
 class CodeGenX86_64 final : public CodeGenCPU {
  public:
-  llvm::Value* VisitExpr_(const Cast* op) override;
+  llvm::Value* VisitExpr_(const CastNode* op) override;
 
  private:
   llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty,
                                 const std::vector<llvm::Value*>& args);
 };
 
-llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
+llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
   // LLVM does not automatically generate the correct instruction sequences for
   // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of
   // vcvtph2ps), so we explicitly generate them ourselves.
@@ -90,22 +90,23 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) {
           ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
           LLVMType(DataType::Float(32, from.lanes())),
           {
-            MakeValue(ir::Call::make(
-                DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
-                ir::Call::PureIntrinsic)),
+            MakeValue(ir::CallNode::make(
+                DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
+                ir::CallNode::PureIntrinsic)),
                 MakeValue(
-                    ir::Broadcast::make(ir::FloatImm::make(DataType::Float(32), 0), from.lanes())),
-                /*mask=*/MakeValue(ir::IntImm::make(DataType::Int(16), -1)),
-                /*rounding-mode=*/MakeValue(ir::IntImm::make(DataType::Int(32), 4)),
+                    ir::BroadcastNode::make(
+                      ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
+                /*mask=*/MakeValue(ir::IntImmNode::make(DataType::Int(16), -1)),
+                /*rounding-mode=*/MakeValue(ir::IntImmNode::make(DataType::Int(32), 4)),
           });
     }
 
     if (from.lanes() >= 8 && has_f16c) {
       return CallVectorIntrin(
           ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
-          {MakeValue(ir::Call::make(
-              DataType::Int(16, from.lanes()), ir::Call::reinterpret, {op->value},
-              ir::Call::PureIntrinsic))});
+          {MakeValue(ir::CallNode::make(
+              DataType::Int(16, from.lanes()), ir::CallNode::reinterpret, {op->value},
+              ir::CallNode::PureIntrinsic))});
     }
   }
 
index da07ff3..10774ec 100644 (file)
@@ -64,21 +64,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
   Expr e = targs[0];
-  const ir::Call* call = e.as<ir::Call>();
+  const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
   const Expr& x = call->args[0];
   Expr one = make_const(x.dtype(), 1);
   Expr two = make_const(x.dtype(), 2);
   Expr neg_two = make_const(x.dtype(), -2);
 
-  Expr exp_neg2x = ir::Call::make(
-      x.dtype(), "exp", {neg_two * x}, ir::Call::PureIntrinsic);
-  Expr exp_pos2x = ir::Call::make(
-      x.dtype(), "exp", {two * x}, ir::Call::PureIntrinsic);
+  Expr exp_neg2x = ir::CallNode::make(
+      x.dtype(), "exp", {neg_two * x}, ir::CallNode::PureIntrinsic);
+  Expr exp_pos2x = ir::CallNode::make(
+      x.dtype(), "exp", {two * x}, ir::CallNode::PureIntrinsic);
 
   Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
   Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
-  *rv = ir::Select::make(
+  *rv = ir::SelectNode::make(
       x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
 });
 
index 0d65576..a870385 100644 (file)
@@ -39,34 +39,34 @@ namespace codegen {
 template<unsigned id, int num_signature>
 inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   Expr e = targs[0];
-  const ir::Call* call = e.as<ir::Call>();
+  const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
   Array<Expr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
-  cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
+  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
 
   for (Expr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = ir::Call::make(
-      call->dtype, "llvm_intrin", cargs, ir::Call::PureIntrinsic);
+  *rv = ir::CallNode::make(
+      call->dtype, "llvm_intrin", cargs, ir::CallNode::PureIntrinsic);
 }
 
 template<unsigned id, int num_signature>
 inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   Expr e = targs[0];
-  const ir::Call* call = e.as<ir::Call>();
+  const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
   Array<Expr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
-  cargs.push_back(ir::UIntImm::make(DataType::UInt(32), num_signature));
+  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
+  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), num_signature));
   for (Expr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = ir::Call::make(
-      call->dtype, "llvm_intrin", cargs, ir::Call::Intrinsic);
+  *rv = ir::CallNode::make(
+      call->dtype, "llvm_intrin", cargs, ir::CallNode::Intrinsic);
 }
 
 }  // namespace codegen
index 2f0e5c5..00824bb 100644 (file)
@@ -35,14 +35,14 @@ namespace codegen {
 inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) {
   Expr e = args[0];
   using namespace ir;
-  const Call* call = e.as<Call>();
+  const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64.";
   std::ostringstream intrinsic_name;
   intrinsic_name << "__nv_" << call->name;
   if (call->dtype.bits() == 32) intrinsic_name << "f";
-  *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
-                   Call::PureExtern);
+  *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
+                   CallNode::PureExtern);
 }
 
 namespace llvm {
index 380f9a9..09de88f 100644 (file)
@@ -35,12 +35,12 @@ namespace codegen {
 inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) {
   Expr e = args[0];
   using namespace ir;
-  const Call* call = e.as<Call>();
+  const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   std::ostringstream intrinsic_name;
   intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits();
-  *rv = Call::make(call->dtype, intrinsic_name.str(), call->args,
-                   Call::PureExtern);
+  *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args,
+                   CallNode::PureExtern);
 }
 
 namespace llvm {
index 0709965..254e436 100644 (file)
@@ -106,8 +106,8 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
   return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
 }
 
-spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
-  const std::string& sync = op->args[0].as<StringImm>()->value;
+spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
+  const std::string& sync = op->args[0].as<StringImmNode>()->value;
   spirv::Value value;
   if (sync == "warp") {
     return value;
@@ -126,154 +126,154 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) {
   return value;
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
   auto it = var_map_.find(op);
   CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
   return it->second;
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const IntImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
   return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const UIntImmNode* op) {
   return builder_->UIntImm(builder_->GetSType(op->dtype), op->value);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
   return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
   LOG(FATAL) << "StringImm is not supported in Device code";
   return spirv::Value();
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
   return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
   return builder_->Add(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Sub* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
   return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Mul* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
   return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Div* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
   return builder_->Div(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Mod* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
   return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Min* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
   spirv::Value a = MakeValue(op->a);
   spirv::Value b = MakeValue(op->b);
   return builder_->Select(builder_->LT(a, b), a, b);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Max* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
   spirv::Value a = MakeValue(op->a);
   spirv::Value b = MakeValue(op->b);
   return builder_->Select(builder_->GT(a, b), a, b);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const LT* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
   return builder_->LT(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const LE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
   return builder_->LE(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const GT* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
   return builder_->GT(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const GE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
   return builder_->GE(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const EQ* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
   return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const NE* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
   return builder_->NE(MakeValue(op->a), MakeValue(op->b));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const And* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
   spirv::Value a = MakeValue(op->a);
   spirv::Value b = MakeValue(op->b);
   return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Or* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
   spirv::Value a = MakeValue(op->a);
   spirv::Value b = MakeValue(op->b);
   return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Not* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
   spirv::Value a = MakeValue(op->a);
   return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
   return builder_->Select(MakeValue(op->condition),
                           MakeValue(op->true_value),
                           MakeValue(op->false_value));
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
   CHECK(!var_map_.count(op->var.get()));
   var_map_[op->var.get()] = MakeValue(op->value);
   analyzer_->Bind(op->var, op->value);
   return MakeValue(op->body);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
   if (op->is_intrinsic("spirv_glsl450")) {
     CHECK_GE(op->args.size(), 2U);
-    uint32_t inst_id = op->args[0].as<UIntImm>()->value;
+    uint32_t inst_id = op->args[0].as<UIntImmNode>()->value;
     std::vector<spirv::Value> values;
     for (size_t i = 1; i < op->args.size(); ++i) {
       values.push_back(MakeValue(op->args[i]));
     }
     return builder_->CallGLSL450(
         builder_->GetSType(op->dtype), inst_id, values);
-  } else if (op->is_intrinsic(Call::bitwise_and)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     CHECK_EQ(op->args.size(), 2U);
     spirv::Value a = MakeValue(op->args[0]);
     spirv::Value b = MakeValue(op->args[1]);
     return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
-  } else if (op->is_intrinsic(Call::bitwise_xor)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     CHECK_EQ(op->args.size(), 2U);
     spirv::Value a = MakeValue(op->args[0]);
     spirv::Value b = MakeValue(op->args[1]);
     return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
-  } else if (op->is_intrinsic(Call::bitwise_or)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
     CHECK_EQ(op->args.size(), 2U);
     spirv::Value a = MakeValue(op->args[0]);
     spirv::Value b = MakeValue(op->args[1]);
     return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
-  } else if (op->is_intrinsic(Call::bitwise_not)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
     CHECK_EQ(op->args.size(), 1U);
     spirv::Value a = MakeValue(op->args[0]);
     return builder_->MakeValue(spv::OpNot, a.stype, a);
-  } else if (op->is_intrinsic(Call::shift_left)) {
+  } else if (op->is_intrinsic(CallNode::shift_left)) {
     CHECK_EQ(op->args.size(), 2U);
     spirv::Value a = MakeValue(op->args[0]);
     spirv::Value b = MakeValue(op->args[1]);
     return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
-  } else if (op->is_intrinsic(Call::shift_right)) {
+  } else if (op->is_intrinsic(CallNode::shift_right)) {
     CHECK_EQ(op->args.size(), 2U);
     spirv::Value a = MakeValue(op->args[0]);
     spirv::Value b = MakeValue(op->args[1]);
@@ -282,7 +282,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
     } else {
       return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
     }
-  } else if (op->is_intrinsic(Call::reinterpret)) {
+  } else if (op->is_intrinsic(CallNode::reinterpret)) {
     return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
                                MakeValue(op->args[0]));
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
@@ -319,12 +319,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
         builder_->GetSType(op->dtype),
         MakeValue(op->args[0]));
   } else {
-    if (op->call_type == Call::Intrinsic ||
-        op->call_type == Call::PureIntrinsic) {
+    if (op->call_type == CallNode::Intrinsic ||
+        op->call_type == CallNode::PureIntrinsic) {
       LOG(FATAL) << "Unresolved intrinsic " << op->name
                  << " with return type " << op->dtype;
-    } else if (op->call_type == Call::Extern ||
-               op->call_type == Call::PureExtern) {
+    } else if (op->call_type == CallNode::Extern ||
+               op->call_type == CallNode::PureExtern) {
       LOG(FATAL) << "Unresolved extern " << op->name
                  << " with return type " << op->dtype;
     } else {
@@ -334,7 +334,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Call* op) {
   }
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
   std::vector<spirv::Value> values;
   spirv::Value base = MakeValue(op->base);
   for (int i = 0; i < op->lanes; ++i) {
@@ -349,7 +349,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
   return builder_->Concat(values);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Broadcast* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
   std::vector<spirv::Value> values;
   spirv::Value v = MakeValue(op->value);
   for (int i = 0; i < op->lanes; i++) {
@@ -358,7 +358,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Broadcast* op) {
   return builder_->Concat(values);
 }
 
-spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
+spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
   CHECK(is_one(op->predicate));
   auto it = storage_info_.find(op->buffer_var.get());
   CHECK(it != storage_info_.end());
@@ -396,7 +396,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
       this->Scalarize(op->index, f);
       return builder_->Concat(values);
     } else {
-      if (const Ramp* ramp = op->index.as<Ramp>()) {
+      if (const RampNode* ramp = op->index.as<RampNode>()) {
         if (is_one(ramp->stride)) {
           CHECK_EQ(ramp->lanes, op->dtype.lanes());
           arith::ModularSet me = analyzer_->modular_set(ramp->base);
@@ -419,7 +419,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) {
 
 void CodeGenSPIRV::Scalarize(const Expr& e,
                              std::function<void(int i, spirv::Value v)> f) {
-  if (const Ramp* ramp = e.as<Ramp>()) {
+  if (const RampNode* ramp = e.as<RampNode>()) {
     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
       Expr offset = ramp->base + ramp->stride * i;
       f(i, MakeValue(offset));
@@ -434,7 +434,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
   }
 }
 
-void CodeGenSPIRV::VisitStmt_(const Store* op) {
+void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
   CHECK(is_one(op->predicate));
   auto it = storage_info_.find(op->buffer_var.get());
   CHECK(it != storage_info_.end());
@@ -474,7 +474,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
       };
       this->Scalarize(op->index, f);
     } else {
-      if (const Ramp* ramp = op->index.as<Ramp>()) {
+      if (const RampNode* ramp = op->index.as<RampNode>()) {
         if (is_one(ramp->stride)) {
           CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
           arith::ModularSet me = analyzer_->modular_set(ramp->base);
@@ -494,7 +494,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) {
   }
 }
 
-void CodeGenSPIRV::VisitStmt_(const For* op) {
+void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
   analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
   spirv::Value init_value = MakeValue(op->min);
@@ -540,7 +540,7 @@ void CodeGenSPIRV::VisitStmt_(const For* op) {
   builder_->StartLabel(merge_label);
 }
 
-void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
+void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
   spirv::Value cond = MakeValue(op->condition);
   spirv::Label then_label = builder_->NewLabel();
   spirv::Label merge_label = builder_->NewLabel();
@@ -573,7 +573,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElse* op) {
   builder_->StartLabel(merge_label);
 }
 
-void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
+void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
   CHECK(!is_zero(op->condition));
   CHECK(!op->new_expr.defined());
   CHECK(!op->dtype.is_handle());
@@ -603,7 +603,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) {
   this->VisitStmt(op->body);
 }
 
-void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::thread_extent) {
     IterVar iv = Downcast<IterVar>(op->node);
     if (iv->thread_tag.length() != 0) {
@@ -613,24 +613,24 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
       }
     }
   } else if (op->attr_key == ir::attr::storage_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     storage_info_[v].scope =
-        runtime::StorageScope::make(op->value.as<StringImm>()->value);
+        runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
   } else if (op->attr_key == ir::attr::volatile_scope) {
-    const Variable* v = op->node.as<Variable>();
+    const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     storage_info_[v].is_volatile = true;
   }
   this->VisitStmt(op->body);
 }
 
-void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
   With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
   this->VisitStmt(op->body);
 }
 
-void CodeGenSPIRV::VisitStmt_(const LetStmt* op) {
+void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
   CHECK(!var_map_.count(op->var.get()));
   CHECK(!op->var.dtype().is_handle());
   var_map_[op->var.get()] = MakeValue(op->value);
@@ -644,11 +644,11 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
   }
 }
 
-void CodeGenSPIRV::VisitStmt_(const Evaluate* op) {
+void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) {
   MakeValue(op->value);
 }
 
-void CodeGenSPIRV::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) {
   this->VisitStmt(op->body);
 }
 
index 5cd88c9..b72cd5b 100644 (file)
@@ -62,45 +62,45 @@ class CodeGenSPIRV:
     return VisitExpr(e);
   }
   // override codegen
-  spirv::Value VisitExpr_(const Variable* op) override;
-  spirv::Value VisitExpr_(const Cast* op) override;
-  spirv::Value VisitExpr_(const IntImm* op) override;
-  spirv::Value VisitExpr_(const UIntImm* op) override;
-  spirv::Value VisitExpr_(const FloatImm* op) override;
-  spirv::Value VisitExpr_(const StringImm* op) override;
-  spirv::Value VisitExpr_(const Add* op) override;
-  spirv::Value VisitExpr_(const Sub* op) override;
-  spirv::Value VisitExpr_(const Mul* op) override;
-  spirv::Value VisitExpr_(const Div* op) override;
-  spirv::Value VisitExpr_(const Mod* op) override;
-  spirv::Value VisitExpr_(const Min* op) override;
-  spirv::Value VisitExpr_(const Max* op) override;
-  spirv::Value VisitExpr_(const LT* op) override;
-  spirv::Value VisitExpr_(const LE* op) override;
-  spirv::Value VisitExpr_(const GT* op) override;
-  spirv::Value VisitExpr_(const GE* op) override;
-  spirv::Value VisitExpr_(const EQ* op) override;
-  spirv::Value VisitExpr_(const NE* op) override;
-  spirv::Value VisitExpr_(const And* op) override;
-  spirv::Value VisitExpr_(const Or* op) override;
-  spirv::Value VisitExpr_(const Not* op) override;
-  spirv::Value VisitExpr_(const Select* op) override;
-  spirv::Value VisitExpr_(const Let* op) override;
-  spirv::Value VisitExpr_(const Call* op) override;
-  spirv::Value VisitExpr_(const Ramp* op) override;
-  spirv::Value VisitExpr_(const Broadcast* op) override;
-  spirv::Value VisitExpr_(const Load* op) override;
+  spirv::Value VisitExpr_(const VarNode* op) override;
+  spirv::Value VisitExpr_(const CastNode* op) override;
+  spirv::Value VisitExpr_(const IntImmNode* op) override;
+  spirv::Value VisitExpr_(const UIntImmNode* op) override;
+  spirv::Value VisitExpr_(const FloatImmNode* op) override;
+  spirv::Value VisitExpr_(const StringImmNode* op) override;
+  spirv::Value VisitExpr_(const AddNode* op) override;
+  spirv::Value VisitExpr_(const SubNode* op) override;
+  spirv::Value VisitExpr_(const MulNode* op) override;
+  spirv::Value VisitExpr_(const DivNode* op) override;
+  spirv::Value VisitExpr_(const ModNode* op) override;
+  spirv::Value VisitExpr_(const MinNode* op) override;
+  spirv::Value VisitExpr_(const MaxNode* op) override;
+  spirv::Value VisitExpr_(const LTNode* op) override;
+  spirv::Value VisitExpr_(const LENode* op) override;
+  spirv::Value VisitExpr_(const GTNode* op) override;
+  spirv::Value VisitExpr_(const GENode* op) override;
+  spirv::Value VisitExpr_(const EQNode* op) override;
+  spirv::Value VisitExpr_(const NENode* op) override;
+  spirv::Value VisitExpr_(const AndNode* op) override;
+  spirv::Value VisitExpr_(const OrNode* op) override;
+  spirv::Value VisitExpr_(const NotNode* op) override;
+  spirv::Value VisitExpr_(const SelectNode* op) override;
+  spirv::Value VisitExpr_(const LetNode* op) override;
+  spirv::Value VisitExpr_(const CallNode* op) override;
+  spirv::Value VisitExpr_(const RampNode* op) override;
+  spirv::Value VisitExpr_(const BroadcastNode* op) override;
+  spirv::Value VisitExpr_(const LoadNode* op) override;
   // stmt
-  void VisitStmt_(const Store* op) override;
-  void VisitStmt_(const For* op) override;
-  void VisitStmt_(const IfThenElse* op) override;
-  void VisitStmt_(const Allocate* op) override;
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const LetStmt* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const IfThenElseNode* op) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const LetStmtNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
-  void VisitStmt_(const Evaluate* op) override;
-  void VisitStmt_(const ProducerConsumer* op) override;
+  void VisitStmt_(const EvaluateNode* op) override;
+  void VisitStmt_(const ProducerConsumerNode* op) override;
 
  protected:
   /*! \brief The storage information */
@@ -129,7 +129,7 @@ class CodeGenSPIRV:
   void InitFuncState();
   // Get the thread index
   spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
-  spirv::Value CreateStorageSync(const Call* op);
+  spirv::Value CreateStorageSync(const CallNode* op);
   void Scalarize(const Expr& e,
                  std::function<void(int i, spirv::Value v)> f);
   // The builder
@@ -139,9 +139,9 @@ class CodeGenSPIRV:
   // Likely branch
   uint32_t weight_likely_branch_{128};
   // the storage scope of allocation
-  std::unordered_map<const Variable*, StorageInfo> storage_info_;
+  std::unordered_map<const VarNode*, StorageInfo> storage_info_;
   // The definition of local variable.
-  std::unordered_map<const Variable*, spirv::Value> var_map_;
+  std::unordered_map<const VarNode*, spirv::Value> var_map_;
   // The analyzer.
   std::unique_ptr<arith::Analyzer> analyzer_;
 };
index 7a347e5..69d2014 100644 (file)
@@ -35,17 +35,17 @@ using namespace runtime;
 template<unsigned id>
 inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   Expr e = targs[0];
-  const ir::Call* call = e.as<ir::Call>();
+  const ir::CallNode* call = e.as<ir::CallNode>();
   CHECK(call != nullptr);
   Array<Expr> cargs;
   // intrin id.
-  cargs.push_back(ir::UIntImm::make(DataType::UInt(32), id));
+  cargs.push_back(ir::UIntImmNode::make(DataType::UInt(32), id));
 
   for (Expr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = ir::Call::make(
-      call->dtype, "spirv_glsl450", cargs, ir::Call::PureIntrinsic);
+  *rv = ir::CallNode::make(
+      call->dtype, "spirv_glsl450", cargs, ir::CallNode::PureIntrinsic);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
index 23bb008..3da083b 100644 (file)
@@ -81,7 +81,7 @@ int CodeGenStackVM::GetStrID(const std::string& key) {
   return sid;
 }
 
-int CodeGenStackVM::AllocVarID(const Variable* v) {
+int CodeGenStackVM::AllocVarID(const VarNode* v) {
   CHECK(!var_idmap_.count(v));
   int vid = static_cast<int>(vm_.heap_size);
   CHECK_EQ(vm_.heap_size, var_idmap_.size());
@@ -91,17 +91,17 @@ int CodeGenStackVM::AllocVarID(const Variable* v) {
   return vid;
 }
 
-int CodeGenStackVM::GetVarID(const Variable* v) const {
+int CodeGenStackVM::GetVarID(const VarNode* v) const {
   auto it = var_idmap_.find(v);
   CHECK(it != var_idmap_.end())
       << "Find undefined Variable " << v->name_hint;
   return it->second;
 }
 
-void CodeGenStackVM::VisitExpr_(const Load* op) {
+void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
   this->Push(op->buffer_var);
   StackVM::OpCode code = StackVM::GetLoad(op->dtype);
-  if (const IntImm* index = op->index.as<IntImm>()) {
+  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
     this->PushOp(code, index->value);
   } else {
     this->Push(op->index);
@@ -112,10 +112,10 @@ void CodeGenStackVM::VisitExpr_(const Load* op) {
   }
 }
 
-void CodeGenStackVM::VisitStmt_(const Store* op) {
+void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
   this->Push(op->buffer_var);
   StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
-  if (const IntImm* index = op->index.as<IntImm>()) {
+  if (const IntImmNode* index = op->index.as<IntImmNode>()) {
     this->Push(op->value);
     this->PushOp(code, index->value);
   } else {
@@ -128,7 +128,7 @@ void CodeGenStackVM::VisitStmt_(const Store* op) {
   }
 }
 
-void CodeGenStackVM::VisitStmt_(const Allocate* op) {
+void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
   CHECK(!is_zero(op->condition));
   int vid = AllocVarID(op->buffer_var.get());
   if (op->new_expr.defined()) {
@@ -141,22 +141,22 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
   }
 }
 
-void CodeGenStackVM::VisitExpr_(const Call* op) {
+void CodeGenStackVM::VisitExpr_(const CallNode* op) {
   if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-    const Load *l = op->args[0].as<Load>();
+    const LoadNode *l = op->args[0].as<LoadNode>();
     CHECK(op->args.size() == 1 && l);
     this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get()));
     this->Push(l->index);
     this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes());
     this->PushOp(StackVM::MUL_I64);
     this->PushOp(StackVM::ADDR_ADD);
-  } else if (op->is_intrinsic(Call::reinterpret)) {
+  } else if (op->is_intrinsic(CallNode::reinterpret)) {
     this->Push(op->args[0]);
   } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
     CHECK_EQ(op->args.size(), 3U);
-    int kind = op->args[2].as<IntImm>()->value;
+    int kind = op->args[2].as<IntImmNode>()->value;
     this->Push(op->args[0]);
-    const IntImm* index = op->args[1].as<IntImm>();
+    const IntImmNode* index = op->args[1].as<IntImmNode>();
     CHECK(index != nullptr);
     StackVM::Code code;
     code.op_code = StackVM::TVM_STRUCT_GET;
@@ -167,12 +167,12 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
     vm_.code.push_back(code);
   } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
     CHECK_GE(op->args.size(), 5U);
-    const StringImm* s = op->args[0].as<StringImm>();
+    const StringImmNode* s = op->args[0].as<StringImmNode>();
     CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
     this->Push(op->args[1]);
     this->Push(op->args[2]);
-    int begin = op->args[3].as<IntImm>()->value;
-    int end = op->args[4].as<IntImm>()->value;
+    int begin = op->args[3].as<IntImmNode>()->value;
+    int end = op->args[4].as<IntImmNode>()->value;
     // find the fuction id.
     const std::string& func_name = s->value;
     auto it = extern_fun_idmap_.find(func_name);
@@ -196,8 +196,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
     vm_.code.push_back(code);
   } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
     CHECK_EQ(op->args.size(), 2U);
-    const std::string& type = op->args[0].as<StringImm>()->value;
-    const IntImm* num = op->args[1].as<IntImm>();
+    const std::string& type = op->args[0].as<StringImmNode>()->value;
+    const IntImmNode* num = op->args[1].as<IntImmNode>();
     CHECK(num != nullptr);
     static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
     // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
@@ -268,59 +268,59 @@ void CodeGenStackVM::PushCast(DataType dst, DataType src) {
   }
 }
 
-void CodeGenStackVM::VisitExpr_(const StringImm* op) {
+void CodeGenStackVM::VisitExpr_(const StringImmNode* op) {
   int sid = this->GetStrID(op->value);
   this->PushOp(StackVM::PUSH_I64, sid);
 }
 
-void CodeGenStackVM::VisitExpr_(const IntImm* op) {
+void CodeGenStackVM::VisitExpr_(const IntImmNode* op) {
   CHECK(op->value >= std::numeric_limits<int>::min() &&
         op->value <= std::numeric_limits<int>::max())
       << "Int constant exceed bound";
     this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
 }
 
-void CodeGenStackVM::VisitExpr_(const UIntImm* op) {
+void CodeGenStackVM::VisitExpr_(const UIntImmNode* op) {
   CHECK(op->value <= std::numeric_limits<int>::max())
       << "Int constant exceed bound";
   this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
 }
 
-void CodeGenStackVM::VisitExpr_(const FloatImm* op) {
+void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
   LOG(FATAL) << "Float Imm is not supported";
 }
 
-void CodeGenStackVM::VisitExpr_(const Variable* op) {
+void CodeGenStackVM::VisitExpr_(const VarNode* op) {
   int vid = this->GetVarID(op);
   this->PushOp(StackVM::LOAD_HEAP, vid);
 }
 
-void CodeGenStackVM::VisitExpr_(const Cast* op) {
+void CodeGenStackVM::VisitExpr_(const CastNode* op) {
   this->Push(op->value);
   PushCast(op->dtype, op->value.dtype());
 }
 
-void CodeGenStackVM::VisitExpr_(const Add* op) {
+void CodeGenStackVM::VisitExpr_(const AddNode* op) {
   PushBinary(StackVM::ADD_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const Sub* op) {
+void CodeGenStackVM::VisitExpr_(const SubNode* op) {
   PushBinary(StackVM::SUB_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const Mul* op) {
+void CodeGenStackVM::VisitExpr_(const MulNode* op) {
   PushBinary(StackVM::MUL_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const Div* op) {
+void CodeGenStackVM::VisitExpr_(const DivNode* op) {
   PushBinary(StackVM::DIV_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const Mod* op) {
+void CodeGenStackVM::VisitExpr_(const ModNode* op) {
   PushBinary(StackVM::MOD_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const Min* op) {
+void CodeGenStackVM::VisitExpr_(const MinNode* op) {
   this->Push(op->a);
   this->Push(op->b);
   this->PushOp(StackVM::PUSH_VALUE, -1);
@@ -329,7 +329,7 @@ void CodeGenStackVM::VisitExpr_(const Min* op) {
   this->PushOp(StackVM::SELECT);
 }
 
-void CodeGenStackVM::VisitExpr_(const Max* op) {
+void CodeGenStackVM::VisitExpr_(const MaxNode* op) {
   this->Push(op->a);
   this->Push(op->b);
   this->PushOp(StackVM::PUSH_VALUE, 0);
@@ -338,34 +338,34 @@ void CodeGenStackVM::VisitExpr_(const Max* op) {
   this->PushOp(StackVM::SELECT);
 }
 
-void CodeGenStackVM::VisitExpr_(const EQ* op) {
+void CodeGenStackVM::VisitExpr_(const EQNode* op) {
   PushBinary(StackVM::EQ_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const LE* op) {
+void CodeGenStackVM::VisitExpr_(const LENode* op) {
   PushBinary(StackVM::LE_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const NE* op) {
+void CodeGenStackVM::VisitExpr_(const NENode* op) {
   PushBinary(StackVM::EQ_I64, op->a, op->b);
   this->PushOp(StackVM::NOT);
 }
 
-void CodeGenStackVM::VisitExpr_(const LT* op) {
+void CodeGenStackVM::VisitExpr_(const LTNode* op) {
   PushBinary(StackVM::LT_I64, op->a, op->b);
 }
 
-void CodeGenStackVM::VisitExpr_(const GE* op) {
+void CodeGenStackVM::VisitExpr_(const GENode* op) {
   PushBinary(StackVM::LT_I64, op->a, op->b);
   this->PushOp(StackVM::NOT);
 }
 
-void CodeGenStackVM::VisitExpr_(const GT* op) {
+void CodeGenStackVM::VisitExpr_(const GTNode* op) {
   PushBinary(StackVM::LE_I64, op->a, op->b);
   this->PushOp(StackVM::NOT);
 }
 
-void CodeGenStackVM::VisitExpr_(const And* op) {
+void CodeGenStackVM::VisitExpr_(const AndNode* op) {
   this->Push(op->a);
   int64_t pc_jump = this->GetPC();
   int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
@@ -375,7 +375,7 @@ void CodeGenStackVM::VisitExpr_(const And* op) {
   this->SetOperand(opr_index, diff);
 }
 
-void CodeGenStackVM::VisitExpr_(const Or* op) {
+void CodeGenStackVM::VisitExpr_(const OrNode* op) {
   this->Push(op->a);
   int64_t pc_jump = this->GetPC();
   int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
@@ -384,16 +384,16 @@ void CodeGenStackVM::VisitExpr_(const Or* op) {
   this->SetOperand(opr_index, diff);
 }
 
-void CodeGenStackVM::VisitExpr_(const Not* op) {
+void CodeGenStackVM::VisitExpr_(const NotNode* op) {
   this->Push(op->a);
   this->PushOp(StackVM::NOT);
 }
 
-void CodeGenStackVM::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) {
   this->Push(op->body);
 }
 
-void CodeGenStackVM::VisitStmt_(const For* op) {
+void CodeGenStackVM::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
   int vid = this->AllocVarID(op->loop_var.get());
   this->PushOp(StackVM::PUSH_I64, 0);
@@ -423,21 +423,21 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
   }
 }
 
-void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
+void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) {
   if (is_const(ev->value)) return;
-  const Call* op = ev->value.as<Call>();
+  const CallNode* op = ev->value.as<CallNode>();
   if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
     CHECK_EQ(op->args.size(), 4U);
     this->Push(op->args[0]);
     this->Push(op->args[3]);
-    const IntImm* index = op->args[1].as<IntImm>();
+    const IntImmNode* index = op->args[1].as<IntImmNode>();
     CHECK(index != nullptr);
     StackVM::Code code;
     code.op_code = StackVM::TVM_STRUCT_SET;
     vm_.code.push_back(code);
     code.v_int = index->value;
     vm_.code.push_back(code);
-    code.v_int = op->args[2].as<IntImm>()->value;
+    code.v_int = op->args[2].as<IntImmNode>()->value;
     vm_.code.push_back(code);
   } else {
     this->Push(ev->value);
@@ -445,7 +445,7 @@ void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
   }
 }
 
-void CodeGenStackVM::VisitStmt_(const IfThenElse* op) {
+void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
   this->Push(op->condition);
   int64_t label_ejump = this->GetPC();
   int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
@@ -467,30 +467,30 @@ void CodeGenStackVM::VisitStmt_(const IfThenElse* op) {
   }
 }
 
-void CodeGenStackVM::VisitStmt_(const LetStmt* op) {
+void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) {
   this->Push(op->value);
   int64_t vid = this->AllocVarID(op->var.get());
   this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
   this->Push(op->body);
 }
 
-void CodeGenStackVM::VisitExpr_(const Ramp* op) {
+void CodeGenStackVM::VisitExpr_(const RampNode* op) {
   LOG(FATAL) << "Ramp is not supported";
 }
 
-void CodeGenStackVM::VisitExpr_(const Broadcast* op) {
+void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) {
   LOG(FATAL) << "Broadcast is not supported";
 }
 
-void CodeGenStackVM::VisitExpr_(const Select* op) {
+void CodeGenStackVM::VisitExpr_(const SelectNode* op) {
   this->Push(op->true_value);
   this->Push(op->false_value);
   this->Push(op->condition);
   this->PushOp(StackVM::SELECT);
 }
 
-void CodeGenStackVM::VisitStmt_(const AssertStmt* op) {
-  if (const auto* str = op->message.as<StringImm>()) {
+void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) {
+  if (const auto* str = op->message.as<StringImmNode>()) {
     int sid = this->GetStrID(str->value);
     this->Push(op->condition);
     this->PushOp(StackVM::ASSERT, sid);
@@ -498,11 +498,11 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt* op) {
   this->Push(op->body);
 }
 
-void CodeGenStackVM::VisitStmt_(const AttrStmt* op) {
+void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) {
   this->Push(op->body);
 }
 
-void CodeGenStackVM::VisitExpr_(const Let* op) {
+void CodeGenStackVM::VisitExpr_(const LetNode* op) {
   this->Push(op->value);
   int64_t vid = this->AllocVarID(op->var.get());
   this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
index 7a4c0ab..1f00ecc 100644 (file)
@@ -96,13 +96,13 @@ class CodeGenStackVM
    * \param v The variable.
    * \return the heap index of the var.
    */
-  int AllocVarID(const Variable* v);
+  int AllocVarID(const VarNode* v);
   /*!
    * \brief Get a variable name.
    * \param v The variable.
    * \return the heap index of the var.
    */
-  int GetVarID(const Variable* v) const;
+  int GetVarID(const VarNode* v) const;
   // Push binary operator
   void PushBinary(StackVM::OpCode op_int64,
                   const Expr& a,
@@ -111,52 +111,52 @@ class CodeGenStackVM
   void PushCast(DataType dst, DataType src);
   // overloadable functions
   // expression
-  void VisitExpr_(const Variable* op) final;
-  void VisitExpr_(const Load* op) final;
-  void VisitExpr_(const Let* op) final;
-  void VisitExpr_(const Call* op) final;
-  void VisitExpr_(const Add* op) final;
-  void VisitExpr_(const Sub* op) final;
-  void VisitExpr_(const Mul* op) final;
-  void VisitExpr_(const Div* op) final;
-  void VisitExpr_(const Mod* op) final;
-  void VisitExpr_(const Min* op) final;
-  void VisitExpr_(const Max* op) final;
-  void VisitExpr_(const EQ* op) final;
-  void VisitExpr_(const NE* op) final;
-  void VisitExpr_(const LT* op) final;
-  void VisitExpr_(const LE* op) final;
-  void VisitExpr_(const GT* op) final;
-  void VisitExpr_(const GE* op) final;
-  void VisitExpr_(const And* op) final;
-  void VisitExpr_(const Or* op) final;
-  void VisitExpr_(const Cast* op) final;
-  void VisitExpr_(const Not* op) final;
-  void VisitExpr_(const Select* op) final;
-  void VisitExpr_(const Ramp* op) final;
-  void VisitExpr_(const Broadcast* op) final;
-  void VisitExpr_(const IntImm* op) final;
-  void VisitExpr_(const UIntImm* op) final;
-  void VisitExpr_(const FloatImm* op) final;
-  void VisitExpr_(const StringImm* op) final;
+  void VisitExpr_(const VarNode* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitExpr_(const LetNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
+  void VisitExpr_(const AddNode* op) final;
+  void VisitExpr_(const SubNode* op) final;
+  void VisitExpr_(const MulNode* op) final;
+  void VisitExpr_(const DivNode* op) final;
+  void VisitExpr_(const ModNode* op) final;
+  void VisitExpr_(const MinNode* op) final;
+  void VisitExpr_(const MaxNode* op) final;
+  void VisitExpr_(const EQNode* op) final;
+  void VisitExpr_(const NENode* op) final;
+  void VisitExpr_(const LTNode* op) final;
+  void VisitExpr_(const LENode* op) final;
+  void VisitExpr_(const GTNode* op) final;
+  void VisitExpr_(const GENode* op) final;
+  void VisitExpr_(const AndNode* op) final;
+  void VisitExpr_(const OrNode* op) final;
+  void VisitExpr_(const CastNode* op) final;
+  void VisitExpr_(const NotNode* op) final;
+  void VisitExpr_(const SelectNode* op) final;
+  void VisitExpr_(const RampNode* op) final;
+  void VisitExpr_(const BroadcastNode* op) final;
+  void VisitExpr_(const IntImmNode* op) final;
+  void VisitExpr_(const UIntImmNode* op) final;
+  void VisitExpr_(const FloatImmNode* op) final;
+  void VisitExpr_(const StringImmNode* op) final;
   // statment
-  void VisitStmt_(const LetStmt* op) final;
-  void VisitStmt_(const Store* op) final;
-  void VisitStmt_(const For* op) final;
-  void VisitStmt_(const IfThenElse* op) final;
-  void VisitStmt_(const Allocate* op) final;
-  void VisitStmt_(const AttrStmt* op) final;
-  void VisitStmt_(const AssertStmt* op) final;
-  void VisitStmt_(const Evaluate* op) final;
+  void VisitStmt_(const LetStmtNode* op) final;
+  void VisitStmt_(const StoreNode* op) final;
+  void VisitStmt_(const ForNode* op) final;
+  void VisitStmt_(const IfThenElseNode* op) final;
+  void VisitStmt_(const AllocateNode* op) final;
+  void VisitStmt_(const AttrStmtNode* op) final;
+  void VisitStmt_(const AssertStmtNode* op) final;
+  void VisitStmt_(const EvaluateNode* op) final;
   void VisitStmt_(const SeqStmtNode* op) final;
-  void VisitStmt_(const ProducerConsumer* op) final;
+  void VisitStmt_(const ProducerConsumerNode* op) final;
 
  private:
   bool debug_{false};
   /*! \brief The vm to be generated */
   StackVM vm_;
   /*! \brief id of each variable */
-  std::unordered_map<const Variable*, int> var_idmap_;
+  std::unordered_map<const VarNode*, int> var_idmap_;
   /*! \brief id of each string */
   std::unordered_map<std::string, int> str_idmap_;
   /*! \brief id of each global function */
index 00b2c23..7e3d44f 100644 (file)
@@ -76,18 +76,18 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) {
   os << t.bits();
 }
 
-void CodeGenHybrid::VisitExpr_(const IntImm* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) {  // NOLINT(*)
   os << op->value;
 }
-void CodeGenHybrid::VisitExpr_(const UIntImm* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const UIntImmNode* op, std::ostream& os) {  // NOLINT(*)
   PrintType(op->dtype, os);
   os << "(" << op->value << ")";
 }
-void CodeGenHybrid::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
   PrintType(op->dtype, os);
   os << "(" << std::setprecision(20) << op->value << ")";
 }
-void CodeGenHybrid::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*)
   os << "'" << op->value << "'";
 }
 
@@ -114,7 +114,7 @@ inline void PrintBinaryExpr(const T* op,
   }
 }
 
-inline void PrintBinaryIntrinsitc(const Call* op,
+inline void PrintBinaryIntrinsitc(const CallNode* op,
                                   const char* opstr,
                                   std::ostream& os,  // NOLINT(*)
                                   CodeGenHybrid* p) {
@@ -127,7 +127,7 @@ inline void PrintBinaryIntrinsitc(const Call* op,
   os << ')';
 }
 
-void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) {  // NOLINT(*)
   if (op->dtype == op->value.dtype()) {
     PrintExpr(op->value, stream);
   } else {
@@ -138,77 +138,77 @@ void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) {  // NOLINT(*)
   }
 }
 
-void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) {  // NOLINT(*)
   os << GetVarID(op);
 }
-void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "+", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "-", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "*", os, this);
 }
 
-void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) {  // NOLINT(*)
   if (op->dtype.is_int())
     PrintBinaryExpr(op, "//", os, this);
   else
     PrintBinaryExpr(op, "/", os, this);
 }
 
-void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) {  // NOLINT(*)
   if (op->dtype.is_int())
     PrintBinaryExpr(op, "//", os, this);
   else
     PrintBinaryExpr(op, "/", os, this);
 }
 
-void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "%", os, this);
 }
 
-void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "%", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "min", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "max", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "==", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const NE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "!=", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const LT* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "<", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const LE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "<=", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const GT* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, ">", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const GE* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, ">=", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const And* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "&&", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Or* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "||", os, this);
 }
-void CodeGenHybrid::VisitExpr_(const Not* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) {  // NOLINT(*)
   os << "not ";
   PrintExpr(op->a, os);
 }
 
-void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
-  if (op->call_type == Call::Halide) {
+void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
+  if (op->call_type == CallNode::Halide) {
     os << GetTensorID(op->func, op->value_index);
     os << "[";
     for (size_t i = 0; i < op->args.size(); ++i) {
@@ -218,17 +218,17 @@ void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
       os << idx.str();
     }
     os << "]";
-  } else if (op->is_intrinsic(Call::bitwise_and)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     PrintBinaryIntrinsitc(op, "&", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_xor)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_xor)) {
     PrintBinaryIntrinsitc(op, "^", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_or)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_or)) {
     PrintBinaryIntrinsitc(op, "|", os, this);
-  } else if (op->is_intrinsic(Call::shift_left)) {
+  } else if (op->is_intrinsic(CallNode::shift_left)) {
     PrintBinaryIntrinsitc(op, "<<", os, this);
-  } else if (op->is_intrinsic(Call::shift_right)) {
+  } else if (op->is_intrinsic(CallNode::shift_right)) {
     PrintBinaryIntrinsitc(op, ">>", os, this);
-  } else if (op->is_intrinsic(Call::bitwise_not)) {
+  } else if (op->is_intrinsic(CallNode::bitwise_not)) {
     CHECK_EQ(op->args.size(), 1U);
     os << "(~";
     PrintExpr(op->args[0], os);
@@ -251,31 +251,31 @@ void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) {  // NOLINT(*)
   }
 }
 
-void CodeGenHybrid::VisitExpr_(const Load* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Phase 0 has no Load(s)!";
 }
 
-void CodeGenHybrid::VisitStmt_(const Store* op) {
+void CodeGenHybrid::VisitStmt_(const StoreNode* op) {
   LOG(FATAL) << "Phase 0 has no Store(s)!";
 }
 
-void CodeGenHybrid::VisitExpr_(const Let* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Phase 0 has no Let(s)!";
 }
 
-void CodeGenHybrid::VisitStmt_(const Allocate* op) {
+void CodeGenHybrid::VisitStmt_(const AllocateNode* op) {
   LOG(FATAL) << "Phase 0 has no Allocate(s)!";
 }
 
-void CodeGenHybrid::VisitExpr_(const Ramp* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) {  // NOLINT(*)
   LOG(FATAL) << "Ramp to be supported yet";
 }
 
-void CodeGenHybrid::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
   LOG(FATAL) << "Broadcast: not supported ";
 }
 
-void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(*)
+void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) {  // NOLINT(*)
   PrintExpr(op->true_value, os);
   os << " if ";
   PrintExpr(op->condition, os);
@@ -284,13 +284,13 @@ void CodeGenHybrid::VisitExpr_(const Select* op, std::ostream& os) {  // NOLINT(
   os << "\n";
 }
 
-void CodeGenHybrid::VisitStmt_(const LetStmt* op) {
+void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) {
   std::string value = PrintExpr(op->value);
   stream << GetVarID(op->var.get()) << " = " << value << ";\n";
   PrintStmt(op->body);
 }
 
-void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
+void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == ir::attr::thread_extent) {
     auto iter_var = op->node.as<IterVarNode>();
     CHECK(iter_var);
@@ -305,7 +305,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
     indent_ -= tab_;
   } else if (op->attr_key == ir::attr::realize_scope) {
     auto v = Downcast<FunctionRef>(op->node);
-    alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
+    alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
     PrintStmt(op->body);
   } else {
     // For now we ignore the unsupported AttrStmt
@@ -313,7 +313,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) {
   }
 }
 
-void CodeGenHybrid::VisitStmt_(const Realize* op) {
+void CodeGenHybrid::VisitStmt_(const RealizeNode* op) {
   CHECK(alloc_storage_scope_.count(op->func));
   if (!alloc_storage_scope_[op->func].empty()) {
     PrintIndent();
@@ -331,7 +331,7 @@ void CodeGenHybrid::VisitStmt_(const Realize* op) {
   PrintStmt(op->body);
 }
 
-void CodeGenHybrid::VisitStmt_(const AssertStmt* op) {
+void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) {
   PrintIndent();
   stream << "assert ";
   PrintExpr(op->condition, stream);
@@ -341,7 +341,7 @@ void CodeGenHybrid::VisitStmt_(const AssertStmt* op) {
   PrintStmt(op->body);
 }
 
-void CodeGenHybrid::VisitStmt_(const Provide* op) {
+void CodeGenHybrid::VisitStmt_(const ProvideNode* op) {
   PrintIndent();
   stream << GetTensorID(op->func, op->value_index);
   stream << "[";
@@ -354,7 +354,7 @@ void CodeGenHybrid::VisitStmt_(const Provide* op) {
   stream << "\n";
 }
 
-void CodeGenHybrid::VisitStmt_(const For* op) {
+void CodeGenHybrid::VisitStmt_(const ForNode* op) {
   std::string extent = PrintExpr(op->extent);
   PrintIndent();
   std::string vid = GetVarID(op->loop_var.get());
@@ -367,12 +367,12 @@ void CodeGenHybrid::VisitStmt_(const For* op) {
 bool is_noop(const Stmt &stmt) {
   if (!stmt.defined())
     return true;
-  if (auto eval = stmt.as<Evaluate>())
+  if (auto eval = stmt.as<EvaluateNode>())
     return is_const(eval->value);
   return false;
 }
 
-void CodeGenHybrid::VisitStmt_(const IfThenElse* op) {
+void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
   std::string cond = PrintExpr(op->condition);
   PrintIndent();
   stream << "if " << cond << ":\n";
@@ -395,14 +395,14 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
   }
 }
 
-void CodeGenHybrid::VisitStmt_(const Evaluate* op) {
+void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
   if (is_const(op->value)) return;
   std::string str = PrintExpr(op->value);
   if (!str.empty())
     stream << str << "\n";
 }
 
-void CodeGenHybrid::VisitStmt_(const ProducerConsumer* op) {
+void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) {
   PrintStmt(op->body);
 }
 
@@ -410,7 +410,7 @@ void CodeGenHybrid::PrintIndent() {
   stream << std::string(indent_, ' ');
 }
 
-std::string CodeGenHybrid::GetVarID(const Variable *v) {
+std::string CodeGenHybrid::GetVarID(const VarNode *v) {
   if (binds_.count(v))
     return binds_[v];
   auto key = std::make_pair(static_cast<const Object*>(v), 0);
@@ -489,7 +489,7 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt,
     if (auto tensor = inputs[i].as<TensorNode>()) {
       stream << GetTensorID(tensor->op, tensor->value_index);
     } else {
-      auto var = inputs[i].as<Variable>();
+      auto var = inputs[i].as<VarNode>();
       CHECK(var) << "Input should either be a tensor or a variable!";
       stream << GetVarID(var);
     }
index 27c97c7..01696b2 100644 (file)
@@ -90,49 +90,49 @@ class CodeGenHybrid :
     return os.str();
   }
   // expression
-  void VisitExpr_(const Variable* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Load* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Let* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Call* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Add* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Sub* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const FloorDiv* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const FloorMod* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const NE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const LT* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const LE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const GT* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const GE* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const And* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Or* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Cast* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const FloatImm* op, std::ostream& os) override;  // NOLINT(*)
-  void VisitExpr_(const StringImm* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const VarNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LoadNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LetNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const CallNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const AddNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const SubNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MulNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const DivNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const ModNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloorDivNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloorModNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MinNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const MaxNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const EQNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const NENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LTNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const LENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const GTNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const GENode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const AndNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const OrNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const CastNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const NotNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const SelectNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const RampNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const BroadcastNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const IntImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const UIntImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const FloatImmNode* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const StringImmNode* op, std::ostream& os) override;  // NOLINT(*)
   // statment
-  void VisitStmt_(const LetStmt* op) override;
-  void VisitStmt_(const Store* op) override;
-  void VisitStmt_(const Provide* op) override;
-  void VisitStmt_(const For* op) override;
-  void VisitStmt_(const IfThenElse* op) override;
-  void VisitStmt_(const Allocate* op) override;
-  void VisitStmt_(const Realize* op) override;
-  void VisitStmt_(const AttrStmt* op) override;
-  void VisitStmt_(const AssertStmt* op) override;
-  void VisitStmt_(const Evaluate* op) override;
+  void VisitStmt_(const LetStmtNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ProvideNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+  void VisitStmt_(const IfThenElseNode* op) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitStmt_(const RealizeNode* op) override;
+  void VisitStmt_(const AttrStmtNode* op) override;
+  void VisitStmt_(const AssertStmtNode* op) override;
+  void VisitStmt_(const EvaluateNode* op) override;
   void VisitStmt_(const SeqStmtNode* op) override;
-  void VisitStmt_(const ProducerConsumer* op) override;
+  void VisitStmt_(const ProducerConsumerNode* op) override;
   /*!
    * \brief Print Type represetnation of type t.
    * \param t The type representation.
@@ -154,7 +154,7 @@ class CodeGenHybrid :
    *        Values are the corresponding IDs.*/
   std::map<std::pair<const Object *, int>, std::string> id_map_;
   /*! \brief Variables (keys) binded to the threads (values). */
-  std::map<const Variable *, std::string> binds_;
+  std::map<const VarNode *, std::string> binds_;
   /*!
    * \brief Find an unallocated name for the given prefix.
    * \param prefix The given prefix.
@@ -166,7 +166,7 @@ class CodeGenHybrid :
    * \brief Get or allocate the ID for the given variable.
    * \param v The given variable.
    */
-  std::string GetVarID(const Variable *v);
+  std::string GetVarID(const VarNode *v);
   /*!
    * \brief Get or allocate the ID for the given tensor.
    * \param func The tensor to allocate a name.
index 51b355e..34ee4b3 100644 (file)
@@ -76,33 +76,33 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
   virtual R VisitAttrDefault_(const Object* node, Args... args) = 0;
   virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::UIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
   // deep comparison of symbolic integer expressions.
-  virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::LE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::EQ* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::NE* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT;
-  virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
+  virtual R VisitAttr_(const ir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
 
  private:
   // initialize the vtable.
@@ -112,32 +112,32 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
     // Set dispatch
     ATTR_FUNCTOR_DISPATCH(StrMapNode);
     ATTR_FUNCTOR_DISPATCH(ArrayNode);
-    ATTR_FUNCTOR_DISPATCH(IntImm);
-    ATTR_FUNCTOR_DISPATCH(UIntImm);
-    ATTR_FUNCTOR_DISPATCH(FloatImm);
-    ATTR_FUNCTOR_DISPATCH(StringImm);
-    ATTR_FUNCTOR_DISPATCH(Variable);
-    ATTR_FUNCTOR_DISPATCH(Add);
-    ATTR_FUNCTOR_DISPATCH(Sub);
-    ATTR_FUNCTOR_DISPATCH(Mul);
-    ATTR_FUNCTOR_DISPATCH(Div);
-    ATTR_FUNCTOR_DISPATCH(Mod);
-    ATTR_FUNCTOR_DISPATCH(FloorDiv);
-    ATTR_FUNCTOR_DISPATCH(FloorMod);
-    ATTR_FUNCTOR_DISPATCH(Min);
-    ATTR_FUNCTOR_DISPATCH(Max);
-    ATTR_FUNCTOR_DISPATCH(GE);
-    ATTR_FUNCTOR_DISPATCH(GT);
-    ATTR_FUNCTOR_DISPATCH(LE);
-    ATTR_FUNCTOR_DISPATCH(LT);
-    ATTR_FUNCTOR_DISPATCH(EQ);
-    ATTR_FUNCTOR_DISPATCH(NE);
-    ATTR_FUNCTOR_DISPATCH(And);
-    ATTR_FUNCTOR_DISPATCH(Or);
-    ATTR_FUNCTOR_DISPATCH(Not);
-    ATTR_FUNCTOR_DISPATCH(Cast);
-    ATTR_FUNCTOR_DISPATCH(Call);
-    ATTR_FUNCTOR_DISPATCH(Select);
+    ATTR_FUNCTOR_DISPATCH(IntImmNode);
+    ATTR_FUNCTOR_DISPATCH(UIntImmNode);
+    ATTR_FUNCTOR_DISPATCH(FloatImmNode);
+    ATTR_FUNCTOR_DISPATCH(StringImmNode);
+    ATTR_FUNCTOR_DISPATCH(VarNode);
+    ATTR_FUNCTOR_DISPATCH(AddNode);
+    ATTR_FUNCTOR_DISPATCH(SubNode);
+    ATTR_FUNCTOR_DISPATCH(MulNode);
+    ATTR_FUNCTOR_DISPATCH(DivNode);
+    ATTR_FUNCTOR_DISPATCH(ModNode);
+    ATTR_FUNCTOR_DISPATCH(FloorDivNode);
+    ATTR_FUNCTOR_DISPATCH(FloorModNode);
+    ATTR_FUNCTOR_DISPATCH(MinNode);
+    ATTR_FUNCTOR_DISPATCH(MaxNode);
+    ATTR_FUNCTOR_DISPATCH(GENode);
+    ATTR_FUNCTOR_DISPATCH(GTNode);
+    ATTR_FUNCTOR_DISPATCH(LENode);
+    ATTR_FUNCTOR_DISPATCH(LTNode);
+    ATTR_FUNCTOR_DISPATCH(EQNode);
+    ATTR_FUNCTOR_DISPATCH(NENode);
+    ATTR_FUNCTOR_DISPATCH(AndNode);
+    ATTR_FUNCTOR_DISPATCH(OrNode);
+    ATTR_FUNCTOR_DISPATCH(NotNode);
+    ATTR_FUNCTOR_DISPATCH(CastNode);
+    ATTR_FUNCTOR_DISPATCH(CallNode);
+    ATTR_FUNCTOR_DISPATCH(SelectNode);
     return vtable;
   }
 };
@@ -156,31 +156,31 @@ class AttrsEqualHandler :
   bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
   bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::IntImmNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::UIntImmNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::FloatImmNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::StringImmNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::SubNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::MulNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::DivNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::ModNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::FloorDivNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::FloorModNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::MinNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::MaxNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::GENode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::GTNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::LTNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::LENode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::EQNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::NENode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::AndNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::OrNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::NotNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::CastNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::CallNode* lhs, const ObjectRef& other) final;
+  bool VisitAttr_(const ir::SelectNode* lhs, const ObjectRef& other) final;
 };
 
 class AttrsHashHandler :
@@ -197,33 +197,33 @@ class AttrsHashHandler :
 
  protected:
   size_t VisitAttrDefault_(const Object* lhs) final;
-  size_t VisitAttr_(const ir::IntImm* lhs) final;
-  size_t VisitAttr_(const ir::UIntImm* lhs) final;
-  size_t VisitAttr_(const ir::FloatImm* lhs) final;
-  size_t VisitAttr_(const ir::StringImm* lhs) final;
+  size_t VisitAttr_(const ir::IntImmNode* lhs) final;
+  size_t VisitAttr_(const ir::UIntImmNode* lhs) final;
+  size_t VisitAttr_(const ir::FloatImmNode* lhs) final;
+  size_t VisitAttr_(const ir::StringImmNode* lhs) final;
   size_t VisitAttr_(const ArrayNode* lhs) final;
   size_t VisitAttr_(const StrMapNode* lhs) final;
-  size_t VisitAttr_(const ir::Add* op) final;
-  size_t VisitAttr_(const ir::Sub* op) final;
-  size_t VisitAttr_(const ir::Mul* op) final;
-  size_t VisitAttr_(const ir::Div* op) final;
-  size_t VisitAttr_(const ir::Mod* op) final;
-  size_t VisitAttr_(const ir::FloorDiv* op) final;
-  size_t VisitAttr_(const ir::FloorMod* op) final;
-  size_t VisitAttr_(const ir::Min* op) final;
-  size_t VisitAttr_(const ir::Max* op) final;
-  size_t VisitAttr_(const ir::GE* op) final;
-  size_t VisitAttr_(const ir::GT* op) final;
-  size_t VisitAttr_(const ir::LE* op) final;
-  size_t VisitAttr_(const ir::LT* op) final;
-  size_t VisitAttr_(const ir::EQ* op) final;
-  size_t VisitAttr_(const ir::NE* op) final;
-  size_t VisitAttr_(const ir::And* op) final;
-  size_t VisitAttr_(const ir::Or* op) final;
-  size_t VisitAttr_(const ir::Not* op) final;
-  size_t VisitAttr_(const ir::Cast* op) final;
-  size_t VisitAttr_(const ir::Call* op) final;
-  size_t VisitAttr_(const ir::Select* op) final;
+  size_t VisitAttr_(const ir::AddNode* op) final;
+  size_t VisitAttr_(const ir::SubNode* op) final;
+  size_t VisitAttr_(const ir::MulNode* op) final;
+  size_t VisitAttr_(const ir::DivNode* op) final;
+  size_t VisitAttr_(const ir::ModNode* op) final;
+  size_t VisitAttr_(const ir::FloorDivNode* op) final;
+  size_t VisitAttr_(const ir::FloorModNode* op) final;
+  size_t VisitAttr_(const ir::MinNode* op) final;
+  size_t VisitAttr_(const ir::MaxNode* op) final;
+  size_t VisitAttr_(const ir::GENode* op) final;
+  size_t VisitAttr_(const ir::GTNode* op) final;
+  size_t VisitAttr_(const ir::LENode* op) final;
+  size_t VisitAttr_(const ir::LTNode* op) final;
+  size_t VisitAttr_(const ir::EQNode* op) final;
+  size_t VisitAttr_(const ir::NENode* op) final;
+  size_t VisitAttr_(const ir::AndNode* op) final;
+  size_t VisitAttr_(const ir::OrNode* op) final;
+  size_t VisitAttr_(const ir::NotNode* op) final;
+  size_t VisitAttr_(const ir::CastNode* op) final;
+  size_t VisitAttr_(const ir::CallNode* op) final;
+  size_t VisitAttr_(const ir::SelectNode* op) final;
   /*!
    * \brief alias of dmlc::HashCombine
    * \param lhs The first hash value.
index d69e3e2..6264e0f 100644 (file)
@@ -90,29 +90,29 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& ot
   return lhs == other.get();
 }
 
-bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<IntImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<IntImmNode>()) {
     return lhs->value == rhs->value;
   }
   return false;
 }
 
-bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<UIntImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const UIntImmNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<UIntImmNode>()) {
     return lhs->value == rhs->value;
   }
   return false;
 }
 
-bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<FloatImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<FloatImmNode>()) {
     return lhs->value == rhs->value;
   }
   return false;
 }
 
-bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<StringImm>()) {
+bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<StringImmNode>()) {
     return lhs->value == rhs->value;
   }
   return false;
@@ -151,34 +151,34 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other
     }                                                                   \
   }                                                                     \
 
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);
-
-bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<Not>()) {
+TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
+TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
+
+bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<NotNode>()) {
     return Equal(lhs->a, rhs->a);
   } else {
     return false;
   }
 }
 
-bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<Cast>()) {
+bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<CastNode>()) {
     if (lhs->dtype != rhs->dtype) return false;
     return Equal(lhs->value, rhs->value);
   } else {
@@ -186,8 +186,8 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) {
   }
 }
 
-bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<Call>()) {
+bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<CallNode>()) {
     return
         lhs->name == rhs->name &&
         lhs->dtype == rhs->dtype &&
@@ -198,8 +198,8 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) {
   }
 }
 
-bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<Select>()) {
+bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
+  if (const auto* rhs = other.as<SelectNode>()) {
     return
         Equal(lhs->condition, rhs->condition) &&
         Equal(lhs->true_value, rhs->true_value) &&
@@ -220,19 +220,19 @@ size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
   }
 }
 
-size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
   return std::hash<int64_t>()(op->value);
 }
 
-size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const UIntImmNode* op) {
   return std::hash<uint64_t>()(op->value);
 }
 
-size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
   return std::hash<double>()(op->value);
 }
 
-size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
+size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
   return std::hash<std::string>()(op->value);
 }
 
@@ -265,31 +265,31 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
     return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
   }                                                                     \
 
-TVM_DEFINE_ATTRS_BINOP_HASH(Add);
-TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
-TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
-TVM_DEFINE_ATTRS_BINOP_HASH(Div);
-TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod);
-TVM_DEFINE_ATTRS_BINOP_HASH(Max);
-TVM_DEFINE_ATTRS_BINOP_HASH(Min);
-TVM_DEFINE_ATTRS_BINOP_HASH(GE);
-TVM_DEFINE_ATTRS_BINOP_HASH(GT);
-TVM_DEFINE_ATTRS_BINOP_HASH(LE);
-TVM_DEFINE_ATTRS_BINOP_HASH(LT);
-TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
-TVM_DEFINE_ATTRS_BINOP_HASH(NE);
-TVM_DEFINE_ATTRS_BINOP_HASH(And);
-TVM_DEFINE_ATTRS_BINOP_HASH(Or);
-
-size_t AttrsHashHandler::VisitAttr_(const Not* op) {
-  static size_t key = std::hash<std::string>()(Not::_type_key);
+TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
+TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
+TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
+
+size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
+  static size_t key = std::hash<std::string>()(NotNode::_type_key);
   return Combine(key, Hash(op->a));
 }
 
-size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
-  static size_t key = std::hash<std::string>()(Cast::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
+  static size_t key = std::hash<std::string>()(CastNode::_type_key);
   AttrsHash hasher;
   size_t res = key;
   res = Combine(res, hasher(op->dtype));
@@ -297,8 +297,8 @@ size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
   return res;
 }
 
-size_t AttrsHashHandler::VisitAttr_(const Call* op) {
-  static size_t key = std::hash<std::string>()(Call::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
+  static size_t key = std::hash<std::string>()(CallNode::_type_key);
   AttrsHash hasher;
   size_t res = key;
   res = Combine(res, hasher(op->name));
@@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) {
   return res;
 }
 
-size_t AttrsHashHandler::VisitAttr_(const Select* op) {
-  static size_t key = std::hash<std::string>()(Select::_type_key);
+size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
+  static size_t key = std::hash<std::string>()(SelectNode::_type_key);
   size_t res = key;
   res = Combine(res, Hash(op->condition));
   res = Combine(res, Hash(op->true_value));
index 22efa1d..d96033d 100644 (file)
@@ -31,8 +31,8 @@
 namespace tvm {
 
 // TODO(tqchen): change to floormod/div
-using IndexMod = ir::FloorMod;
-using IndexDiv = ir::FloorDiv;
+using IndexMod = ir::FloorModNode;
+using IndexDiv = ir::FloorDivNode;
 
 Array<Expr> SimplifyArray(Array<Expr> array) {
   for (size_t i = 0; i < array.size(); ++i) {
@@ -65,7 +65,7 @@ inline std::vector<const Expr*> ExprSplitAddition(const Expr &expr) {
   while (!split_buffer.empty()) {
     const Expr* top_ele = split_buffer.top();
     split_buffer.pop();
-    auto expr_add_match = top_ele->as<Add>();
+    auto expr_add_match = top_ele->as<AddNode>();
     if (expr_add_match) {
       split_buffer.push(&expr_add_match->b);
       split_buffer.push(&expr_add_match->a);
@@ -88,13 +88,13 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
                                               const Expr &mod_l_expr,
                                               const Expr &mod_r_expr) {
   using namespace ir;
-  const Mul* mult_ptr = mult_expr.as<Mul>();
+  const MulNode* mult_ptr = mult_expr.as<MulNode>();
   if (!mult_ptr) return std::make_pair(false, Expr());
   Expr mult_outer = mult_ptr->b;
   const Expr* inner = &(mult_ptr->a);
   // 1. Calculate the outer multiplier
   while (true) {
-    mult_ptr = inner->as<Mul>();
+    mult_ptr = inner->as<MulNode>();
     if (mult_ptr) {
       inner = &(mult_ptr->a);
       mult_outer = mult_ptr->b * mult_outer;
@@ -113,8 +113,8 @@ inline std::pair<bool, Expr> MergeMulModInner(const Expr &mult_expr,
   Expr no_opt_sum;  // Sum of the exprs that cannot be optimized
   while (true) {
     auto inner_div_ptr = search_ptr->as<IndexDiv>();
-    auto inner_mult_ptr = search_ptr->as<Mul>();
-    auto inner_add_ptr = search_ptr->as<Add>();
+    auto inner_mult_ptr = search_ptr->as<MulNode>();
+    auto inner_add_ptr = search_ptr->as<AddNode>();
     if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) {
       return std::make_pair(false, Expr());
     } else if (inner_div_ptr) {
@@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector<const Expr*>& eles,
   *has_mod = false;
   for (const Expr* ele : eles) {
     auto mod_ptr = ele->as<IndexMod>();
-    auto mult_ptr = ele->as<Mul>();
+    auto mult_ptr = ele->as<MulNode>();
     if (mod_ptr) {
       *has_mod = true;
       mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b)));
@@ -252,7 +252,7 @@ inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
   if (n->strides.size() == 0) {
     // Scalar case
     if (n->shape.size() == 0 && index.size() == 1) {
-      auto is_int = index[0].as<IntImm>();
+      auto is_int = index[0].as<IntImmNode>();
       CHECK(is_int && is_int->value == 0);
       base = base + index[0];
     } else {
@@ -285,7 +285,7 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index, DataType dtype)
     offset = offset * make_const(offset.dtype(), dtype.lanes());
   }
   if (dtype.lanes() != 1) {
-    return ir::Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
+    return ir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
   } else {
     return offset;
   }
@@ -299,13 +299,13 @@ Expr Buffer::vload(Array<Expr> begin, DataType dtype) const {
       << "Cannot load " << dtype
       << " from buffer of " << n->dtype;
   if (dtype == DataType::Bool()) {
-    return ir::Cast::make(
+    return ir::CastNode::make(
         DataType::Bool(),
-        ir::Load::make(
+        ir::LoadNode::make(
             DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)),
             const_true()));
   } else {
-    return ir::Load::make(
+    return ir::LoadNode::make(
         dtype, n->data, BufferOffset(n, begin, dtype),
         const_true(dtype.lanes()));
   }
@@ -320,12 +320,12 @@ Stmt Buffer::vstore(Array<Expr> begin, Expr value) const {
       << "Cannot load " << dtype
       << " from buffer of " << n->dtype;
   if (value.dtype() == DataType::Bool()) {
-    return ir::Store::make(n->data,
-                           ir::Cast::make(DataType::Int(8), value),
+    return ir::StoreNode::make(n->data,
+                           ir::CastNode::make(DataType::Int(8), value),
                            BufferOffset(n, begin, DataType::Int(8)),
                            const_true());
   } else {
-    return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype),
+    return ir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
                            const_true(dtype.lanes()));
   }
 }
@@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
     int highest_dim = 0;
     extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
   } else {
-    extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr()) - offset;
+    extent = arith::ComputeReduce<ir::MulNode>(self->shape, Expr()) - offset;
   }
   Expr elem_offset = self->elem_offset + offset;
   if (content_lanes > 1) {
@@ -405,8 +405,8 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E
   Array<Expr> acc_args{
     e_dtype, self->data, elem_offset,
         extent, make_const(DataType::Int(32), access_mask)};
-  return ir::Call::make(
-      ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::Call::Intrinsic);
+  return ir::CallNode::make(
+      ptr_type, ir::intrinsic::tvm_access_ptr, acc_args, ir::CallNode::Intrinsic);
 }
 
 Buffer BufferNode::make(Var data,
index c4a6b35..c30f344 100644 (file)
@@ -72,7 +72,7 @@ Layout::Layout(const Array<IterVar>& axes) {
   node->axes = axes;
   std::ostringstream repr;
   for (const IterVar& axis : axes) {
-    if (const auto* factor = axis->dom->extent.as<IntImm>()) {
+    if (const auto* factor = axis->dom->extent.as<IntImmNode>()) {
       CHECK_GT(factor->value, 0);
       repr << factor->value;
     }
@@ -186,7 +186,7 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
   if (!this->defined()) return -1;
   for (const IterVar& itvar : operator->()->axes) {
     if (sub == LayoutAxis::Get(itvar)) {
-      const auto* factor = itvar->dom->extent.as<IntImm>();
+      const auto* factor = itvar->dom->extent.as<IntImmNode>();
       CHECK(factor);
       return factor->value;
     }
@@ -251,7 +251,7 @@ inline Array<Expr> TransformIndex(const Array<Expr>& src_index,
                                   const Array<IterVar>& src_axis,
                                   const Array<Expr>& transform_rule) {
   Array<Expr> result;
-  std::unordered_map<const Variable*, Expr> bind_map;
+  std::unordered_map<const VarNode*, Expr> bind_map;
   for (size_t i = 0; i < src_index.size(); ++i) {
     bind_map[src_axis[i]->var.get()] = src_index[i];
   }
@@ -287,18 +287,18 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
   // for major-axis, bind the corresponding size
   // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
   // e.g., (C * 16 + c) / 32
-  std::unordered_map<const Variable*, Expr> bind_map;
+  std::unordered_map<const VarNode*, Expr> bind_map;
   std::unordered_set<size_t> symbolic_var_set;
   for (size_t i = 0; i < src_shape.size(); ++i) {
     Expr orig_shape = src_shape[i];
     IterVar orig_axis = src_axis[i];
-    if (orig_shape.as<ir::Any>()) {
+    if (orig_shape.as<ir::AnyNode>()) {
       symbolic_var_set.insert(i);
     }
     if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
       if (orig_shape.defined()) {
-        const auto* orig_shape_const = orig_shape.as<IntImm>();
-        const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImm>();
+        const auto* orig_shape_const = orig_shape.as<IntImmNode>();
+        const auto* orig_axis_extent = orig_axis->dom->extent.as<IntImmNode>();
         if (orig_shape_const) {
           CHECK_EQ(orig_shape_const->value, orig_axis_extent->value)
             << "Input shape mismatch at index " << i << ". Expected "
@@ -322,7 +322,7 @@ inline Array<Expr> TransformShape(const Array<Expr>& src_shape,
       result.push_back(axis->dom->extent);
     } else {
       if (symbolic_var_set.count(i)) {
-        result.push_back(ir::Any::make());
+        result.push_back(ir::AnyNode::make());
       } else {
         result.push_back(ir::Simplify(ir::Substitute(rule, bind_map)));
       }
index eed6938..58a97ed 100644 (file)
 namespace tvm {
 
 Expr::Expr(int32_t value)
-    : Expr(IntImm::make(DataType::Int(32), value)) {}
+    : Expr(IntImmNode::make(DataType::Int(32), value)) {}
 
 Expr::Expr(float value)
-    : Expr(ir::FloatImm::make(DataType::Float(32), value)) {}
+    : Expr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
 
 Expr::Expr(std::string str)
-    : Expr(ir::StringImm::make(str)) {}
+    : Expr(ir::StringImmNode::make(str)) {}
 
 Var::Var(std::string name_hint, DataType t)
-    : Var(Variable::make(t, name_hint)) {}
+    : Var(VarNode::make(t, name_hint)) {}
 
-Var Variable::make(DataType t, std::string name_hint) {
-  ObjectPtr<Variable> node = make_object<Variable>();
+Var VarNode::make(DataType t, std::string name_hint) {
+  ObjectPtr<VarNode> node = make_object<VarNode>();
   node->dtype = t;
   node->name_hint = std::move(name_hint);
   return Var(node);
@@ -54,10 +54,10 @@ Range::Range(Expr begin, Expr end)
           is_zero(begin) ? end : (end - begin))) {
 }
 
-Integer IntImm::make(DataType t, int64_t value) {
+Integer IntImmNode::make(DataType t, int64_t value) {
   CHECK(t.is_int() && t.is_scalar())
       << "ValueError: IntImm can only take scalar.";
-  ObjectPtr<IntImm> node = make_object<IntImm>();
+  ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
   node->dtype = t;
   node->value = value;
   return Integer(node);
@@ -98,8 +98,8 @@ Var var(std::string name_hint, DataType t) {
 }
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IntImm>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const IntImm*>(node.get());
+.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const IntImmNode*>(node.get());
     if (op->dtype == DataType::Int(32)) {
       p->stream << op->value;
     } else {
index 1166e7e..34fac72 100644 (file)
@@ -32,7 +32,7 @@ namespace tvm {
 // simple cast that only checks if type matches and cast
 inline Expr SimpleCast(const DataType& t, Expr value) {
   if (value.dtype() == t) return value;
-  return ir::Cast::make(t, value);
+  return ir::CastNode::make(t, value);
 }
 
 // The public function with a quick checking path.
@@ -41,9 +41,9 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) {  // NOLINT(*)
   DataType ltype = lhs.dtype();
   DataType rtype = rhs.dtype();
   if (ltype.lanes() == 1 && rtype.lanes() != 1) {
-    lhs = ir::Broadcast::make(lhs, rtype.lanes());
+    lhs = ir::BroadcastNode::make(lhs, rtype.lanes());
   } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
-    rhs = ir::Broadcast::make(rhs, ltype.lanes());
+    rhs = ir::BroadcastNode::make(rhs, ltype.lanes());
   } else {
     CHECK(ltype.lanes() == rtype.lanes())
         << "Cannot match type " << ltype << " vs " << rtype;
@@ -85,27 +85,27 @@ Expr max_value(const DataType& dtype) {
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
     if (dtype.bits() == 64) {
-      return IntImm::make(dtype, std::numeric_limits<int64_t>::max());
+      return IntImmNode::make(dtype, std::numeric_limits<int64_t>::max());
     } else if (dtype.bits() < 64) {
       int64_t val = 1;
       val = (val << (dtype.bits() - 1)) - 1;
-      return IntImm::make(dtype, val);
+      return IntImmNode::make(dtype, val);
     }
   } else if (dtype.is_uint()) {
     if (dtype.bits() == 64) {
-      return UIntImm::make(dtype, std::numeric_limits<uint64_t>::max());
+      return UIntImmNode::make(dtype, std::numeric_limits<uint64_t>::max());
     } else if (dtype.bits() < 64) {
       uint64_t val = 1;
       val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
-      return UIntImm::make(dtype, val);
+      return UIntImmNode::make(dtype, val);
     }
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
-      return FloatImm::make(dtype, std::numeric_limits<double>::max());
+      return FloatImmNode::make(dtype, std::numeric_limits<double>::max());
     } else if (dtype.bits() == 32) {
-      return FloatImm::make(dtype, std::numeric_limits<float>::max());
+      return FloatImmNode::make(dtype, std::numeric_limits<float>::max());
     } else if (dtype.bits() == 16) {
-      return FloatImm::make(dtype, 65504.0);
+      return FloatImmNode::make(dtype, 65504.0);
     }
   }
   LOG(FATAL) << "Cannot decide max_value for type" << dtype;
@@ -117,21 +117,21 @@ Expr min_value(const DataType& dtype) {
   CHECK_EQ(dtype.lanes(), 1);
   if (dtype.is_int()) {
     if (dtype.bits() == 64) {
-      return IntImm::make(dtype, std::numeric_limits<int64_t>::lowest());
+      return IntImmNode::make(dtype, std::numeric_limits<int64_t>::lowest());
     } else if (dtype.bits() < 64) {
       int64_t val = 1;
       val = -(val << (dtype.bits() - 1));
-      return IntImm::make(dtype, val);
+      return IntImmNode::make(dtype, val);
     }
   } else if (dtype.is_uint()) {
-    return UIntImm::make(dtype, 0);
+    return UIntImmNode::make(dtype, 0);
   } else if (dtype.is_float()) {
     if (dtype.bits() == 64) {
-      return FloatImm::make(dtype, std::numeric_limits<double>::lowest());
+      return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
     } else if (dtype.bits() == 32) {
-      return FloatImm::make(dtype, std::numeric_limits<float>::lowest());
+      return FloatImmNode::make(dtype, std::numeric_limits<float>::lowest());
     } else if (dtype.bits() == 16) {
-      return FloatImm::make(dtype, -65504.0);
+      return FloatImmNode::make(dtype, -65504.0);
     }
   }
   LOG(FATAL) << "Cannot decide min_value for type" << dtype;
@@ -153,9 +153,9 @@ inline bool ConstPowerHelper(ValueType val, int *shift) {
 }
 
 bool is_const_power_of_two_integer(const Expr& x, int* shift) {
-  if (const auto* op = x.as<ir::IntImm>()) {
+  if (const auto* op = x.as<ir::IntImmNode>()) {
     return ConstPowerHelper(op->value, shift);
-  } else if (const auto* op = x.as<ir::UIntImm>()) {
+  } else if (const auto* op = x.as<ir::UIntImmNode>()) {
     return ConstPowerHelper(op->value, shift);
   } else {
     return false;
@@ -163,85 +163,86 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
 }
 
 Expr cast(const DataType& t, Expr value) {
-  using ir::IntImm;
-  using ir::UIntImm;
-  using ir::FloatImm;
+  using ir::IntImmNode;
+  using ir::UIntImmNode;
+  using ir::FloatImmNode;
   if (value.dtype() == t) return value;
   // const fold IntImm as they are used in index computations
   if (t.lanes() == 1) {
-    if (const IntImm* op = value.as<IntImm>()) {
+    if (const IntImmNode* op = value.as<IntImmNode>()) {
       return make_const(t, op->value);
-    } else if (const UIntImm* op = value.as<UIntImm>()) {
+    } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
       return make_const(t, op->value);
-    } else if (const FloatImm* op = value.as<FloatImm>()) {
+    } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
       return make_const(t, op->value);
     }
-    return ir::Cast::make(t, value);
+    return ir::CastNode::make(t, value);
   } else {
     if (value.dtype().lanes() == 1) {
       // manually unroll cast
       DataType vtype = t.element_of();
       if (value.dtype() != vtype) {
-        if (const IntImm* op = value.as<IntImm>()) {
+        if (const IntImmNode* op = value.as<IntImmNode>()) {
           value = make_const(vtype, op->value);
-        } else if (const UIntImm* op = value.as<UIntImm>()) {
+        } else if (const UIntImmNode* op = value.as<UIntImmNode>()) {
           return make_const(t, op->value);
-        } else if (const FloatImm* op = value.as<FloatImm>()) {
+        } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
           value = make_const(vtype, op->value);
         } else {
-          value = ir::Cast::make(vtype, value);
+          value = ir::CastNode::make(vtype, value);
         }
       }
-      return ir::Broadcast::make(value, t.lanes());
+      return ir::BroadcastNode::make(value, t.lanes());
     } else {
       CHECK(value.dtype().lanes() == t.lanes());
-      return ir::Cast::make(t, value);
+      return ir::CastNode::make(t, value);
     }
   }
 }
 
 Expr reinterpret(const DataType& t, Expr value) {
   if (value.dtype() == t) return value;
-  return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    t, ir::CallNode::reinterpret, { value }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator+(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Add>(a, b);
+  Expr ret = arith::TryConstFold<ir::AddNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Add::make(a, b);
+  return ir::AddNode::make(a, b);
 }
 
 // negation
 Expr operator-(Expr a) {
-  using ir::IntImm;
-  using ir::FloatImm;
-  const IntImm* pa = a.as<IntImm>();
-  const FloatImm* fa = a.as<FloatImm>();
-  if (pa) return ir::IntImm::make(a.dtype(), -pa->value);
-  if (fa) return ir::FloatImm::make(a.dtype(), -fa->value);
+  using ir::IntImmNode;
+  using ir::FloatImmNode;
+  const IntImmNode* pa = a.as<IntImmNode>();
+  const FloatImmNode* fa = a.as<FloatImmNode>();
+  if (pa) return ir::IntImmNode::make(a.dtype(), -pa->value);
+  if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
   return make_zero(a.dtype()) - a;
 }
 
 Expr operator-(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Sub>(a, b);
+  Expr ret = arith::TryConstFold<ir::SubNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Sub::make(a, b);
+  return ir::SubNode::make(a, b);
 }
 
 Expr operator*(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Mul>(a, b);
+  Expr ret = arith::TryConstFold<ir::MulNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Mul::make(a, b);
+  return ir::MulNode::make(a, b);
 }
 
 Expr div(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Div>(a, b);
+  Expr ret = arith::TryConstFold<ir::DivNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Div::make(a, b);
+  return ir::DivNode::make(a, b);
 }
 
 Expr truncdiv(Expr a, Expr b) {
@@ -252,9 +253,9 @@ Expr truncdiv(Expr a, Expr b) {
 
 Expr truncmod(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Mod>(a, b);
+  Expr ret = arith::TryConstFold<ir::ModNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Mod::make(a, b);
+  return ir::ModNode::make(a, b);
 }
 
 Expr operator/(Expr a, Expr b) {
@@ -278,18 +279,18 @@ Expr floordiv(Expr a, Expr b) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   CHECK(b.dtype().is_int() || b.dtype().is_uint());
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
+  Expr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::FloorDiv::make(a, b);
+  return ir::FloorDivNode::make(a, b);
 }
 
 Expr floormod(Expr a, Expr b) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
   CHECK(b.dtype().is_int() || b.dtype().is_uint());
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
+  Expr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::FloorMod::make(a, b);
+  return ir::FloorModNode::make(a, b);
 }
 
 Expr min(Expr a, Expr b) {
@@ -301,9 +302,9 @@ Expr min(Expr a, Expr b) {
   if (is_pos_inf(b)) return a;
   if (is_neg_inf(b)) return b;
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Min>(a, b);
+  Expr ret = arith::TryConstFold<ir::MinNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Min::make(a, b);
+  return ir::MinNode::make(a, b);
 }
 
 Expr max(Expr a, Expr b) {
@@ -315,184 +316,194 @@ Expr max(Expr a, Expr b) {
   if (is_pos_inf(b)) return b;
   if (is_neg_inf(b)) return a;
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::Max>(a, b);
+  Expr ret = arith::TryConstFold<ir::MaxNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Max::make(a, b);
+  return ir::MaxNode::make(a, b);
 }
 
 Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
-  using ir::IntImm;
-  using ir::UIntImm;
+  using ir::IntImmNode;
+  using ir::UIntImmNode;
   CHECK(cond.dtype() == DataType::Bool(1))
       << "if_then_else only accept the condition to be boolean type.";
   BinaryOpMatchTypes(true_value, false_value);
-  if (const UIntImm* op = cond.as<UIntImm>()) {
+  if (const UIntImmNode* op = cond.as<UIntImmNode>()) {
     if (op->value != 0) {
       return true_value;
     } else {
       return false_value;
     }
-  } else if (const IntImm* op = cond.as<IntImm>()) {
+  } else if (const IntImmNode* op = cond.as<IntImmNode>()) {
     if (op->value != 0) {
       return true_value;
     } else {
       return false_value;
     }
   }
-  return ir::Call::make(
+  return ir::CallNode::make(
       true_value.dtype(),
       ir::intrinsic::tvm_if_then_else,
       {cond, true_value, false_value},
-      ir::Call::PureIntrinsic);
+      ir::CallNode::PureIntrinsic);
 }
 
 Expr likely(Expr cond) {
   if (is_const(cond)) return cond;
-  return ir::Call::make(cond.dtype(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(cond.dtype(),
+                            ir::CallNode::likely,
+                            { cond },
+                            ir::CallNode::PureIntrinsic);
 }
 
 Expr operator>(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::GT>(a, b);
+  Expr ret = arith::TryConstFold<ir::GTNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::GT::make(a, b);
+  return ir::GTNode::make(a, b);
 }
 
 Expr operator>=(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::GE>(a, b);
+  Expr ret = arith::TryConstFold<ir::GENode>(a, b);
   if (ret.defined()) return ret;
-  return ir::GE::make(a, b);
+  return ir::GENode::make(a, b);
 }
 
 Expr operator<(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::LT>(a, b);
+  Expr ret = arith::TryConstFold<ir::LTNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::LT::make(a, b);
+  return ir::LTNode::make(a, b);
 }
 
 Expr operator<=(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::LE>(a, b);
+  Expr ret = arith::TryConstFold<ir::LENode>(a, b);
   if (ret.defined()) return ret;
-  return ir::LE::make(a, b);
+  return ir::LENode::make(a, b);
 }
 
 Expr operator==(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::EQ>(a, b);
+  Expr ret = arith::TryConstFold<ir::EQNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::EQ::make(a, b);
+  return ir::EQNode::make(a, b);
 }
 
 Expr operator!=(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
-  Expr ret = arith::TryConstFold<ir::NE>(a, b);
+  Expr ret = arith::TryConstFold<ir::NENode>(a, b);
   if (ret.defined()) return ret;
-  return ir::NE::make(a, b);
+  return ir::NENode::make(a, b);
 }
 
 Expr operator&&(Expr a, Expr b) {
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::And>(a, b);
+  Expr ret = arith::TryConstFold<ir::AndNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::And::make(a, b);
+  return ir::AndNode::make(a, b);
 }
 
 Expr operator||(Expr a, Expr b) {
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::Or>(a, b);
+  Expr ret = arith::TryConstFold<ir::OrNode>(a, b);
   if (ret.defined()) return ret;
-  return ir::Or::make(a, b);
+  return ir::OrNode::make(a, b);
 }
 
 Expr operator!(Expr a) {
   CHECK(a.dtype().is_bool());
-  Expr ret = arith::TryConstFold<ir::Not>(a);
+  Expr ret = arith::TryConstFold<ir::NotNode>(a);
   if (ret.defined()) return ret;
-  return ir::Not::make(a);
+  return ir::NotNode::make(a);
 }
 
 Expr operator>>(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
+      if (pa && pb) return IntImmNode::make(rtype, (pa->value >> pb->value));
       if (pb) {
         if (pb->value == 0) return a;
       }
     });
-  return ir::Call::make(a.dtype(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::shift_right, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator<<(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
+      if (pa && pb) return IntImmNode::make(rtype, (pa->value << pb->value));
       if (pb) {
         if (pb->value == 0) return a;
       }
     });
-  return ir::Call::make(a.dtype(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::shift_left, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator&(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
+      if (pa && pb) return IntImmNode::make(rtype, (pa->value & pb->value));
     });
-  return ir::Call::make(a.dtype(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::bitwise_and, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator|(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
+      if (pa && pb) return IntImmNode::make(rtype, (pa->value | pb->value));
     });
-  return ir::Call::make(a.dtype(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::bitwise_or, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator^(Expr a, Expr b) {
   BinaryOpMatchTypes(a, b);
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
-      if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
+      if (pa && pb) return IntImmNode::make(rtype, (pa->value ^ pb->value));
     });
-  return ir::Call::make(a.dtype(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::bitwise_xor, { a, b }, ir::CallNode::PureIntrinsic);
 }
 
 Expr operator~(Expr a) {
   CHECK(a.dtype().is_int() || a.dtype().is_uint());
-  return ir::Call::make(a.dtype(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    a.dtype(), ir::CallNode::bitwise_not, { a }, ir::CallNode::PureIntrinsic);
 }
 
 Expr pow(Expr x, Expr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "power only applies to float";
-  return ir::Call::make(x.dtype(), "pow", { x, y }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(
+    x.dtype(), "pow", { x, y }, ir::CallNode::PureIntrinsic);
 }
 
 Expr abs(Expr x) {
   if (x.dtype().is_int()) {
-    using ir::IntImm;
-    const IntImm* px = x.as<IntImm>();
+    using ir::IntImmNode;
+    const IntImmNode* px = x.as<IntImmNode>();
     if (px) {
-      return ir::IntImm::make(x.dtype(), std::abs(px->value));
+      return ir::IntImmNode::make(x.dtype(), std::abs(px->value));
     }
-    return ir::Select::make(x >= make_zero(x.dtype()), x, -x);
+    return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x);
   } else if (x.dtype().is_float()) {
-    using ir::FloatImm;
-    const FloatImm* fx = x.as<FloatImm>();
+    using ir::FloatImmNode;
+    const FloatImmNode* fx = x.as<FloatImmNode>();
     if (fx) {
-      return ir::FloatImm::make(x.dtype(), std::fabs(fx->value));
+      return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value));
     }
-    return ir::Call::make(x.dtype(), "fabs", {x}, ir::Call::PureIntrinsic);
+    return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic);
   } else if (x.dtype().is_uint()) {
     return x;
   } else {
@@ -507,17 +518,17 @@ Expr isnan(Expr x) {
   if (x.dtype().is_int() || x.dtype().is_uint()) {
     return make_const(t, false);
   } else if (x.dtype().is_float()) {
-    using ir::FloatImm;
-    const FloatImm* fx = x.as<FloatImm>();
+    using ir::FloatImmNode;
+    const FloatImmNode* fx = x.as<FloatImmNode>();
     if (fx) {
       return make_const(t, std::isnan(fx->value));
     }
     if (x.dtype().bits() == 16) {
-      return ir::Call::make(t, ir::Call::isnan,
+      return ir::CallNode::make(t, ir::CallNode::isnan,
                                {cast(DataType::Float(32, t.lanes()), std::move(x))},
-                               ir::Call::PureIntrinsic);
+                               ir::CallNode::PureIntrinsic);
     } else {
-      return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
+      return ir::CallNode::make(t, ir::CallNode::isnan, {x}, ir::CallNode::PureIntrinsic);
     }
   } else {
     LOG(FATAL) << "Data type " << x.dtype()
@@ -528,102 +539,102 @@ Expr isnan(Expr x) {
 
 Expr sum(Expr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::Add::make(x, y);
+  Expr result = ir::AddNode::make(x, y);
   Expr identity_element = make_zero(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr all(Expr source, Array<IterVar> rdom) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::And::make(x, y);
+  Expr result = ir::AndNode::make(x, y);
   Expr identity_element = make_const(source.dtype(), true);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr any(Expr source, Array<IterVar> rdom) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::Or::make(x, y);
+  Expr result = ir::OrNode::make(x, y);
   Expr identity_element = make_const(source.dtype(), false);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr max(Expr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::Max::make(x, y);
+  Expr result = ir::MaxNode::make(x, y);
   Expr identity_element = min_value(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr min(Expr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::Min::make(x, y);
+  Expr result = ir::MinNode::make(x, y);
   Expr identity_element = max_value(source.dtype());
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr prod(Expr source, Array<IterVar> rdom) {
   Var x("x", source.dtype()), y("y", source.dtype());
-  Expr result = ir::Mul::make(x, y);
+  Expr result = ir::MulNode::make(x, y);
   Expr identity_element = make_const(source.dtype(), 1);
   ir::CommReducer combiner =
     ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
-  return ir::Reduce::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return ir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
 }
 
 Expr fmod(Expr x, Expr y) {
   BinaryOpMatchTypes(x, y);
   CHECK(x.dtype().is_float()) << "fmod only applies to float";
-  return ir::Call::make(x.dtype(), "fmod", { x, y }, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(x.dtype(), "fmod", { x, y }, ir::CallNode::PureIntrinsic);
 }
 
 Expr floor(Expr x) {
-  using ir::FloatImm;
-  const FloatImm* fx = x.as<FloatImm>();
-  if (fx) return FloatImm::make(x.dtype(), std::floor(fx->value));
-  return ir::Call::make(x.dtype(), "floor", {x}, ir::Call::PureIntrinsic);
+  using ir::FloatImmNode;
+  const FloatImmNode* fx = x.as<FloatImmNode>();
+  if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
+  return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
 }
 
 Expr ceil(Expr x) {
-  using ir::FloatImm;
-  const FloatImm* fx = x.as<FloatImm>();
-  if (fx) return FloatImm::make(x.dtype(), std::ceil(fx->value));
-  return ir::Call::make(x.dtype(), "ceil", {x}, ir::Call::PureIntrinsic);
+  using ir::FloatImmNode;
+  const FloatImmNode* fx = x.as<FloatImmNode>();
+  if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
+  return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
 }
 
 Expr round(Expr x) {
-  using ir::FloatImm;
-  const FloatImm* fx = x.as<FloatImm>();
-  if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
-  return ir::Call::make(x.dtype(), "round", {x}, ir::Call::PureIntrinsic);
+  using ir::FloatImmNode;
+  const FloatImmNode* fx = x.as<FloatImmNode>();
+  if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+  return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
 }
 
 Expr nearbyint(Expr x) {
-  using ir::FloatImm;
-  const FloatImm* fx = x.as<FloatImm>();
-  if (fx) return FloatImm::make(x.dtype(), std::nearbyint(fx->value));
-  return ir::Call::make(x.dtype(), "nearbyint", {x}, ir::Call::PureIntrinsic);
+  using ir::FloatImmNode;
+  const FloatImmNode* fx = x.as<FloatImmNode>();
+  if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
+  return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
 }
 
 Expr trunc(Expr x) {
-  using ir::FloatImm;
-  const FloatImm* fx = x.as<FloatImm>();
+  using ir::FloatImmNode;
+  const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) {
-    return FloatImm::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
+    return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
                                      std::floor(fx->value)));
   }
-  return ir::Call::make(x.dtype(), "trunc", {x}, ir::Call::PureIntrinsic);
+  return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic);
 }
 
 }  // namespace tvm
index de047f3..6b777cc 100644 (file)
@@ -31,79 +31,79 @@ namespace tvm {
 namespace ir {
 
 // constructors
-Expr UIntImm::make(DataType t, uint64_t value) {
+Expr UIntImmNode::make(DataType t, uint64_t value) {
   CHECK(t.is_uint() && t.lanes() == 1)
       << "ValueError: UIntImm can only take scalar";
-  ObjectPtr<UIntImm> node = make_object<UIntImm>();
+  ObjectPtr<UIntImmNode> node = make_object<UIntImmNode>();
   node->dtype = t;
   node->value = value;
   return Expr(node);
 }
 
-Expr FloatImm::make(DataType t, double value) {
+Expr FloatImmNode::make(DataType t, double value) {
   CHECK_EQ(t.lanes(), 1)
       << "ValueError: FloatImm can only take scalar";
-  ObjectPtr<FloatImm> node = make_object<FloatImm>();
+  ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
   node->dtype = t;
   node->value = value;
   return Expr(node);
 }
 
-Expr StringImm::make(std::string value) {
-  ObjectPtr<StringImm> node = make_object<StringImm>();
+Expr StringImmNode::make(std::string value) {
+  ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
   node->dtype = DataType::Handle();
   node->value = std::move(value);
   return Expr(node);
 }
 
-Expr Cast::make(DataType t, Expr value) {
+Expr CastNode::make(DataType t, Expr value) {
   CHECK(value.defined());
   CHECK_EQ(t.lanes(), value.dtype().lanes());
-  ObjectPtr<Cast> node = make_object<Cast>();
+  ObjectPtr<CastNode> node = make_object<CastNode>();
   node->dtype = t;
   node->value = std::move(value);
   return Expr(node);
 }
 
-Expr And::make(Expr a, Expr b) {
+Expr AndNode::make(Expr a, Expr b) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(b.defined()) << "ValueError: b is undefined";
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
   CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
 
-  ObjectPtr<And> node = make_object<And>();
+  ObjectPtr<AndNode> node = make_object<AndNode>();
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
   node->b = std::move(b);
   return Expr(node);
 }
 
-Expr Or::make(Expr a, Expr b) {
+Expr OrNode::make(Expr a, Expr b) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(b.defined()) << "ValueError: b is undefined";
   CHECK(a.dtype().is_bool());
   CHECK(b.dtype().is_bool());
   CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
 
-  ObjectPtr<Or> node = make_object<Or>();
+  ObjectPtr<OrNode> node = make_object<OrNode>();
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
   node->b = std::move(b);
   return Expr(node);
 }
 
-Expr Not::make(Expr a) {
+Expr NotNode::make(Expr a) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(a.dtype().is_bool());
 
-  ObjectPtr<Not> node = make_object<Not>();
+  ObjectPtr<NotNode> node = make_object<NotNode>();
   node->dtype = DataType::Bool(a.dtype().lanes());
   node->a = std::move(a);
   return Expr(node);
 }
 
-Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
+Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) {
   CHECK(condition.defined()) << "ValueError: condition is undefined";
   CHECK(true_value.defined()) << "ValueError: true_value is undefined";
   CHECK(false_value.defined()) << "ValueError: true_value is undefined";
@@ -111,7 +111,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
   CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes());
   CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
 
-  ObjectPtr<Select> node = make_object<Select>();
+  ObjectPtr<SelectNode> node = make_object<SelectNode>();
   node->dtype = true_value.dtype();
   node->condition = std::move(condition);
   node->true_value = std::move(true_value);
@@ -119,14 +119,14 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
   return Expr(node);
 }
 
-Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
+Expr LoadNode::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
   CHECK(buffer_var.defined());
   CHECK(predicate.defined());
   CHECK(index.defined());
   CHECK_EQ(dtype.lanes(), index.dtype().lanes());
   CHECK_EQ(dtype.lanes(), predicate.dtype().lanes());
 
-  ObjectPtr<Load> node = make_object<Load>();
+  ObjectPtr<LoadNode> node = make_object<LoadNode>();
   node->dtype = dtype;
   node->buffer_var = std::move(buffer_var);
   node->index = std::move(index);
@@ -135,7 +135,7 @@ Expr Load::make(DataType dtype, Var buffer_var, Expr index, Expr predicate) {
   return Expr(node);
 }
 
-Expr Ramp::make(Expr base, Expr stride, int lanes) {
+Expr RampNode::make(Expr base, Expr stride, int lanes) {
   CHECK(base.defined());
   CHECK(stride.defined());
   CHECK(base.dtype().is_scalar());
@@ -143,7 +143,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) {
   CHECK_GT(lanes, 1);
   CHECK_EQ(stride.dtype(), base.dtype());
 
-  ObjectPtr<Ramp> node = make_object<Ramp>();
+  ObjectPtr<RampNode> node = make_object<RampNode>();
   node->dtype = base.dtype().with_lanes(lanes);
   node->base = base;
   node->stride = stride;
@@ -151,24 +151,24 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) {
   return Expr(node);
 }
 
-Expr Broadcast::make(Expr value, int lanes) {
+Expr BroadcastNode::make(Expr value, int lanes) {
   CHECK(value.defined());
   CHECK(value.dtype().is_scalar());
   CHECK_GT(lanes, 1);
 
-  ObjectPtr<Broadcast> node = make_object<Broadcast>();
+  ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>();
   node->dtype = value.dtype().with_lanes(lanes);
   node->value = std::move(value);
   node->lanes = lanes;
   return Expr(node);
 }
 
-Expr Let::make(Var var, Expr value, Expr body) {
+Expr LetNode::make(Var var, Expr value, Expr body) {
   CHECK(value.defined());
   CHECK(body.defined());
   CHECK_EQ(value.dtype(), var.dtype());
 
-  ObjectPtr<Let> node = make_object<Let>();
+  ObjectPtr<LetNode> node = make_object<LetNode>();
   node->dtype = body.dtype();
   node->var = std::move(var);
   node->value = std::move(value);
@@ -176,23 +176,23 @@ Expr Let::make(Var var, Expr value, Expr body) {
   return Expr(node);
 }
 
-const char* Call::vectorizable_intrinsics[] = {
+const char* CallNode::vectorizable_intrinsics[] = {
     "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
-    "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
-    ir::Call::likely, ir::Call::popcount
+    "log", "sin", "cos", "pow", ir::CallNode::shift_left, ir::CallNode::shift_right,
+    ir::CallNode::likely, ir::CallNode::popcount
 };
 
-bool Call::is_vectorizable() const {
-  size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
+bool CallNode::is_vectorizable() const {
+  size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*);
   for (size_t i = 0; i < cnt; ++i) {
-    if (name == Call::vectorizable_intrinsics[i]) {
+    if (name == CallNode::vectorizable_intrinsics[i]) {
       return true;
     }
   }
   return false;
 }
 
-Expr Call::make(DataType dtype,
+Expr CallNode::make(DataType dtype,
                 std::string name,
                 Array<Expr> args,
                 CallType call_type,
@@ -208,7 +208,7 @@ Expr Call::make(DataType dtype,
     }
   }
 
-  ObjectPtr<Call> node = make_object<Call>();
+  ObjectPtr<CallNode> node = make_object<CallNode>();
   node->dtype = dtype;
   node->name = std::move(name);
   node->args = std::move(args);
@@ -218,7 +218,7 @@ Expr Call::make(DataType dtype,
   return Expr(node);
 }
 
-Expr Shuffle::make(Array<Expr> vectors,
+Expr ShuffleNode::make(Array<Expr> vectors,
                    Array<Expr> indices) {
   CHECK_NE(vectors.size(), 0U);
   CHECK_NE(indices.size(), 0U);
@@ -232,14 +232,14 @@ Expr Shuffle::make(Array<Expr> vectors,
   }
   CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
 
-  ObjectPtr<Shuffle> node = make_object<Shuffle>();
+  ObjectPtr<ShuffleNode> node = make_object<ShuffleNode>();
   node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
   node->vectors = std::move(vectors);
   node->indices = std::move(indices);
   return Expr(node);
 }
 
-Expr Shuffle::make_concat(Array<Expr> vectors) {
+Expr ShuffleNode::make_concat(Array<Expr> vectors) {
   CHECK_NE(vectors.size(), 0);
   if (vectors.size() == 1) {
     return vectors[0];
@@ -248,13 +248,13 @@ Expr Shuffle::make_concat(Array<Expr> vectors) {
   int index = 0;
   for (const Expr& e : vectors) {
     for (int i = 0; i < e.dtype().lanes(); ++i) {
-      indices.push_back(IntImm::make(DataType::Int(32), index++));
+      indices.push_back(IntImmNode::make(DataType::Int(32), index++));
     }
   }
   return make(vectors, indices);
 }
 
-Expr Shuffle::make_extract_element(Expr vector, int index) {
+Expr ShuffleNode::make_extract_element(Expr vector, int index) {
   return make({vector}, {Integer(index)});
 }
 
@@ -284,7 +284,7 @@ Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
     });
 }
 
-Expr Reduce::make(CommReducer combiner, Array<Expr> source,
+Expr ReduceNode::make(CommReducer combiner, Array<Expr> source,
                   Array<IterVar> axis, Expr condition, int value_index) {
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK_EQ(axis[i]->iter_type, kCommReduce)
@@ -293,7 +293,7 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
   if (!condition.defined()) {
     condition = const_true();
   }
-  auto n = make_object<Reduce>();
+  auto n = make_object<ReduceNode>();
   CHECK(source.defined());
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK(axis[i].defined());
@@ -307,28 +307,28 @@ Expr Reduce::make(CommReducer combiner, Array<Expr> source,
   return Expr(n);
 }
 
-Expr Any::make() {
-  auto n = make_object<Any>();
+Expr AnyNode::make() {
+  auto n = make_object<AnyNode>();
   return Expr(n);
 }
 
-Stmt LetStmt::make(Var var, Expr value, Stmt body) {
+Stmt LetStmtNode::make(Var var, Expr value, Stmt body) {
   CHECK(value.defined());
   CHECK(body.defined());
   CHECK_EQ(value.dtype(), var.dtype());
 
-  ObjectPtr<LetStmt> node = make_object<LetStmt>();
+  ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
   node->var = std::move(var);
   node->value = std::move(value);
   node->body = std::move(body);
   return Stmt(node);
 }
 
-Stmt AttrStmt::make(ObjectRef node,
+Stmt AttrStmtNode::make(ObjectRef node,
                     std::string attr_key,
                     Expr value,
                     Stmt body) {
-  auto n = make_object<AttrStmt>();
+  auto n = make_object<AttrStmtNode>();
   n->node = node;
   n->attr_key = std::move(attr_key);
   n->value = std::move(value);
@@ -336,31 +336,31 @@ Stmt AttrStmt::make(ObjectRef node,
   return Stmt(n);
 }
 
-Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
+Stmt AssertStmtNode::make(Expr condition, Expr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) ||
-        message.as<StringImm>())
+        message.as<StringImmNode>())
       << "TypeError: AssertStmt message must be an int or string:"
       << message << "\n";
 
-  ObjectPtr<AssertStmt> node = make_object<AssertStmt>();
+  ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
   node->condition = std::move(condition);
   node->message = std::move(message);
   node->body = std::move(body);
   return Stmt(node);
 }
 
-Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
+Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   CHECK(body.defined());
 
-  ObjectPtr<ProducerConsumer> node = make_object<ProducerConsumer>();
+  ObjectPtr<ProducerConsumerNode> node = make_object<ProducerConsumerNode>();
   node->func = std::move(func);
   node->is_producer = is_producer;
   node->body = std::move(body);
   return Stmt(node);
 }
 
-Stmt For::make(Var loop_var,
+Stmt ForNode::make(Var loop_var,
                Expr min,
                Expr extent,
                ForType for_type,
@@ -373,7 +373,7 @@ Stmt For::make(Var loop_var,
   CHECK(loop_var.dtype().is_scalar());
   CHECK(body.defined());
 
-  ObjectPtr<For> node = make_object<For>();
+  ObjectPtr<ForNode> node = make_object<ForNode>();
   node->loop_var = std::move(loop_var);
   node->min = std::move(min);
   node->extent = std::move(extent);
@@ -383,14 +383,14 @@ Stmt For::make(Var loop_var,
   return Stmt(node);
 }
 
-Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
+Stmt StoreNode::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
   CHECK(value.defined());
   CHECK(index.defined());
   CHECK(predicate.defined());
   CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
   CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
 
-  ObjectPtr<Store> node = make_object<Store>();
+  ObjectPtr<StoreNode> node = make_object<StoreNode>();
   node->buffer_var = std::move(buffer_var);
   node->value = std::move(value);
   node->index = std::move(index);
@@ -398,7 +398,7 @@ Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
   return Stmt(node);
 }
 
-Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
+Stmt ProvideNode::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
   CHECK(value_index >=0 && value_index < func->num_outputs())
       << "value index output function return value bound";
   CHECK(value.defined()) << "Provide of undefined value\n";
@@ -407,7 +407,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> ar
     CHECK(args[i].defined()) << "Provide to undefined location\n";
   }
 
-  ObjectPtr<Provide> node = make_object<Provide>();
+  ObjectPtr<ProvideNode> node = make_object<ProvideNode>();
   node->func = std::move(func);
   node->value_index = value_index;
   node->value = std::move(value);
@@ -415,7 +415,7 @@ Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> ar
   return Stmt(node);
 }
 
-Stmt Allocate::make(Var buffer_var,
+Stmt AllocateNode::make(Var buffer_var,
                     DataType dtype,
                     Array<Expr> extents,
                     Expr condition,
@@ -430,7 +430,7 @@ Stmt Allocate::make(Var buffer_var,
     CHECK(condition.defined());
     CHECK(condition.dtype().is_bool());
 
-    ObjectPtr<Allocate> node = make_object<Allocate>();
+    ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
     node->buffer_var = std::move(buffer_var);
     node->dtype = dtype;
     node->extents = std::move(extents);
@@ -441,10 +441,10 @@ Stmt Allocate::make(Var buffer_var,
     return Stmt(node);
 }
 
-int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
+int32_t AllocateNode::constant_allocation_size(const Array<Expr>& extents) {
   int64_t result = 1;
   for (size_t i = 0; i < extents.size(); ++i) {
-    if (const IntImm *int_size = extents[i].as<IntImm>()) {
+    if (const IntImmNode *int_size = extents[i].as<IntImmNode>()) {
       result *= int_size->value;
       if (result > std::numeric_limits<int32_t>::max()) {
         return 0;
@@ -456,13 +456,13 @@ int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
   return static_cast<int32_t>(result);
 }
 
-Stmt Free::make(Var buffer_var) {
-  ObjectPtr<Free> node = make_object<Free>();
+Stmt FreeNode::make(Var buffer_var) {
+  ObjectPtr<FreeNode> node = make_object<FreeNode>();
   node->buffer_var = buffer_var;
   return Stmt(node);
 }
 
-Stmt Realize::make(FunctionRef func,
+Stmt RealizeNode::make(FunctionRef func,
                    int value_index,
                    DataType dtype,
                    Region bounds,
@@ -478,7 +478,7 @@ Stmt Realize::make(FunctionRef func,
   CHECK(condition.defined());
   CHECK(condition.dtype().is_bool());
 
-  ObjectPtr<Realize> node = make_object<Realize>();
+  ObjectPtr<RealizeNode> node = make_object<RealizeNode>();
   node->func = std::move(func);
   node->value_index = value_index;
   node->dtype = dtype;
@@ -488,7 +488,7 @@ Stmt Realize::make(FunctionRef func,
   return Stmt(node);
 }
 
-Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
+Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) {
   for (size_t i = 0; i < bounds.size(); ++i) {
     CHECK(bounds[i]->min.defined());
     CHECK(bounds[i]->extent.defined());
@@ -496,7 +496,7 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo
     CHECK(bounds[i]->extent.dtype().is_scalar());
   }
 
-  ObjectPtr<Prefetch> node = make_object<Prefetch>();
+  ObjectPtr<PrefetchNode> node = make_object<PrefetchNode>();
   node->func = std::move(func);
   node->value_index = value_index;
   node->dtype = dtype;
@@ -510,36 +510,36 @@ SeqStmt::SeqStmt(Array<Stmt> seq) {
   data_ = std::move(node);
 }
 
-Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
+Stmt IfThenElseNode::make(Expr condition, Stmt then_case, Stmt else_case) {
   CHECK(condition.defined());
   CHECK(then_case.defined());
   // else_case may be null.
 
-  ObjectPtr<IfThenElse> node = make_object<IfThenElse>();
+  ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
   node->condition = std::move(condition);
   node->then_case = std::move(then_case);
   node->else_case = std::move(else_case);
   return Stmt(node);
 }
 
-Stmt Evaluate::make(Expr value) {
+Stmt EvaluateNode::make(Expr value) {
   CHECK(value.defined());
 
-  ObjectPtr<Evaluate> node = make_object<Evaluate>();
+  ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
   node->value = std::move(value);
   return Stmt(node);
 }
 
 // Printers
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<UIntImm>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const UIntImm*>(node.get());
+.set_dispatch<UIntImmNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const UIntImmNode*>(node.get());
     p->stream << "(" << op->dtype << ")" << op->value;
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloatImm>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const FloatImm*>(node.get());
+.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const FloatImmNode*>(node.get());
     auto& stream = p->stream;
     switch (op->dtype.bits()) {
       case 64:
@@ -557,8 +557,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<StringImm>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const StringImm*>(node.get());
+.set_dispatch<StringImmNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const StringImmNode*>(node.get());
     auto& stream = p->stream;
     stream << '"';
     for (size_t i = 0; i < op->value.size(); ++i) {
@@ -593,116 +593,116 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Cast>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Cast*>(node.get());
+.set_dispatch<CastNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const CastNode*>(node.get());
     p->stream << op->dtype << '(';
     p->Print(op->value);
     p->stream << ')';
   })
-.set_dispatch<Variable>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Variable*>(node.get());
+.set_dispatch<VarNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const VarNode*>(node.get());
     // omit the type
     // stream << op->name << "." << op->type;
     p->stream << op->name_hint;
   })
-.set_dispatch<Add>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Add*>(node.get());
+.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const AddNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " + ";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Sub>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Sub*>(node.get());
+.set_dispatch<SubNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const SubNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " - ";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Mul>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Mul*>(node.get());
+.set_dispatch<MulNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const MulNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << "*";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Div>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Div*>(node.get());
+.set_dispatch<DivNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const DivNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << "/";
     p->Print(op->b);
     p->stream << ')';
   })
-.set_dispatch<Mod>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Mod*>(node.get());
+.set_dispatch<ModNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ModNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " % ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<Min>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Min*>(node.get());
+.set_dispatch<MinNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const MinNode*>(node.get());
     p->stream << "min(";
     p->Print(op->a);
     p->stream << ", ";
     p->Print(op->b);
     p->stream << ")";
 })
-.set_dispatch<Max>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Max*>(node.get());
+.set_dispatch<MaxNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const MaxNode*>(node.get());
     p->stream << "max(";
     p->Print(op->a);
     p->stream << ", ";
     p->Print(op->b);
     p->stream << ")";
 })
-.set_dispatch<EQ>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const EQ*>(node.get());
+.set_dispatch<EQNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const EQNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " == ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<NE>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const NE*>(node.get());
+.set_dispatch<NENode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const NENode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " != ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<LT>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const LT*>(node.get());
+.set_dispatch<LTNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const LTNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " < ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<LE>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const LE*>(node.get());
+.set_dispatch<LENode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const LENode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " <= ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<GT>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const GT*>(node.get());
+.set_dispatch<GTNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const GTNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " > ";
     p->Print(op->b);
     p->stream << ')';
 })
-.set_dispatch<GE>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const GE*>(node.get());
+.set_dispatch<GENode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const GENode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " >= ";
@@ -711,20 +711,20 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloorDiv>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const FloorDiv*>(node.get());
+.set_dispatch<FloorDivNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const FloorDivNode*>(node.get());
   p->stream << "floordiv(" << op->a << ", " << op->b << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<FloorMod>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const FloorMod*>(node.get());
+.set_dispatch<FloorModNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const FloorModNode*>(node.get());
   p->stream << "floormod(" << op->a << ", " << op->b << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<And>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const And*>(node.get());
+.set_dispatch<AndNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const AndNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " && ";
@@ -733,8 +733,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Or>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Or*>(node.get());
+.set_dispatch<OrNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const OrNode*>(node.get());
     p->stream << '(';
     p->Print(op->a);
     p->stream << " || ";
@@ -743,15 +743,15 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Not>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Not*>(node.get());
+.set_dispatch<NotNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const NotNode*>(node.get());
     p->stream << '!';
     p->Print(op->a);
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Select>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Select*>(node.get());
+.set_dispatch<SelectNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const SelectNode*>(node.get());
     p->stream << "select(";
     p->Print(op->condition);
     p->stream << ", ";
@@ -762,8 +762,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Load>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Load*>(node.get());
+.set_dispatch<LoadNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const LoadNode*>(node.get());
     p->stream << op->buffer_var << "[";
     p->Print(op->index);
     p->stream << "]";
@@ -774,8 +774,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Ramp>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Ramp*>(node.get());
+.set_dispatch<RampNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const RampNode*>(node.get());
     p->stream << "ramp(";
     p->Print(op->base);
     p->stream << ", ";
@@ -784,16 +784,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Broadcast>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Broadcast*>(node.get());
+.set_dispatch<BroadcastNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const BroadcastNode*>(node.get());
     p->stream << "x" << op->lanes << "(";
     p->Print(op->value);
     p->stream << ")";
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Call>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Call*>(node.get());
+.set_dispatch<CallNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const CallNode*>(node.get());
     p->stream << op->name << "(";
     for (size_t i = 0; i < op->args.size(); ++i) {
       p->Print(op->args[i]);
@@ -805,8 +805,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Let>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Let*>(node.get());
+.set_dispatch<LetNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const LetNode*>(node.get());
     p->stream << "(let " << op->var << " = ";
     p->Print(op->value);
     p->stream << " in ";
@@ -815,8 +815,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<LetStmt>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const LetStmt*>(node.get());
+.set_dispatch<LetStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const LetStmtNode*>(node.get());
     p->PrintIndent();
     p->stream << "let " << op->var << " = ";
     p->Print(op->value);
@@ -825,8 +825,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<AttrStmt>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const AttrStmt*>(node.get());
+.set_dispatch<AttrStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const AttrStmtNode*>(node.get());
     p->PrintIndent();
     p->stream << "// attr [";
     p->Print(op->node);
@@ -838,8 +838,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<AssertStmt>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const AssertStmt*>(node.get());
+.set_dispatch<AssertStmtNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const AssertStmtNode*>(node.get());
     p->PrintIndent();
     p->stream << "assert(";
     p->Print(op->condition);
@@ -850,8 +850,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<ProducerConsumer>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const ProducerConsumer*>(node.get());
+.set_dispatch<ProducerConsumerNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ProducerConsumerNode*>(node.get());
     if (op->is_producer) {
       p->PrintIndent();
       p->stream << "produce " << op->func->func_name() << " {\n";
@@ -884,8 +884,8 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
 }
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<For>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const For*>(node.get());
+.set_dispatch<ForNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ForNode*>(node.get());
     p->PrintIndent();
     p->stream << op->for_type << " (" << op->loop_var << ", ";
     p->Print(op->min);
@@ -902,8 +902,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Store>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Store*>(node.get());
+.set_dispatch<StoreNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const StoreNode*>(node.get());
     p->PrintIndent();
     p->stream << op->buffer_var << "[";
     p->Print(op->index);
@@ -917,8 +917,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Provide>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Provide*>(node.get());
+.set_dispatch<ProvideNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ProvideNode*>(node.get());
     p->PrintIndent();
     p->stream << op->func->func_name() << "(";
     for (size_t i = 0; i < op->args.size(); ++i) {
@@ -935,8 +935,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Allocate>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Allocate*>(node.get());
+.set_dispatch<AllocateNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const AllocateNode*>(node.get());
     p->PrintIndent();
     p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
     for (size_t i = 0; i < op->extents.size(); ++i) {
@@ -953,16 +953,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Free>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Free*>(node.get());
+.set_dispatch<FreeNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const FreeNode*>(node.get());
     p->PrintIndent();
     p->stream << "free " << op->buffer_var;
     p->stream << '\n';
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Realize>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Realize*>(node.get());
+.set_dispatch<RealizeNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const RealizeNode*>(node.get());
     p->PrintIndent();
     p->stream << "realize " << op->func->func_name() << "(";
     for (size_t i = 0; i < op->bounds.size(); ++i) {
@@ -992,8 +992,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Prefetch>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Prefetch*>(node.get());
+.set_dispatch<PrefetchNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const PrefetchNode*>(node.get());
     p->PrintIndent();
     p->stream << "prefetch " << op->func->func_name() << "(";
     for (size_t i = 0; i < op->bounds.size(); ++i) {
@@ -1019,8 +1019,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<IfThenElse>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const IfThenElse*>(node.get());
+.set_dispatch<IfThenElseNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const IfThenElseNode*>(node.get());
     p->PrintIndent();
     while (true) {
       p->stream << "if (" << op->condition << ") {\n";
@@ -1032,7 +1032,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
         break;
       }
 
-      if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
+      if (const IfThenElseNode *nested_if = op->else_case.as<IfThenElseNode>()) {
         p->PrintIndent();
         p->stream << "} else ";
         op = nested_if;
@@ -1050,8 +1050,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
 });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Evaluate>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Evaluate*>(node.get());
+.set_dispatch<EvaluateNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const EvaluateNode*>(node.get());
     p->PrintIndent();
     p->Print(op->value);
     p->stream << "\n";
@@ -1068,8 +1068,8 @@ void PrintList(const Array<T> &exprs, NodePrinter* p) {
 }
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Shuffle>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Shuffle*>(node.get());
+.set_dispatch<ShuffleNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ShuffleNode*>(node.get());
     p->stream << "shuffle(";
     PrintList(op->vectors, p);
     p->stream << ", ";
@@ -1121,8 +1121,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Reduce>([](const ObjectRef& node, NodePrinter* p) {
-    auto* op = static_cast<const Reduce*>(node.get());
+.set_dispatch<ReduceNode>([](const ObjectRef& node, NodePrinter* p) {
+    auto* op = static_cast<const ReduceNode*>(node.get());
     p->stream << "reduce(combiner="
               << op->combiner;
     p->stream << ", source=" << op->source;
@@ -1143,58 +1143,58 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   });
 
 TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
-.set_dispatch<Any>([](const ObjectRef& node, NodePrinter* p) {
+.set_dispatch<AnyNode>([](const ObjectRef& node, NodePrinter* p) {
     p->stream << "?";
 });
 
 TVM_REGISTER_NODE_TYPE(CommReducerNode);
-TVM_REGISTER_NODE_TYPE(Reduce);
-TVM_REGISTER_NODE_TYPE(Any);
-TVM_REGISTER_NODE_TYPE(AttrStmt);
-TVM_REGISTER_NODE_TYPE(FloatImm);
-TVM_REGISTER_NODE_TYPE(IntImm);
-TVM_REGISTER_NODE_TYPE(UIntImm);
-TVM_REGISTER_NODE_TYPE(StringImm);
-TVM_REGISTER_NODE_TYPE(Cast);
-TVM_REGISTER_NODE_TYPE(Variable);
-TVM_REGISTER_NODE_TYPE(Add);
-TVM_REGISTER_NODE_TYPE(Sub);
-TVM_REGISTER_NODE_TYPE(Mul);
-TVM_REGISTER_NODE_TYPE(Div);
-TVM_REGISTER_NODE_TYPE(Mod);
-TVM_REGISTER_NODE_TYPE(FloorDiv);
-TVM_REGISTER_NODE_TYPE(FloorMod);
-TVM_REGISTER_NODE_TYPE(Min);
-TVM_REGISTER_NODE_TYPE(Max);
-TVM_REGISTER_NODE_TYPE(EQ);
-TVM_REGISTER_NODE_TYPE(NE);
-TVM_REGISTER_NODE_TYPE(LT);
-TVM_REGISTER_NODE_TYPE(LE);
-TVM_REGISTER_NODE_TYPE(GT);
-TVM_REGISTER_NODE_TYPE(GE);
-TVM_REGISTER_NODE_TYPE(And);
-TVM_REGISTER_NODE_TYPE(Or);
-TVM_REGISTER_NODE_TYPE(Not);
-TVM_REGISTER_NODE_TYPE(Select);
-TVM_REGISTER_NODE_TYPE(Load);
-TVM_REGISTER_NODE_TYPE(Ramp);
-TVM_REGISTER_NODE_TYPE(Broadcast);
-TVM_REGISTER_NODE_TYPE(Shuffle);
-TVM_REGISTER_NODE_TYPE(Prefetch);
-TVM_REGISTER_NODE_TYPE(Call);
-TVM_REGISTER_NODE_TYPE(Let);
-TVM_REGISTER_NODE_TYPE(LetStmt);
-TVM_REGISTER_NODE_TYPE(AssertStmt);
-TVM_REGISTER_NODE_TYPE(ProducerConsumer);
-TVM_REGISTER_NODE_TYPE(For);
-TVM_REGISTER_NODE_TYPE(Store);
-TVM_REGISTER_NODE_TYPE(Provide);
-TVM_REGISTER_NODE_TYPE(Allocate);
-TVM_REGISTER_NODE_TYPE(Free);
-TVM_REGISTER_NODE_TYPE(Realize);
+TVM_REGISTER_NODE_TYPE(ReduceNode);
+TVM_REGISTER_NODE_TYPE(AnyNode);
+TVM_REGISTER_NODE_TYPE(AttrStmtNode);
+TVM_REGISTER_NODE_TYPE(FloatImmNode);
+TVM_REGISTER_NODE_TYPE(IntImmNode);
+TVM_REGISTER_NODE_TYPE(UIntImmNode);
+TVM_REGISTER_NODE_TYPE(StringImmNode);
+TVM_REGISTER_NODE_TYPE(CastNode);
+TVM_REGISTER_NODE_TYPE(VarNode);
+TVM_REGISTER_NODE_TYPE(AddNode);
+TVM_REGISTER_NODE_TYPE(SubNode);
+TVM_REGISTER_NODE_TYPE(MulNode);
+TVM_REGISTER_NODE_TYPE(DivNode);
+TVM_REGISTER_NODE_TYPE(ModNode);
+TVM_REGISTER_NODE_TYPE(FloorDivNode);
+TVM_REGISTER_NODE_TYPE(FloorModNode);
+TVM_REGISTER_NODE_TYPE(MinNode);
+TVM_REGISTER_NODE_TYPE(MaxNode);
+TVM_REGISTER_NODE_TYPE(EQNode);
+TVM_REGISTER_NODE_TYPE(NENode);
+TVM_REGISTER_NODE_TYPE(LTNode);
+TVM_REGISTER_NODE_TYPE(LENode);
+TVM_REGISTER_NODE_TYPE(GTNode);
+TVM_REGISTER_NODE_TYPE(GENode);
+TVM_REGISTER_NODE_TYPE(AndNode);
+TVM_REGISTER_NODE_TYPE(OrNode);
+TVM_REGISTER_NODE_TYPE(NotNode);
+TVM_REGISTER_NODE_TYPE(SelectNode);
+TVM_REGISTER_NODE_TYPE(LoadNode);
+TVM_REGISTER_NODE_TYPE(RampNode);
+TVM_REGISTER_NODE_TYPE(BroadcastNode);
+TVM_REGISTER_NODE_TYPE(ShuffleNode);
+TVM_REGISTER_NODE_TYPE(PrefetchNode);
+TVM_REGISTER_NODE_TYPE(CallNode);
+TVM_REGISTER_NODE_TYPE(LetNode);
+TVM_REGISTER_NODE_TYPE(LetStmtNode);
+TVM_REGISTER_NODE_TYPE(AssertStmtNode);
+TVM_REGISTER_NODE_TYPE(ProducerConsumerNode);
+TVM_REGISTER_NODE_TYPE(ForNode);
+TVM_REGISTER_NODE_TYPE(StoreNode);
+TVM_REGISTER_NODE_TYPE(ProvideNode);
+TVM_REGISTER_NODE_TYPE(AllocateNode);
+TVM_REGISTER_NODE_TYPE(FreeNode);
+TVM_REGISTER_NODE_TYPE(RealizeNode);
 TVM_REGISTER_NODE_TYPE(SeqStmtNode);
-TVM_REGISTER_NODE_TYPE(IfThenElse);
-TVM_REGISTER_NODE_TYPE(Evaluate);
+TVM_REGISTER_NODE_TYPE(IfThenElseNode);
+TVM_REGISTER_NODE_TYPE(EvaluateNode);
 
 }  // namespace ir
 }  // namespace tvm
index d0e81b9..f797700 100644 (file)
@@ -34,14 +34,14 @@ Expr Tensor::operator()(Array<Var> indices) const {
 }
 
 Expr Tensor::operator()(Array<Expr> indices) const {
-  using ir::Call;
+  using ir::CallNode;
   if (ndim() != 0) {
     CHECK_EQ(ndim(), indices.size())
         << "Tensor dimension mismatch in read"
         << "ndim = " << ndim() << ", indices.size=" << indices.size();
   }
-  auto n = Call::make(
-      (*this)->dtype, (*this)->op->name, indices, Call::Halide,
+  auto n = CallNode::make(
+      (*this)->dtype, (*this)->op->name, indices, CallNode::Halide,
       (*this)->op, (*this)->value_index);
   return n;
 }
index 6146284..0ad68b1 100644 (file)
@@ -50,7 +50,7 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode);
 /// Verify if ComputeOp is valid with respect to Reduce operations.
 static void VerifyComputeOp(const ComputeOpNode *op);
 
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
   return (a->combiner.same_as(b->combiner)) &&
          (a->source.same_as(b->source)) &&
          (a->axis.same_as(b->axis)) &&
@@ -148,8 +148,8 @@ Operation ComputeOpNode::make(std::string name,
   n->attrs = std::move(attrs);
   n->axis = std::move(axis);
   n->body = std::move(body);
-  if (n->body[0]->IsInstance<ir::Reduce>()) {
-    const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
+  if (n->body[0]->IsInstance<ir::ReduceNode>()) {
+    const ir::ReduceNode* reduce = n->body[0].as<ir::ReduceNode>();
     n->reduce_axis = reduce->axis;
   }
   VerifyComputeOp(n.get());
@@ -162,7 +162,7 @@ Array<Tensor> ComputeOpNode::InputTensors() const {
   std::unordered_set<Tensor> visited;
   for (auto& e : body) {
     ir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
-        const ir::Call *call = n.as<ir::Call>();
+        const ir::CallNode *call = n.as<ir::CallNode>();
         if (call != nullptr && call->func.defined()) {
           Tensor t = Downcast<Operation>(call->func).output(call->value_index);
           if (!visited.count(t)) {
@@ -181,14 +181,14 @@ Operation ComputeOpNode::ReplaceInputs(
   CHECK_EQ(self.operator->(), this);
   VerifyComputeOp(this);
   Array<Expr> arr;
-  if (this->body[0]->IsInstance<ir::Reduce>()) {
+  if (this->body[0]->IsInstance<ir::ReduceNode>()) {
     // Specially handle reduce so the replaced op
     // still share all the components
     Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
     if (!new_reduce.same_as(this->body[0])) {
-      const ir::Reduce* r = new_reduce.as<ir::Reduce>();
+      const ir::ReduceNode* r = new_reduce.as<ir::ReduceNode>();
       for (size_t k = 0; k < this->body.size(); ++k) {
-        auto n = make_object<ir::Reduce>(*r);
+        auto n = make_object<ir::ReduceNode>(*r);
         n->value_index = static_cast<int>(k);
         n->dtype = r->source[k].dtype();
         arr.push_back(Expr(n));
@@ -212,11 +212,11 @@ Operation ComputeOpNode::ReplaceInputs(
 void ComputeOpNode::PropBoundToInputs(
     const Operation& self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   CHECK_EQ(self.operator->(), this);
   auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) {
-    auto *call = n.as<ir::Call>();
+    auto *call = n.as<ir::CallNode>();
     if (call != nullptr && call->func.defined()) {
       Tensor t = Downcast<Operation>(call->func).output(call->value_index);
       if (t->op.defined() && out_dom_map->count(t)) {
@@ -282,7 +282,7 @@ Stmt BaseComputeOpNode::BuildRealize(
   Stmt realize = body;
   for (int i = this->num_outputs(); i > 0; --i) {
     Tensor t = stage->op.output(i-1);
-    realize = ir::Realize::make(t->op, t->value_index,
+    realize = ir::RealizeNode::make(t->op, t->value_index,
       t->dtype, bounds, const_true(), realize);
     // alignment requirement, only useful for compute
     for (size_t i = 0; i < num_schedulable_dims(); ++i) {
@@ -293,9 +293,11 @@ Stmt BaseComputeOpNode::BuildRealize(
           Array<Expr> tuple = {static_cast<int>(i),
                                attr->dim_align_factor,
                                attr->dim_align_offset};
-          realize = ir::AttrStmt::make(
+          realize = ir::AttrStmtNode::make(
               t, ir::attr::buffer_dim_align,
-              Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
+              CallNode::make(DataType::Handle(),
+                             ir::intrinsic::tvm_tuple,
+                             tuple, CallNode::Intrinsic),
               realize);
         }
       }
@@ -320,7 +322,7 @@ void MakeReduction(const ComputeOpNode* op,
   std::vector<Stmt> inits, provides;
 
   size_t size = op->body.size();
-  const Reduce* reduce = op->body[0].as<Reduce>();
+  const ReduceNode* reduce = op->body[0].as<ReduceNode>();
   CHECK(reduce);
   const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
   CHECK(combiner);
@@ -332,15 +334,15 @@ void MakeReduction(const ComputeOpNode* op,
   Array<Expr> update_value = (*combiner)(lhs, reduce->source);
   for (size_t i = 0; i < size; ++i) {
     Tensor t = tensors[i];
-    inits.emplace_back(Provide::make(
+    inits.emplace_back(ProvideNode::make(
           t->op, t->value_index, init_value[i], args));
-    provides.emplace_back(Provide::make(
+    provides.emplace_back(ProvideNode::make(
           t->op, t->value_index, update_value[i], args));
   }
   *init = SeqStmt::Flatten(inits);
   *provide = SeqStmt::Flatten(provides);
   if (!is_one(reduce->condition)) {
-    *provide = IfThenElse::make(reduce->condition, *provide);
+    *provide = IfThenElseNode::make(reduce->condition, *provide);
   }
 }
 
@@ -351,7 +353,7 @@ Stmt MakeProvide(const ComputeOpNode* op,
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
-  return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
+  return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args);
 }
 
 Stmt MakeComputeStmt(const ComputeOpNode* self,
@@ -543,7 +545,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
   /// Special member functions
   //@{
   explicit ComputeVerifier(const ComputeOpNode* compute)
-      : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
+      : compute_(compute), reduce_(compute->body[0].as<ir::ReduceNode>()) {}
   virtual ~ComputeVerifier() = default;
   ComputeVerifier(const ComputeVerifier&) = delete;
   ComputeVerifier(ComputeVerifier&&) = delete;
@@ -555,7 +557,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
   void Run() {
     for (const Expr e : compute_->body) {
       // Check for consistency of top level reductions
-      const ir::Reduce* reduce = e.as<ir::Reduce>();
+      const ir::ReduceNode* reduce = e.as<ir::ReduceNode>();
       CHECK((reduce && reduce_) || (!reduce && !reduce_))
           << "All ComputeOp should be consistent "
           << "with being Reduce operation or not.";
@@ -580,7 +582,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
     --level_;
   }
 
-  void VisitExpr_(const ir::Reduce* op) final {
+  void VisitExpr_(const ir::ReduceNode* op) final {
     // Check for non top level reductions
     CHECK(0 == level_)
         << "Reductions are only allowed at the top level of compute. "
@@ -590,7 +592,7 @@ class ComputeVerifier final : protected ir::ExprVisitor {
 
  private:
   const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
-  const ir::Reduce* reduce_{nullptr};      ///< Top level Reduce operation
+  const ir::ReduceNode* reduce_{nullptr};      ///< Top level Reduce operation
   int level_{0};                           ///< Level of op being processed
 };
 }  // namespace
@@ -607,7 +609,7 @@ Stmt TransformUpdate(const Stage& stage,
                      Stmt body,
                      Stmt update) {
   Array<Expr> conds;
-  std::unordered_set<const Variable*> banned;
+  std::unordered_set<const VarNode*> banned;
   for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
     IterVar iv = stage->leaf_iter_vars[i];
     auto iit = stage->iter_var_attrs.find(iv);
@@ -632,7 +634,7 @@ Stmt TransformUpdate(const Stage& stage,
     }
   }
 
-  return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
+  return IfThenElseNode::make(arith::ComputeReduce<ir::OrNode>(conds, const_true(1)),
                           update, body);
 }
 }  // namespace tvm
index ab56fc9..89d0ca7 100644 (file)
@@ -46,9 +46,9 @@ Stmt MakeCrossThreadReduction(
 
   size_t size = self->body.size();
   CHECK_GT(size, 0);
-  std::vector<const Reduce*> reduces(size);
+  std::vector<const ReduceNode*> reduces(size);
   for (size_t i = 0; i < size; ++i) {
-    const Reduce* reduce = self->body[i].as<Reduce>();
+    const ReduceNode* reduce = self->body[i].as<ReduceNode>();
     CHECK(reduce);
     reduces[i] = reduce;
   }
@@ -84,11 +84,11 @@ Stmt MakeCrossThreadReduction(
     thread_head_check.emplace_back(stage->store_predicate);
   }
 
-  Stmt reduce_body = Evaluate::make(Call::make(
+  Stmt reduce_body = EvaluateNode::make(CallNode::make(
       DataType::Handle(),
       ir::intrinsic::tvm_thread_allreduce,
-      freduce_args, Call::Intrinsic));
-  reduce_body = AttrStmt::make(
+      freduce_args, CallNode::Intrinsic));
+  reduce_body = AttrStmtNode::make(
       reduces[0]->combiner,
       attr::reduce_scope,
       make_zero(DataType::Handle()),
@@ -96,19 +96,19 @@ Stmt MakeCrossThreadReduction(
   std::vector<Stmt> assigns(size);
   for (size_t idx = 0; idx < size; ++idx) {
     DataType t = reduces[idx]->dtype;
-    assigns[idx] = Provide::make(
+    assigns[idx] = ProvideNode::make(
       stage->op, idx,
-      Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
+      LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args);
   }
   Stmt assign_body = SeqStmt::Flatten(assigns);
   assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
   assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
   Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
   for (size_t idx = size; idx != 0; --idx) {
-    body = Allocate::make(
+    body = AllocateNode::make(
       res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
-    body = AttrStmt::make(
-      res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body);
+    body = AttrStmtNode::make(
+      res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
   }
   body = op::Substitute(body, value_map);
   return MergeNest(nest, body);
index c6102ed..ee958da 100644 (file)
@@ -113,7 +113,7 @@ Operation ExternOpNode::ReplaceInputs(
 void ExternOpNode::PropBoundToInputs(
     const Operation& self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   for (Tensor t : this->inputs) {
     auto it = out_dom_map->find(t);
@@ -147,7 +147,7 @@ Stmt ExternOpNode::BuildRealize(
           Range::make_by_min_extent(
               make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
-    realize_body = ir::Realize::make(
+    realize_body = ir::RealizeNode::make(
         t->op, t->value_index, t->dtype,
         bounds, const_true(), realize_body);
   }
@@ -159,7 +159,7 @@ Stmt ExternOpNode::BuildProvide(
     const std::unordered_map<IterVar, Range>& dom_map,
     bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
   auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
     Array<ObjectRef> bind_spec;
     Array<Expr> tuple;
@@ -169,9 +169,9 @@ Stmt ExternOpNode::BuildProvide(
       tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
       tuple.push_back(buffer->shape[k]);
     }
-    ret = AttrStmt::make(
+    ret = AttrStmtNode::make(
         bind_spec, attr::buffer_bind_scope,
-        Call::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
+        CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
   };
   for (size_t i = output_placeholders.size(); i != 0; --i) {
     f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
index b4f29f5..5364c38 100644 (file)
@@ -92,7 +92,7 @@ Array<Tensor> HybridOpNode::InputTensors() const {
   std::unordered_set<Tensor> visited;
   Array<Tensor> curr_inputs;
   ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) {
-      const ir::Call *call = n.as<ir::Call>();
+      const ir::CallNode *call = n.as<ir::CallNode>();
       if (call != nullptr && call->func.defined()) {
         Tensor t = Downcast<Operation>(call->func).output(call->value_index);
         if (orig_inputs.count(t) && !visited.count(t)) {
@@ -128,7 +128,7 @@ Operation HybridOpNode::ReplaceInputs(
 void HybridOpNode::PropBoundToInputs(
     const Operation &self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet> &dom_map,
+    const std::unordered_map<const VarNode*, IntSet> &dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   auto curr_inputs = InputTensors();
   for (Tensor t : curr_inputs) {
@@ -168,7 +168,7 @@ Stmt HybridOpNode::BuildRealize(
           Range::make_by_min_extent(
               make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
-    realize_body = ir::Realize::make(
+    realize_body = ir::RealizeNode::make(
         t->op, t->value_index, t->dtype,
         bounds, const_true(), realize_body);
   }
@@ -180,7 +180,7 @@ Stmt HybridOpNode::BuildProvide(
     const std::unordered_map<IterVar, Range> &dom_map,
     bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret = AttrStmt::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
+  Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
   std::unordered_map<Tensor, Tensor> rmap;
   for (int i = 0; i < this->num_outputs(); ++i) {
     rmap[outputs[i]] = stage->op.output(i);
@@ -223,7 +223,7 @@ Stmt ApplyLoopShapes(const Stage &stage,
                  const std::unordered_map<IterVar, Range> &dom_map, Stmt stmt) {
   class LoopSpliter : public StmtExprMutator {
     Expr factor;
-    const Variable *parent;
+    const VarNode *parent;
     IterVar inner, outer;
 
    public:
@@ -247,16 +247,16 @@ Stmt ApplyLoopShapes(const Stage &stage,
       outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type);
     }
 
-    Stmt VisitStmt_(const For *op) final {
+    Stmt VisitStmt_(const ForNode *op) final {
       if (op->loop_var.get() == parent) {
-        std::unordered_map<const Variable *, Expr> rmap;
+        std::unordered_map<const VarNode *, Expr> rmap;
         rmap[op->loop_var.get()] = inner + outer * factor;
         Stmt ret = ir::Substitute(op->body, rmap);
         Expr cond = likely(outer * factor < (op->extent - inner));
-        ret = IfThenElse::make(cond, ret);
-        ret = For::make(inner->var, Expr(0), inner->dom->extent,
+        ret = IfThenElseNode::make(cond, ret);
+        ret = ForNode::make(inner->var, Expr(0), inner->dom->extent,
                         IterVarTypeToForType(inner->iter_type), op->device_api, ret);
-        ret = For::make(outer->var, Expr(0), outer->dom->extent,
+        ret = ForNode::make(outer->var, Expr(0), outer->dom->extent,
                         IterVarTypeToForType(outer->iter_type), op->device_api, ret);
         splitted = true;
         return ret;
@@ -267,8 +267,8 @@ Stmt ApplyLoopShapes(const Stage &stage,
 
   class LoopFuser : public StmtExprMutator {
     const IterVar &parent;
-    const Variable *inner;
-    const Variable *outer;
+    const VarNode *inner;
+    const VarNode *outer;
     bool under_outer;
     Expr extent;
 
@@ -280,10 +280,10 @@ Stmt ApplyLoopShapes(const Stage &stage,
         extent(0), fused(false) {}
 
     // TODO(@were): Handle imperfect loops
-    Stmt VisitStmt_(const For* op) final {
+    Stmt VisitStmt_(const ForNode* op) final {
       if (op->loop_var.get() == inner) {
         CHECK(under_outer);
-        std::unordered_map<const Variable *, Expr> rmap;
+        std::unordered_map<const VarNode *, Expr> rmap;
         rmap[op->loop_var.get()] = indexmod(parent, op->extent);
         extent = op->extent;
         fused = true;
@@ -291,15 +291,15 @@ Stmt ApplyLoopShapes(const Stage &stage,
       } else if (op->loop_var.get() == outer) {
         under_outer = true;
         Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const Variable *, Expr> rmap;
+        std::unordered_map<const VarNode *, Expr> rmap;
         rmap[op->loop_var.get()] = indexdiv(parent, extent);
         body = ir::Substitute(body, rmap);
         under_outer = false;
-        return For::make(parent->var, Expr(0), extent * op->extent,
+        return ForNode::make(parent->var, Expr(0), extent * op->extent,
                          op->for_type, op->device_api, body);
       } else if (under_outer) {
         Stmt body = this->VisitStmt(op->body);
-        std::unordered_map<const Variable *, Expr> rmap;
+        std::unordered_map<const VarNode *, Expr> rmap;
         rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent);
         body = ir::Substitute(body, rmap);
         extent = extent * op->extent;
@@ -327,13 +327,13 @@ Stmt ApplyLoopShapes(const Stage &stage,
 Stmt ApplyLoopAnnotations(const Stage &stage,
                           const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
   class LoopAnnotator : public StmtMutator {
-    const Variable *var;
+    const VarNode *var;
     const IterVarAttr &attr;
 
    public:
-    LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
+    LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {}
 
-    Stmt VisitStmt_(const For *op) final {
+    Stmt VisitStmt_(const ForNode *op) final {
       if (op->loop_var.get() == var) {
         if (attr->bind_thread.defined()) {
           const auto &iter_var = attr->bind_thread;
@@ -342,12 +342,12 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
             CHECK(Equal(iter_var->dom->extent, op->extent))
               << "Thread extent and loop extent mismatch!\n";
           }
-          std::unordered_map<const Variable *, Expr> rmap;
+          std::unordered_map<const VarNode *, Expr> rmap;
           rmap[op->loop_var.get()] = iter_var;
           Stmt body = ir::Substitute(op->body, rmap);
-          return AttrStmt::make(iter_var, "thread_extent", op->extent, body);
+          return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
         } else {
-          return For::make(op->loop_var, op->min, op->extent,
+          return ForNode::make(op->loop_var, op->min, op->extent,
                            IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
         }
       }
@@ -360,7 +360,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
     int found = 0;
 
     const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var;
-    const Variable *var = actual->var.get();
+    const VarNode *var = actual->var.get();
     ForType expected = IterVarTypeToForType(iter_var->iter_type);
     IterVarAttr attr;
     if (stage->iter_var_attrs.count(iter_var)) {
@@ -370,7 +370,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
 
     PostOrderVisit(stmt,
     [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) {
-      if (const For *op = node.as<For>()) {
+      if (const ForNode *op = node.as<ForNode>()) {
         if (op->loop_var.get() == var) {
           ++found;
           need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined());
@@ -389,15 +389,15 @@ Stmt ApplyLoopAnnotations(const Stage &stage,
 Stmt ApplyLoopOrder(const Stage &stage,
                     const std::unordered_map<IterVar, Range> &dom_map,
                     const std::unordered_map<IterVar, IterVar> &rebased, Stmt stmt) {
-  std::vector<const Variable*> current_order;
+  std::vector<const VarNode*> current_order;
   PostOrderVisit(stmt, [&current_order](const ObjectRef& node) {
-    if (const For *op = node.as<For>())
+    if (const ForNode *op = node.as<ForNode>())
       current_order.push_back(op->loop_var.get());
   });
   std::reverse(current_order.begin(), current_order.end());
   auto &required_ord = stage->leaf_iter_vars;
   CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!";
-  std::unordered_map<const Variable *, IterVar> reorder;
+  std::unordered_map<const VarNode *, IterVar> reorder;
   bool need_reorder = false;
   for (size_t i = 0; i < current_order.size(); ++i) {
     auto &current = current_order[i];
@@ -413,15 +413,15 @@ Stmt ApplyLoopOrder(const Stage &stage,
   class LoopReorder : public StmtMutator {
     const Stage &stage;
     const std::unordered_map<IterVar, Range> &dom_map;
-    const std::unordered_map<const Variable *, IterVar> &reorder;
+    const std::unordered_map<const VarNode *, IterVar> &reorder;
 
    public:
     LoopReorder(const Stage &stage,
                 const std::unordered_map<IterVar, Range> &dom_map,
-                const std::unordered_map<const Variable*, IterVar> &reorder)
+                const std::unordered_map<const VarNode*, IterVar> &reorder)
       : stage(stage), dom_map(dom_map), reorder(reorder) {}
 
-    Stmt VisitStmt_(const For* op) final {
+    Stmt VisitStmt_(const ForNode* op) final {
       // Reorder from in to out
       Stmt body_ = this->VisitStmt(op->body);
       CHECK(reorder.count(op->loop_var.get()));
@@ -434,7 +434,7 @@ Stmt ApplyLoopOrder(const Stage &stage,
         for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
       }
       const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
-      return For::make(target->var, range->min, range->extent,
+      return ForNode::make(target->var, range->min, range->extent,
                        for_type, DeviceAPI::None, body);
     }
   };
@@ -467,7 +467,7 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
   // TODO(@were): Write a comprehensive pass to analyze iter var types
   std::vector<IterVar> res_;
   PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
-    if (const For *op = node.as<For>()) {
+    if (const ForNode *op = node.as<ForNode>()) {
       Var loop_var(op->loop_var);
       Range dom = Range::make_by_min_extent(op->min, op->extent);
       res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type)));
@@ -483,11 +483,11 @@ class ProviderReplacer : public ir::StmtMutator {
   explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor> &vmap)
       : vmap_(vmap) {}
 
-  Stmt VisitStmt_(const ir::Provide* op) final {
+  Stmt VisitStmt_(const ir::ProvideNode* op) final {
     Tensor t = Downcast<Operation>(op->func).output(op->value_index);
     auto it = vmap_.find(t);
     if (it != vmap_.end()) {
-      Stmt ret = ir::Provide::make(
+      Stmt ret = ir::ProvideNode::make(
         it->second->op, it->second->value_index, op->value, op->args);
       found = true;
       return this->VisitStmt(ret);
index 4a6d0d2..31d736d 100644 (file)
@@ -45,7 +45,7 @@ MakeLoopNest(const Stage& stage,
              std::unordered_map<IterVar, Expr>* p_value_map,
              bool debug_keep_trivial_loop) {
   auto leaf_iter_vars = stage->leaf_iter_vars;
-  Stmt no_op = Evaluate::make(0);
+  Stmt no_op = EvaluateNode::make(0);
   // create the loop nest
   std::vector<std::vector<Stmt> > nest;
   nest.resize(leaf_iter_vars.size() + 1);
@@ -95,33 +95,33 @@ MakeLoopNest(const Stage& stage,
         }
         CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
         for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
-          const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
+          const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
           Expr pvalue = it_attr->pragma_values[k];
           if (!pvalue.defined()) {
             pvalue = make_const(DataType::Int(32), 1);
           }
           nest[i + 1].emplace_back(
-              AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+              AttrStmtNode::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
         }
       }
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         nest[i + 1].emplace_back(
-            LetStmt::make(var, dom->min, no_op));
+            LetStmtNode::make(var, dom->min, no_op));
         value_map[iv] = dom->min;
       } else if (is_zero(dom->min)) {
         nest[i + 1].emplace_back(
-            For::make(var, 0, dom->extent,
+            ForNode::make(var, 0, dom->extent,
                       for_type, DeviceAPI::None, no_op));
         value_map[iv] = var;
       } else {
         Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
         nest[i + 1].emplace_back(
-            For::make(idx, 0, dom->extent,
+            ForNode::make(idx, 0, dom->extent,
                       for_type, DeviceAPI::None, no_op));
         Expr new_value = dom->min + idx;
         value_map[iv] = new_value;
         nest[i + 1].emplace_back(
-            LetStmt::make(var, new_value, no_op));
+            LetStmtNode::make(var, new_value, no_op));
       }
       if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
         CHECK(!is_one(dom->extent))
@@ -130,7 +130,7 @@ MakeLoopNest(const Stage& stage,
                  it_attr->prefetch_offset.size());
         for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
           nest[i + 1].emplace_back(
-              AttrStmt::make(it_attr->prefetch_data[j],
+              AttrStmtNode::make(it_attr->prefetch_data[j],
                              ir::attr::prefetch_scope,
                              it_attr->prefetch_offset[j], no_op));
         }
@@ -143,7 +143,7 @@ MakeLoopNest(const Stage& stage,
       CHECK(is_positive_const(dom->extent));
       // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
+          AttrStmtNode::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
       value_map[iv] = var;
     } else if (bind_iv->thread_tag == "pipeline") {
       // pipeline marker.
@@ -151,14 +151,14 @@ MakeLoopNest(const Stage& stage,
       CHECK(is_one(dom->extent));
       // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
+          AttrStmtNode::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
       value_map[iv] = dom->min;
     } else {
       // Always restrict threaded IterVar to starts from 0.
       CHECK(is_zero(dom->min));
       // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
+          AttrStmtNode::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         value_map[iv] = dom->min;
       } else {
@@ -168,7 +168,7 @@ MakeLoopNest(const Stage& stage,
     // annotate the extent of the IterVar
     if (!new_loop_var) {
       nest[i + 1].emplace_back(
-          AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
+          AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
     }
   }
   // message passing to get offset of root iter vars.
@@ -177,10 +177,10 @@ MakeLoopNest(const Stage& stage,
 }
 
 std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
-  Stmt no_op = Evaluate::make(0);
+  Stmt no_op = EvaluateNode::make(0);
   std::vector<Stmt> nest;
   for (const Expr& cond : predicates) {
-    nest.emplace_back(IfThenElse::make(cond, no_op));
+    nest.emplace_back(IfThenElseNode::make(cond, no_op));
   }
   return nest;
 }
@@ -191,12 +191,12 @@ class TensorReplacer : public ir::StmtExprMutator {
   explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
       : vmap_(vmap) {}
 
-  Expr VisitExpr_(const ir::Call* op) final {
-    if (op->call_type == ir::Call::Halide) {
+  Expr VisitExpr_(const ir::CallNode* op) final {
+    if (op->call_type == ir::CallNode::Halide) {
       Tensor t = Downcast<Operation>(op->func).output(op->value_index);
       auto it = vmap_.find(t);
       if (it != vmap_.end()) {
-        Expr ret = ir::Call::make(
+        Expr ret = ir::CallNode::make(
             op->dtype, it->second->op->name, op->args,
             op->call_type, it->second->op, it->second->value_index);
         found = true;
@@ -229,7 +229,7 @@ Expr ReplaceTensor(Expr expr,
 
 Stmt Substitute(Stmt s,
                 const std::unordered_map<IterVar, Expr>& value_map) {
-  std::unordered_map<const Variable*, Expr> init;
+  std::unordered_map<const VarNode*, Expr> init;
   for (const auto& kv : value_map) {
     init[kv.first->var.get()] = kv.second;
   }
index 6414d5c..2ec10ca 100644 (file)
@@ -79,7 +79,7 @@ Operation PlaceholderOpNode::ReplaceInputs(
 void PlaceholderOpNode::PropBoundToInputs(
     const Operation& self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
 }
 
index ef2d1ef..c4c0960 100644 (file)
@@ -177,7 +177,7 @@ Operation ScanOpNode::ReplaceInputs(
 void ScanOpNode::PropBoundToInputs(
     const Operation& self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   CHECK_EQ(self.operator->(), this);
   for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) {
@@ -241,7 +241,7 @@ void ScanOpNode::GatherBound(
       IterVar sp_ax = this->spatial_axis_[sp_idx];
       CHECK(!out_dom_map->count(sp_ax));
       CHECK(fix_pt.count(sp_ax));
-      if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
+      if (fix_pt[sp_ax].as<ir::IntImmNode>()->value) {
         // fix point, we can slice it.
         (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
       } else {
@@ -271,7 +271,7 @@ Stmt ScanOpNode::BuildRealize(
       IterVar sp_ax = this->spatial_axis_[sp_idx];
       bounds.push_back(dom_map.at(sp_ax));
     }
-    ret = ir::Realize::make(t->op, t->value_index, t->dtype,
+    ret = ir::RealizeNode::make(t->op, t->value_index, t->dtype,
                             bounds, const_true(), ret);
   }
   return ret;
@@ -282,12 +282,12 @@ Stmt ScanOpNode::BuildProvide(
     const std::unordered_map<IterVar, Range>& dom_map,
     bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt provide = AttrStmt::make(
+  Stmt provide = AttrStmtNode::make(
       stage->op, attr::scan_update_scope, this->scan_axis->var,
-      Evaluate::make(0));
-  Stmt init = AttrStmt::make(
+      EvaluateNode::make(0));
+  Stmt init = AttrStmtNode::make(
       stage->op, attr::scan_init_scope, 0,
-      Evaluate::make(0));
+      EvaluateNode::make(0));
   size_t begin_scan = 0;
   for (size_t  i = 0; i < stage->leaf_iter_vars.size(); ++i) {
     if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
index a6252df..e0656ea 100644 (file)
@@ -109,7 +109,7 @@ Operation TensorComputeOpNode::ReplaceInputs(
 void TensorComputeOpNode::PropBoundToInputs(
     const Operation& self,
     arith::Analyzer* analyzer,
-    const std::unordered_map<const Variable*, IntSet>& dom_map,
+    const std::unordered_map<const VarNode*, IntSet>& dom_map,
     std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
   for (size_t i = 0; i < this->inputs.size(); ++i) {
     Tensor t = this->inputs[i];
@@ -135,7 +135,7 @@ Stmt TensorComputeOpNode::BuildProvide(
   CHECK_EQ(stage->op.operator->(), this);
 
   // Start bind data.
-  Stmt nop = Evaluate::make(0);
+  Stmt nop = EvaluateNode::make(0);
   std::vector<Stmt> input_bind_nest, output_bind_nest;
   Array<Tensor> inputs = this->InputTensors();
 
@@ -152,9 +152,11 @@ Stmt TensorComputeOpNode::BuildProvide(
       tuple.push_back(region[i]->min);
       tuple.push_back(region[i]->extent);
     }
-    input_bind_nest.emplace_back(AttrStmt::make(
+    input_bind_nest.emplace_back(AttrStmtNode::make(
         bind_spec, ir::attr::buffer_bind_scope,
-        Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
   }
 
   // output binding
@@ -176,13 +178,15 @@ Stmt TensorComputeOpNode::BuildProvide(
       }
     }
 
-    output_bind_nest.emplace_back(AttrStmt::make(
+    output_bind_nest.emplace_back(AttrStmtNode::make(
         bind_spec, ir::attr::buffer_bind_scope,
-        Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
   }
 
   // Check variable remap
-  std::unordered_map<const Variable*, Expr> vmap;
+  std::unordered_map<const VarNode*, Expr> vmap;
   ir::ArgBinder binder(&vmap);
 
   // Map the expressions passed in the call to the TensorIntrin, to the placeholder
index 0df8e88..601c444 100644 (file)
@@ -85,7 +85,7 @@ size_t InferTensorizeRegion(
   schedule::PassUpDomain(stage, dom_map, &up_state);
   // Get domains if inputs
   std::unordered_map<Tensor, TensorDom> in_dom;
-  std::unordered_map<const Variable*, IntSet> temp_dmap;
+  std::unordered_map<const VarNode*, IntSet> temp_dmap;
   arith::Analyzer analyzer;
   Array<Tensor> inputs = self->InputTensors();
   for (Tensor t : inputs) {
@@ -119,18 +119,18 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
                              const ComputeLoopNest& n,
                              size_t tloc) {
   // Veirfication step.
-  std::unordered_set<const Variable*> banned;
+  std::unordered_set<const VarNode*> banned;
   CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1);
   CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 ||
         n.init_nest.size() == 0);
   auto f_push_banned = [&banned](const Stmt& s) {
-    if (const For* op = s.as<For>()) {
+    if (const ForNode* op = s.as<ForNode>()) {
         banned.insert(op->loop_var.get());
-    } else if (const AttrStmt* op = s.as<AttrStmt>()) {
+    } else if (const AttrStmtNode* op = s.as<AttrStmtNode>()) {
       if (const IterVarNode* iv = op->node.as<IterVarNode>()) {
         banned.insert(iv->var.get());
       }
-    } else if (const LetStmt* op = s.as<LetStmt>()) {
+    } else if (const LetStmtNode* op = s.as<LetStmtNode>()) {
       banned.insert(op->var.get());
     }
   };
@@ -161,10 +161,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
 // Remap the tensor placeholder, index and inline things.
 class TensorIntrinMatcher final : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
-    if (op->call_type == Call::Halide) {
+    op = expr.as<CallNode>();
+    if (op->call_type == CallNode::Halide) {
       Tensor t = Downcast<Operation>(op->func).output(op->value_index);
       auto it = in_remap_.find(t);
       if (it != in_remap_.end()) {
@@ -174,7 +174,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
         for (size_t i = e.start; i < e.region.size(); ++i) {
           args.push_back(op->args[i] - e.region[i]->min);
         }
-        return Call::make(
+        return CallNode::make(
             op->dtype, e.tensor->op->name, args,
             op->call_type, e.tensor->op, e.tensor->value_index);
       }
@@ -182,7 +182,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
     return expr;
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = var_remap_.find(op);
     if (it != var_remap_.end()) {
       return it->second;
@@ -191,9 +191,9 @@ class TensorIntrinMatcher final : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Reduce* op) final {
+  Expr VisitExpr_(const ReduceNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Reduce>();
+    op = expr.as<ReduceNode>();
     Array<IterVar> axis;
     for (size_t i = 0; i < op->axis.size(); ++i) {
       auto it = axis_remap_.find(op->axis[i]);
@@ -201,7 +201,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
         axis.push_back(it->second);
       }
     }
-    return Reduce::make(
+    return ReduceNode::make(
         op->combiner, op->source, axis, op->condition, op->value_index);
   }
 
@@ -301,7 +301,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
   // input data remap
   std::unordered_map<Tensor, InputEntry> in_remap_;
   // variable remap.
-  std::unordered_map<const Variable*, Expr> var_remap_;
+  std::unordered_map<const VarNode*, Expr> var_remap_;
   // IterVar remap.
   std::unordered_map<IterVar, IterVar> axis_remap_;
 };
@@ -372,7 +372,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
   VerifyTensorizeLoopNest(self, stage, n, tloc);
   VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
   // Start bind data.
-  Stmt nop = Evaluate::make(0);
+  Stmt nop = EvaluateNode::make(0);
   std::vector<Stmt> input_bind_nest, output_bind_nest;
   Array<Tensor> inputs = self->InputTensors();
   CHECK_EQ(inputs.size(), intrin->inputs.size())
@@ -390,9 +390,11 @@ Stmt MakeTensorize(const ComputeOpNode* self,
       tuple.push_back(r->min);
       tuple.push_back(r->extent);
     }
-    input_bind_nest.emplace_back(AttrStmt::make(
+    input_bind_nest.emplace_back(AttrStmtNode::make(
         bind_spec, ir::attr::buffer_bind_scope,
-        Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
   }
   // output binding
   const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
@@ -410,12 +412,14 @@ Stmt MakeTensorize(const ComputeOpNode* self,
     Tensor tensor = stage->op.output(i - intrin->inputs.size());
     Buffer buffer = intrin->buffers[i];
     Array<ObjectRef> bind_spec{buffer, tensor};
-    output_bind_nest.emplace_back(AttrStmt::make(
+    output_bind_nest.emplace_back(AttrStmtNode::make(
         bind_spec, ir::attr::buffer_bind_scope,
-        Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop));
+        CallNode::make(DataType::Handle(),
+                       ir::intrinsic::tvm_tuple,
+                       tuple, CallNode::Intrinsic), nop));
   }
   // Check variable remap
-  std::unordered_map<const Variable*, Expr> vmap;
+  std::unordered_map<const VarNode*, Expr> vmap;
   ir::ArgBinder binder(&vmap);
   CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size())
       << "Tensorization fail: reduction axis size do not match";
index a0ddcd9..340f3a8 100644 (file)
@@ -42,7 +42,7 @@ void BinderAddAssert(Expr cond,
   if (!is_one(scond)) {
     std::ostringstream os;
     os << "Argument " << arg_name << " has an unsatisfied constraint";
-    asserts->emplace_back(AssertStmt::make(scond, os.str(), Evaluate::make(0)));
+    asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
   }
 }
 
@@ -51,14 +51,14 @@ bool ArgBinder::Bind_(const Expr& arg,
                       const std::string& arg_name,
                       bool with_lets) {
   CHECK_EQ(arg.dtype(), value.dtype());
-  if (const Variable* v = arg.as<Variable>()) {
+  if (const VarNode* v = arg.as<VarNode>()) {
     auto it = def_map_->find(v);
     if (it == def_map_->end()) {
       Var v_arg = Downcast<Var>(arg);
       defs_.emplace_back(v_arg);
       if (with_lets) {
         (*def_map_)[v] = arg;
-        init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
+        init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0)));
       } else {
         (*def_map_)[v] = value;
       }
@@ -164,7 +164,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
                              const std::string& arg_name) {
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
-  const Stmt nop = Evaluate::make(0);
+  const Stmt nop = EvaluateNode::make(0);
   // dimension checks
   Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
   Expr a_ndim = make_const(tvm_ndim_type,
@@ -173,51 +173,51 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   ndim_err_msg << arg_name
                << ".ndim is expected to equal "
                << buffer->shape.size();
-  asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
+  asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
   // type checks
   DataType dtype = buffer->dtype;
   std::ostringstream type_err_msg;
   type_err_msg << arg_name << ".dtype is expected to be " << dtype;
   Expr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) ==
-               UIntImm::make(DataType::UInt(8), dtype.code()) &&
+               UIntImmNode::make(DataType::UInt(8), dtype.code()) &&
                TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) ==
-               UIntImm::make(DataType::UInt(8), dtype.bits()) &&
+               UIntImmNode::make(DataType::UInt(8), dtype.bits()) &&
                TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
-               UIntImm::make(DataType::UInt(16), dtype.lanes()));
-  asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str(), nop));
+               UIntImmNode::make(DataType::UInt(16), dtype.lanes()));
+  asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
   // data field
   if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
             arg_name + ".data", true)) {
     Var vptr(buffer->data);
     def_handle_dtype_.Set(vptr, ir::TypeAnnotation(buffer->dtype));
     // mark alignment of external bufs
-    init_nest_.emplace_back(AttrStmt::make(
+    init_nest_.emplace_back(AttrStmtNode::make(
         vptr, ir::attr::storage_alignment,
-        IntImm::make(DataType::Int(32), buffer->data_alignment), nop));
+        IntImmNode::make(DataType::Int(32), buffer->data_alignment), nop));
   }
 
   Var v_shape(arg_name + ".shape", DataType::Handle());
   def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
-  init_nest_.emplace_back(LetStmt::make(
+  init_nest_.emplace_back(LetStmtNode::make(
       v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
   for (size_t k = 0; k < buffer->shape.size(); ++k) {
     std::ostringstream field_name;
     field_name << v_shape->name_hint << '[' << k << ']';
     Bind_(buffer->shape[k],
           cast(buffer->shape[k].dtype(),
-               Load::make(tvm_shape_type, v_shape,
-                          IntImm::make(DataType::Int(32), k), const_true(1))),
+               LoadNode::make(tvm_shape_type, v_shape,
+                          IntImmNode::make(DataType::Int(32), k), const_true(1))),
           field_name.str(), true);
   }
   // strides field
   Var v_strides(arg_name + ".strides", DataType::Handle());
   def_handle_dtype_.Set(v_strides, ir::TypeAnnotation(tvm_shape_type));
-  init_nest_.emplace_back(LetStmt::make(
+  init_nest_.emplace_back(LetStmtNode::make(
       v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides),
       nop));
-  Expr is_null = Call::make(
+  Expr is_null = CallNode::make(
     DataType::Bool(1), intrinsic::tvm_handle_is_null,
-    {v_strides}, Call::PureIntrinsic);
+    {v_strides}, CallNode::PureIntrinsic);
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
@@ -227,8 +227,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       size_t k = i - 1;
       Expr svalue = cast(
           stype,
-          Load::make(tvm_shape_type, v_strides,
-                     IntImm::make(DataType::Int(32), k), const_true(1)));
+          LoadNode::make(tvm_shape_type, v_strides,
+                     IntImmNode::make(DataType::Int(32), k), const_true(1)));
       conds.push_back(expect_stride == svalue);
       expect_stride = expect_stride * buffer->shape[k];
     }
@@ -237,10 +237,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
                    << " expected to be compact array";
     if (conds.size() != 0) {
       Stmt check =
-          AssertStmt::make(arith::ComputeReduce<ir::And>(conds, Expr()),
-                           stride_err_msg.str(), Evaluate::make(0));
-      check = IfThenElse::make(Not::make(is_null), check, Stmt());
-      asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)}));
+          AssertStmtNode::make(arith::ComputeReduce<ir::AndNode>(conds, Expr()),
+                           stride_err_msg.str(), EvaluateNode::make(0));
+      check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
+      asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
     }
   } else if (buffer->buffer_type == kAutoBroadcast) {
     DataType stype = buffer->DefaultIndexType();
@@ -250,8 +250,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
       std::ostringstream field_name;
       field_name << v_strides->name_hint << '[' << k << ']';
       Expr value = cast(buffer->shape[k].dtype(),
-                        Load::make(tvm_shape_type, v_strides,
-                                   IntImm::make(DataType::Int(32), k), const_true(1)));
+                        LoadNode::make(tvm_shape_type, v_strides,
+                                   IntImmNode::make(DataType::Int(32), k), const_true(1)));
       value = tvm::if_then_else(is_null, stride, value);
       value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
       Bind_(buffer->strides[k], value, field_name.str(), true);
@@ -260,15 +260,17 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   } else {
     std::ostringstream stride_null_err_msg;
     stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
-    asserts_.emplace_back(AssertStmt::make(Not::make(is_null), stride_null_err_msg.str(), nop));
+    asserts_.emplace_back(
+        AssertStmtNode::make(
+            NotNode::make(is_null), stride_null_err_msg.str(), nop));
 
     for (size_t k = 0; k < buffer->strides.size(); ++k) {
       std::ostringstream field_name;
       field_name << v_strides->name_hint << '[' << k << ']';
       Bind_(buffer->strides[k],
             cast(buffer->shape[k].dtype(),
-                 Load::make(tvm_shape_type, v_strides,
-                            IntImm::make(DataType::Int(32), k), const_true(1))),
+                 LoadNode::make(tvm_shape_type, v_strides,
+                            IntImmNode::make(DataType::Int(32), k), const_true(1))),
             field_name.str(), true);
     }
   }
index 71f0dbe..55d8c22 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -62,7 +62,7 @@ class ArgBinder {
    *   ArgBinder will update this def_map when adding new definitions.
    */
   explicit ArgBinder(
-      std::unordered_map<const Variable*, Expr>* def_map)
+      std::unordered_map<const VarNode*, Expr>* def_map)
       : def_map_(def_map) {
   }
   /*!
@@ -144,7 +144,7 @@ class ArgBinder {
              const std::string& arg_name,
              bool with_lets);
   /*! \brief The definition map, can be uses to substitute */
-  std::unordered_map<const Variable*, Expr>* def_map_;
+  std::unordered_map<const VarNode*, Expr>* def_map_;
   /*! \brief defs generated in the current binder */
   std::vector<Var> defs_;
   /*! \brief Initialize nest */
index d3898a2..84939fc 100644 (file)
@@ -36,25 +36,25 @@ class BoundCollector : public StmtVisitor {
  public:
   BoundCollector() {}
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == ir::attr::buffer_bound) {
-      if (const Variable *key = op->node.as<Variable>()) {
+      if (const VarNode *key = op->node.as<VarNode>()) {
         mem_to_shape[key] = op->value;
       }
     }
     StmtVisitor::VisitStmt_(op);
   }
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const Variable *, Expr> mem_to_shape;
+  std::unordered_map<const VarNode *, Expr> mem_to_shape;
 };
 
 class BoundChecker : public StmtExprMutator {
  public:
   explicit BoundChecker(
-      const std::unordered_map<const Variable *, Expr> &mem_to_shape)
+      const std::unordered_map<const VarNode *, Expr> &mem_to_shape)
       : mem_to_shape_(mem_to_shape) {}
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     // If the shape was updated we should update the hashtable.
     if (UpdateIsNeeded(op->buffer_var)) {
       Update(op->buffer_var, op->extents, op->dtype);
@@ -62,14 +62,14 @@ class BoundChecker : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) {
       unsafe_rewritten_ = true;
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     store_scope_bound_collector_.clear();
     process_store_ = true;
     unsafe_rewritten_ = false;
@@ -81,20 +81,20 @@ class BoundChecker : public StmtExprMutator {
     // The collector should has at least one item.
     if (store_scope_bound_collector_.size()) {
       Expr condition = MakeCondition();
-      if (!condition.as<StringImm>()) {
-        Stmt nop = Evaluate::make(1);
+      if (!condition.as<StringImmNode>()) {
+        Stmt nop = EvaluateNode::make(1);
         Stmt then_case =
-            Store::make(op->buffer_var, op->value, op->index, op->predicate);
+            StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
         Stmt else_case =
-            AssertStmt::make(condition, StringImm::make(error_message_), nop);
-        Stmt body = IfThenElse::make(condition, then_case, else_case);
+            AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop);
+        Stmt body = IfThenElseNode::make(condition, then_case, else_case);
         return body;
       }
     }
     return GetRef<Stmt>(op);
   }
 
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     if (CanInstrument(op->index, op->buffer_var)) {
       Collect(op->index, op->buffer_var);
     }
@@ -122,12 +122,12 @@ class BoundChecker : public StmtExprMutator {
     }
 
     // Scalarize the shape.
-    Expr shape = Mul::make(make_const(DataType::UInt(64), type.lanes()),
-                           Cast::make(DataType::UInt(64), new_shape[0]));
+    Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+                           CastNode::make(DataType::UInt(64), new_shape[0]));
     for (size_t i = 1; i < new_shape.size(); ++i) {
       // Cast to unsigned to avoid integer overlow at frist.
-      shape = Mul::make(shape, Mul::make(make_const(DataType::UInt(64), type.lanes()),
-                                         Cast::make(DataType::UInt(64), new_shape[i])));
+      shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()),
+                                         CastNode::make(DataType::UInt(64), new_shape[i])));
     }
     mem_to_shape_[buffer_var.get()] = shape;
   }
@@ -137,7 +137,7 @@ class BoundChecker : public StmtExprMutator {
       return false;
     }
 
-    if (const Ramp *ramp_index = index.as<Ramp>()) {
+    if (const RampNode *ramp_index = index.as<RampNode>()) {
       return ramp_index->base.defined() &&
              ramp_index->base.dtype().is_scalar() &&
              ramp_index->stride.defined() &&
@@ -163,12 +163,12 @@ class BoundChecker : public StmtExprMutator {
       Expr index = buffer_to_mem.first;
       Expr upper_bound = buffer_to_mem.second;
 
-      if (const Ramp *ramp_index = index.as<Ramp>()) {
+      if (const RampNode *ramp_index = index.as<RampNode>()) {
         // In case index is base + stride * i.
         // Non inclusive range.
-        index = Add::make(
+        index = AddNode::make(
             ramp_index->base,
-            Mul::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
+            MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(),
                                                      ramp_index->lanes - 1)));
       }
 
@@ -177,16 +177,16 @@ class BoundChecker : public StmtExprMutator {
       upper_bound = ir::Simplify(upper_bound);
 
       // Cast to the same type - signed, to be able to check lower bound.
-      index = Cast::make(DataType::Int(64), index);
-      upper_bound = Cast::make(DataType::Int(64), upper_bound);
+      index = CastNode::make(DataType::Int(64), index);
+      upper_bound = CastNode::make(DataType::Int(64), upper_bound);
 
       // Looks like a lower bound should always be zero after normalization.
       Expr lower_bound = make_zero(DataType::Int(64));
 
       Expr current_condition =
-          And::make(GE::make(index, lower_bound), LT::make(index, upper_bound));
+          AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound));
       condition =
-          !i ? current_condition : And::make(condition, current_condition);
+          !i ? current_condition : AndNode::make(condition, current_condition);
     }
     return condition;
   }
@@ -200,7 +200,7 @@ class BoundChecker : public StmtExprMutator {
   // Error message.
   const char *const error_message_ = "OUT OF THE BOUNDS";
   // Hashtable which maps buffer_var to shape.
-  std::unordered_map<const Variable *, Expr> mem_to_shape_;
+  std::unordered_map<const VarNode *, Expr> mem_to_shape_;
 };
 
 Stmt InstrumentBoundCheckers(Stmt stmt) {
index 5f35d2c..62ceede 100644 (file)
@@ -40,7 +40,7 @@ class ContextCallCombiner final : public StmtExprMutator {
     }
   };
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_thread_context)) {
       CHECK_EQ(op->args.size(), 1U);
       Expr ctx = op->args[0];
@@ -50,7 +50,7 @@ class ContextCallCombiner final : public StmtExprMutator {
       } else {
         CHECK(ctx.dtype().is_handle());
         std::string name;
-        if (const Call* call = ctx.as<Call>()) {
+        if (const CallNode* call = ctx.as<CallNode>()) {
           name = call->name + "_cache";
         } else {
           name = "ctx_cache_";
@@ -64,7 +64,7 @@ class ContextCallCombiner final : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::coproc_uop_scope) {
       // Map of comparison expression to variable
@@ -78,7 +78,7 @@ class ContextCallCombiner final : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     if (op->for_type == ForType::Parallel) {
       // Map of comparison expression to variable
       std::map<Expr, Var, CompareExpr> temp;
@@ -99,7 +99,7 @@ class ContextCallCombiner final : public StmtExprMutator {
   static Stmt BuildContext(const std::map<Expr, Var, CompareExpr>& cmap,
                            Stmt body) {
     for (const auto& kv : cmap) {
-      body = LetStmt::make(kv.second, kv.first, body);
+      body = LetStmtNode::make(kv.second, kv.first, body);
     }
     return body;
   }
index 33af959..a7afd46 100644 (file)
@@ -34,7 +34,7 @@ namespace ir {
 // Visitor to find touched set by co-processor scope.
 class CoProcTouchedBuffer : public StmtExprVisitor {
  public:
-  void VisitExpr_(const Load* op) final {
+  void VisitExpr_(const LoadNode* op) final {
     if (in_scope_) {
       touched_[op->buffer_var.get()].coproc = true;
     } else {
@@ -42,7 +42,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor {
     }
     StmtExprVisitor::VisitExpr_(op);
   }
-  void VisitStmt_(const Store* op) final {
+  void VisitStmt_(const StoreNode* op) final {
     if (in_scope_) {
       touched_[op->buffer_var.get()].coproc = true;
     } else {
@@ -50,9 +50,9 @@ class CoProcTouchedBuffer : public StmtExprVisitor {
     }
     StmtExprVisitor::VisitStmt_(op);
   }
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
-      const Variable* buffer = op->args[1].as<Variable>();
+      const VarNode* buffer = op->args[1].as<VarNode>();
       if (in_scope_) {
         touched_[buffer].coproc = true;
       } else {
@@ -61,7 +61,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor {
     }
     StmtExprVisitor::VisitExpr_(op);
   }
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::coproc_scope && !in_scope_) {
       in_scope_ = true;
       IterVar iv = Downcast<IterVar>(op->node);
@@ -78,7 +78,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor {
     bool normal{false};
     bool coproc{false};
   };
-  std::unordered_map<const Variable*, TouchEntry> touched_;
+  std::unordered_map<const VarNode*, TouchEntry> touched_;
   std::unordered_set<IterVar> coproc_;
 
  private:
@@ -89,7 +89,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor {
 class CoProcSyncPlanner : public StorageAccessVisitor {
  public:
   explicit CoProcSyncPlanner(
-      const std::unordered_set<const Variable*>& touched,
+      const std::unordered_set<const VarNode*>& touched,
       const std::string& coproc_name)
       : touched_(touched), coproc_name_(coproc_name) {
   }
@@ -106,21 +106,21 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
   std::unordered_map<const Object*, std::vector<Stmt> > sync_;
 
  protected:
-  bool Enabled(const Variable* buf,
+  bool Enabled(const VarNode* buf,
                const StorageScope& scope) const final {
     return touched_.count(buf);
   }
 
   // Plan the sync
   std::vector<AccessEntry> Summarize(
-      std::vector<StmtEntry> seq, const For* loop) final {
+      std::vector<StmtEntry> seq, const ForNode* loop) final {
     return PlanSync(seq, loop, false);
   }
 
  private:
   // Plan write synchronization if write is not coherent
   std::vector<AccessEntry> PlanSync(
-      std::vector<StmtEntry> seq, const For* loop,
+      std::vector<StmtEntry> seq, const ForNode* loop,
       bool force_sync_at_end) {
     // detect write barriers
     // access by the co-processor.
@@ -196,13 +196,13 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
   }
 
   std::vector<Stmt> GetSync(std::string sync_name) {
-    return {Evaluate::make(Call::make(
+    return {EvaluateNode::make(CallNode::make(
         DataType::Int(32),
         sync_name,
-        {}, Call::Intrinsic))};
+        {}, CallNode::Intrinsic))};
   }
 
-  const std::unordered_set<const Variable*>& touched_;
+  const std::unordered_set<const VarNode*>& touched_;
   std::string coproc_name_;
 };
 
@@ -210,7 +210,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
 class CoProcBarrierDetector : public StorageAccessVisitor {
  public:
   explicit CoProcBarrierDetector(
-      const std::unordered_set<const Variable*>& touched,
+      const std::unordered_set<const VarNode*>& touched,
       const std::string& coproc_name)
       : touched_(touched) {
     read_barrier_name_ = coproc_name + ".coproc_read_barrier";
@@ -232,14 +232,14 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
   std::unordered_map<const Object*, std::vector<Stmt> > barrier_after_;
 
  protected:
-  bool Enabled(const Variable* buf,
+  bool Enabled(const VarNode* buf,
                const StorageScope& scope) const final {
     return touched_.count(buf);
   }
 
   // Plan the sync
   std::vector<AccessEntry> Summarize(
-      std::vector<StmtEntry> seq, const For* loop) final {
+      std::vector<StmtEntry> seq, const ForNode* loop) final {
     if (read_barrier_) {
       return PlanReadBarrier(seq, loop);
     } else {
@@ -250,9 +250,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
  private:
   // Plan write barrier at Read after write point.
   std::vector<AccessEntry> PlanWriteBarrier(
-      std::vector<StmtEntry> seq, const For* loop) {
+      std::vector<StmtEntry> seq, const ForNode* loop) {
     std::vector<AccessEntry> read_seq;
-    std::unordered_map<const Variable*, std::vector<AccessEntry> > write_set;
+    std::unordered_map<const VarNode*, std::vector<AccessEntry> > write_set;
 
     auto fupdate = [&](size_t i, const AccessEntry& acc) {
       auto it  = write_set.find(acc.buffer.get());
@@ -290,9 +290,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
   }
 
   std::vector<AccessEntry> PlanReadBarrier(
-      std::vector<StmtEntry> seq, const For* loop) {
+      std::vector<StmtEntry> seq, const ForNode* loop) {
     std::vector<AccessEntry> write_seq;
-    std::unordered_map<const Variable*, std::vector<AccessEntry> > read_set;
+    std::unordered_map<const VarNode*, std::vector<AccessEntry> > read_set;
 
     auto fupdate = [&](size_t i, const AccessEntry& acc) {
       auto it  = read_set.find(acc.buffer.get());
@@ -343,15 +343,15 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
         << "Cannot deduce write range of " << wvec[0].buffer;
     Expr min = r->min;
     Expr extent = r->extent;
-    return Evaluate::make(Call::make(
+    return EvaluateNode::make(CallNode::make(
         DataType::Int(32), func,
-        {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic));
+        {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic));
   }
   // Write barrier name
   bool read_barrier_{false};
   std::string read_barrier_name_;
   std::string write_barrier_name_;
-  const std::unordered_set<const Variable*>& touched_;
+  const std::unordered_set<const VarNode*>& touched_;
 };
 
 
@@ -373,10 +373,10 @@ class CoProcInstDepDetector : public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::coproc_scope &&
         op->node.same_as(coproc_axis_)) {
-      const IntImm* ctx_id = op->value.as<IntImm>();
+      const IntImmNode* ctx_id = op->value.as<IntImmNode>();
       CHECK(ctx_id != nullptr);
       curr_state_.clear();
       curr_state_.node = op->body.get();
@@ -388,7 +388,7 @@ class CoProcInstDepDetector : public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     SyncState temp_first, temp_last;
     std::swap(first_state_, temp_first);
     std::swap(last_state_, temp_last);
@@ -411,7 +411,7 @@ class CoProcInstDepDetector : public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const IfThenElse* op) final {
+  void VisitStmt_(const IfThenElseNode* op) final {
     SyncState temp_first, temp_last, curr_state;
     std::swap(first_state_, temp_first);
     std::swap(last_state_, temp_last);
@@ -586,16 +586,16 @@ class CoProcInstDepDetector : public StmtVisitor {
   }
 
   Stmt MakePush(int from, int to) {
-    return Evaluate::make(Call::make(
+    return EvaluateNode::make(CallNode::make(
         DataType::Int(32), sync_push_name_,
         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-        Call::Intrinsic));
+        CallNode::Intrinsic));
   }
   Stmt MakePop(int from, int to) {
-    return Evaluate::make(Call::make(
+    return EvaluateNode::make(CallNode::make(
         DataType::Int(32), sync_pop_name_,
         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-        Call::Intrinsic));
+        CallNode::Intrinsic));
   }
   // sync states.
   SyncState first_state_, last_state_, curr_state_;
@@ -611,7 +611,7 @@ class CoProcSyncInserter : public StmtMutator {
     CoProcTouchedBuffer visitor;
     visitor(stmt);
     if (visitor.coproc_.size() == 0) return stmt;
-    std::unordered_set<const Variable*> touched;
+    std::unordered_set<const VarNode*> touched;
 
     for (const auto &kv : visitor.touched_) {
       if (kv.second.normal && kv.second.coproc) {
index 202f255..3578ce5 100644 (file)
@@ -27,7 +27,7 @@
 namespace tvm {
 namespace ir {
 Stmt DecorateDeviceScope(Stmt stmt) {
-  Stmt body = AttrStmt::make(make_zero(DataType::Int(32)),
+  Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)),
                              ir::attr::device_scope,
                              0,
                              stmt);
index 5748e9f..789877f 100644 (file)
@@ -125,12 +125,12 @@ class IfThenElseHoist {
 // in a For stmt.
 bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
   std::vector<const Object*> if_node_list;
-  const For* for_node = for_stmt.as<For>();
+  const ForNode* for_node = for_stmt.as<ForNode>();
   CHECK(for_node);
-  CHECK(if_stmt.as<IfThenElse>());
+  CHECK(if_stmt.as<IfThenElseNode>());
 
   PostOrderVisit(for_node->body, [&](const ObjectRef& node) {
-    if (node.as<IfThenElse>()) {
+    if (node.as<IfThenElseNode>()) {
       if_node_list.push_back(node.get());
     }
   });
@@ -142,12 +142,12 @@ bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
 // in the main VisitAndMutate function.
 Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
   const Object* top_for_node;
-  const For* parent_for_node = parent_for_stmt.as<For>();
+  const ForNode* parent_for_node = parent_for_stmt.as<ForNode>();
   CHECK(parent_for_node);
-  CHECK(new_if_stmt.as<IfThenElse>());
+  CHECK(new_if_stmt.as<IfThenElseNode>());
 
   PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) {
-    if (node.as<For>()) {
+    if (node.as<ForNode>()) {
       top_for_node = node.get();
     }
   });
@@ -169,13 +169,13 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
 std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
   Stmt then_for;
   Stmt else_for;
-  CHECK(if_stmt.as<IfThenElse>());
+  CHECK(if_stmt.as<IfThenElseNode>());
 
   PackedFunc replace_then_case = PackedFunc(
     [&](TVMArgs args, TVMRetValue *ret){
       const ObjectRef& node  = args[0];
       if (node == if_stmt) {
-        *ret = node.as<IfThenElse>()->then_case;
+        *ret = node.as<IfThenElseNode>()->then_case;
       }
     });
 
@@ -183,13 +183,13 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
     [&](TVMArgs args, TVMRetValue *ret){
       const ObjectRef& node  = args[0];
       if (node == if_stmt) {
-        *ret = node.as<IfThenElse>()->else_case;
+        *ret = node.as<IfThenElseNode>()->else_case;
       }
     });
 
   then_for = IRTransform(for_stmt, nullptr, replace_then_case,
                          {Expr("IfThenElse")});
-  if (if_stmt.as<IfThenElse>()->else_case) {
+  if (if_stmt.as<IfThenElseNode>()->else_case) {
     else_for = IRTransform(for_stmt, nullptr, replace_else_case,
                            {Expr("IfThenElse")});
   }
@@ -200,7 +200,7 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
 // Locate all For nodes and capture child IfThenElse nodes.
 void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
   PostOrderVisit(stmt, [&](const ObjectRef& node){
-    const For* for_node = node.as<For>();
+    const ForNode* for_node = node.as<ForNode>();
     if (!for_node) return;
 
     std::queue<Stmt> tracker;
@@ -210,16 +210,16 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
     while (!tracker.empty()) {
       Stmt head = tracker.front();
       tracker.pop();
-      if (head->IsInstance<For>()) {
+      if (head->IsInstance<ForNode>()) {
         for (const auto& if_stmt : for2if_map_.at(head.get())) {
           for2if_map_[for_stmt.get()].push_back(if_stmt);
         }
-      } else if (head->IsInstance<AttrStmt>()) {
-        const AttrStmt* attr_node = head.as<AttrStmt>();
+      } else if (head->IsInstance<AttrStmtNode>()) {
+        const AttrStmtNode* attr_node = head.as<AttrStmtNode>();
         tracker.push(attr_node->body);
-      } else if (head->IsInstance<IfThenElse>()) {
+      } else if (head->IsInstance<IfThenElseNode>()) {
         for2if_map_[for_stmt.get()].push_back(head);
-        const IfThenElse* if_node = head.as<IfThenElse>();
+        const IfThenElseNode* if_node = head.as<IfThenElseNode>();
         tracker.push(if_node->then_case);
         if (if_node->else_case) {
           tracker.push(if_node->else_case);
@@ -230,7 +230,7 @@ void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
           std::unordered_set<const Object*> new_var_set;
           cond_var_map_.insert({head.get(), new_var_set});
           PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) {
-            if (cond_node.as<Variable>()) {
+            if (cond_node.as<VarNode>()) {
               cond_var_map_[head.get()].insert(cond_node.get());
             }
           });
@@ -252,7 +252,7 @@ void IfThenElseHoist::LocateTopFor() {
   // Create IfThenElse -> For map.
   for (const Stmt& for_stmt : ordered_for_list_) {
     std::vector<Stmt> if_list = for2if_map_[for_stmt.get()];
-    const For* for_node = for_stmt.as<For>();
+    const ForNode* for_node = for_stmt.as<ForNode>();
     CHECK(for_node);
     top_for_var_map_.insert({for_node->loop_var.get(), if_list});
     for (const Stmt& if_stmt : if_list) {
@@ -268,7 +268,7 @@ void IfThenElseHoist::LocateTopFor() {
     std::vector<Stmt> for_list = item.second;
     for (size_t i = 0; i < for_list.size(); ++i) {
       const Stmt& for_stmt = for_list.at(i);
-      const For* for_node = for_stmt.as<For>();
+      const ForNode* for_node = for_stmt.as<ForNode>();
       CHECK(for_node);
       std::vector<Stmt> new_for_list{for_stmt};
       for_tracking_map_.insert({for_stmt.get(), new_for_list});
@@ -282,13 +282,13 @@ void IfThenElseHoist::LocateTopFor() {
         top_for = for_stmt;
       }
     }
-    if (top_for.as<For>()) {
+    if (top_for.as<ForNode>()) {
       if_position_map.insert({if_stmt, top_for});
     }
   }
 
   for (const auto& item : if_position_map) {
-    top_for_var_set.insert(item.second.as<For>()->loop_var.get());
+    top_for_var_set.insert(item.second.as<ForNode>()->loop_var.get());
   }
 
   std::vector<const Object*> removed_for_var_list;
@@ -354,9 +354,9 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) {
       for_tracking_map_[for_stmt.get()].push_back(else_for);
     }
 
-    const IfThenElse* new_if_node = new_if.as<IfThenElse>();
+    const IfThenElseNode* new_if_node = new_if.as<IfThenElseNode>();
     CHECK(new_if_node);
-    new_if = IfThenElse::make(new_if_node->condition, then_for, else_for);
+    new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for);
     if (i < if2for_map_[if_stmt.get()].size() - 1) {
       const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
       const Stmt& actual_next_for =
@@ -375,7 +375,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
   PackedFunc replace_top_for = PackedFunc(
     [&](TVMArgs args, TVMRetValue *ret){
       const ObjectRef& current_for = args[0];
-      const For* for_node = current_for.as<For>();
+      const ForNode* for_node = current_for.as<ForNode>();
       if (!for_node) return;
 
       if (top_for_var_map_.count(for_node->loop_var.get())) {
@@ -385,27 +385,27 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
           new_if_list.emplace_back(HoistIf(if_stmt));
         }
 
-        const IfThenElse* next_if_node;
-        const IfThenElse* current_if_node =
-          new_if_list.back().as<IfThenElse>();
+        const IfThenElseNode* next_if_node;
+        const IfThenElseNode* current_if_node =
+          new_if_list.back().as<IfThenElseNode>();
         Stmt new_for = Stmt();
         for (size_t i = new_if_list.size() - 1; i > 0; --i) {
           CHECK(current_if_node);
           const Stmt current_if_stmt =
-            IfThenElse::make(current_if_node->condition,
+            IfThenElseNode::make(current_if_node->condition,
                              current_if_node->then_case,
                              current_if_node->else_case);
-          next_if_node = new_if_list[i - 1].as<IfThenElse>();
+          next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
           CHECK(next_if_node);
-          new_for = IfThenElse::make(next_if_node->condition, current_if_stmt,
+          new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt,
                                      next_if_node->else_case);
-          current_if_node = new_for.as<IfThenElse>();
+          current_if_node = new_for.as<IfThenElseNode>();
         }
 
         if (!new_for.get()) {
-          const IfThenElse* first_if_node = new_if_list[0].as<IfThenElse>();
+          const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
           CHECK(first_if_node);
-          new_for = IfThenElse::make(first_if_node->condition,
+          new_for = IfThenElseNode::make(first_if_node->condition,
                                      first_if_node->then_case,
                                      first_if_node->else_case);
         }
index 951e008..8f6c06d 100644 (file)
@@ -47,20 +47,20 @@ class FragmentGetter : public StmtExprVisitor {
       : m(_m), n(_n), k(_k), layout(_layout) {}
   };
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
 
     if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
         op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
       // Get shape and layout information from load and store intrinsic
       CHECK_EQ(op->args.size(), 8U);
-      const Variable* buffer_var = op->args[0].as<Variable>();
+      const VarNode* buffer_var = op->args[0].as<VarNode>();
       CHECK(buffer_var);
       // Get shape
-      const IntImm* m = op->args[1].as<IntImm>();
-      const IntImm* n = op->args[2].as<IntImm>();
-      const IntImm* k = op->args[3].as<IntImm>();
-      const StringImm* layout = op->args[7].as<StringImm>();
+      const IntImmNode* m = op->args[1].as<IntImmNode>();
+      const IntImmNode* n = op->args[2].as<IntImmNode>();
+      const IntImmNode* k = op->args[3].as<IntImmNode>();
+      const StringImmNode* layout = op->args[7].as<StringImmNode>();
       CHECK(m);
       CHECK(n);
       CHECK(k);
@@ -89,12 +89,12 @@ class FragmentGetter : public StmtExprVisitor {
     } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
       // Get shape information from fill intrinsic
       CHECK_EQ(op->args.size(), 6U);
-      const Variable* buffer_var = op->args[0].as<Variable>();
+      const VarNode* buffer_var = op->args[0].as<VarNode>();
       CHECK(buffer_var);
       // Get shape
-      const IntImm* m = op->args[1].as<IntImm>();
-      const IntImm* n = op->args[2].as<IntImm>();
-      const IntImm* k = op->args[3].as<IntImm>();
+      const IntImmNode* m = op->args[1].as<IntImmNode>();
+      const IntImmNode* n = op->args[2].as<IntImmNode>();
+      const IntImmNode* k = op->args[3].as<IntImmNode>();
       CHECK(m);
       CHECK(n);
       CHECK(k);
@@ -115,19 +115,19 @@ class FragmentGetter : public StmtExprVisitor {
   }
 
   // Get memory scope
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
-      const Variable* buffer = op->node.as<Variable>();
+      const VarNode* buffer = op->node.as<VarNode>();
       CHECK(buffer);
-      scopes[buffer] = op->value.as<StringImm>()->value;
+      scopes[buffer] = op->value.as<StringImmNode>()->value;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
 
   // Memory scope for allocations
-  std::unordered_map<const Variable*, std::string> scopes;
+  std::unordered_map<const VarNode*, std::string> scopes;
   // Fragment metadata for all fragments
-  std::unordered_map<const Variable*, FragmentInfo> fragments;
+  std::unordered_map<const VarNode*, FragmentInfo> fragments;
 };
 
 // Check shape of fragment making sure it is a valid shape for tvm_mma_sync
@@ -135,15 +135,15 @@ class FragmentChecker : public StmtExprVisitor {
  public:
   explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
     // Check shape when calling tvm_mma_sync
     if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
       CHECK_EQ(op->args.size(), 8U);
-      const Variable* buffer_var_d = op->args[0].as<Variable>();
-      const Variable* buffer_var_a = op->args[2].as<Variable>();
-      const Variable* buffer_var_b = op->args[4].as<Variable>();
-      const Variable* buffer_var_c = op->args[6].as<Variable>();
+      const VarNode* buffer_var_d = op->args[0].as<VarNode>();
+      const VarNode* buffer_var_a = op->args[2].as<VarNode>();
+      const VarNode* buffer_var_b = op->args[4].as<VarNode>();
+      const VarNode* buffer_var_c = op->args[6].as<VarNode>();
       CHECK(buffer_var_d);
       CHECK(buffer_var_a);
       CHECK(buffer_var_b);
@@ -158,7 +158,7 @@ class FragmentChecker : public StmtExprVisitor {
 
  private:
   // A tool for checking shapes of two fragments
-  bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
+  bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
     CHECK(fragment_getter.fragments.count(buffer1));
     CHECK(fragment_getter.fragments.count(buffer2));
     FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
@@ -174,9 +174,9 @@ class InferFragmenter : public StmtMutator {
  public:
   explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    const Variable* buffer = op->buffer_var.get();
+    const VarNode* buffer = op->buffer_var.get();
     if (fragment_getter.fragments.count(buffer)) {
       // Add attribute to fragments allocation
       FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
@@ -185,12 +185,12 @@ class InferFragmenter : public StmtMutator {
       std::string shape = std::to_string(info.m) + ", " +
                           std::to_string(info.n) + ", " +
                           std::to_string(info.k);
-      Expr shape_expr = StringImm::make(shape);
-      Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
+      Expr shape_expr = StringImmNode::make(shape);
+      Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
       if (info.layout != "") {
         // Add shape attribute to matrix_a and matrix_b
-        Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
-                                          StringImm::make(info.layout), shape_attr);
+        Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout,
+                                          StringImmNode::make(info.layout), shape_attr);
         return layout_attr;
       } else {
         return shape_attr;
index d1ba19b..0a19c69 100644 (file)
@@ -40,10 +40,10 @@ class CopyIntrinInjector : public StmtMutator {
         flower_copy_fromto_(flower_copy_fromto) {
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
-      storage_scope_[buf] = op->value.as<StringImm>()->value;
+      const VarNode* buf = op->node.as<VarNode>();
+      storage_scope_[buf] = op->value.as<StringImmNode>()->value;
     } else if (op->attr_key == pragma_key_) {
       Stmt ret;
       CHECK(MatchCopyPattern(op->body, &ret))
@@ -59,13 +59,13 @@ class CopyIntrinInjector : public StmtMutator {
     Stmt body = stmt;
 
     // strip the loops
-    std::vector<const For*> loops;
-    while (const For* op = body.as<For>()) {
+    std::vector<const ForNode*> loops;
+    while (const ForNode* op = body.as<ForNode>()) {
       if (!is_zero(op->min)) return false;
       loops.push_back(op);
       body = op->body;
     }
-    const Store* store = body.as<Store>();
+    const StoreNode* store = body.as<StoreNode>();
     if (store == nullptr) return false;
     // Expr sel_cond, sel_true_value, sel_false_value;
     // match select or if
@@ -74,23 +74,23 @@ class CopyIntrinInjector : public StmtMutator {
         if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
         select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
 
-    const Cast* cast = store->value.as<Cast>();
-    const Load* load = store->value.as<Load>();
+    const CastNode* cast = store->value.as<CastNode>();
+    const LoadNode* load = store->value.as<LoadNode>();
     if (0 == loops.size()) {
       CHECK(!has_cond);
     }
     // for now only support true condition matching
     if (has_cond) {
-      load = sel_true_value.Eval().as<Load>();
+      load = sel_true_value.Eval().as<LoadNode>();
     }
     // cast can be part of the pattern
     if (cast != nullptr) {
-      load = cast->value.as<Load>();
+      load = cast->value.as<LoadNode>();
     }
     if (load == nullptr) return false;
     if (load->dtype.lanes() != 1) return false;
     Array<Var> loop_vars;
-    for (const For* op : loops) {
+    for (const ForNode* op : loops) {
       loop_vars.push_back(op->loop_var);
     }
     Array<Expr> store_strides =
@@ -103,7 +103,7 @@ class CopyIntrinInjector : public StmtMutator {
     if (loop_var_size == 0) {
       dst_shape.push_back(make_const(DataType::Int(32), 1));
     } else {
-      for (const For* op : loops) {
+      for (const ForNode* op : loops) {
         dst_shape.push_back(op->extent);
       }
     }
@@ -124,7 +124,7 @@ class CopyIntrinInjector : public StmtMutator {
         DataType t = loop_vars[i].dtype();
         Expr svalue = src_shape[i];
         if (min_value.defined()) {
-          Expr pbefore = Simplify(Max::make(min_value, make_zero(t)));
+          Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t)));
           src_elem_offset = src_elem_offset + pbefore * load_strides[i];
           svalue = svalue - pbefore;
           pad_before.push_back(pbefore);
@@ -132,7 +132,7 @@ class CopyIntrinInjector : public StmtMutator {
           pad_before.push_back(make_zero(t));
         }
         if (max_value.defined()) {
-          Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1),
+          Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1),
                                            make_zero(t)));
           svalue = svalue - pafter;
           pad_after.push_back(pafter);
@@ -174,7 +174,7 @@ class CopyIntrinInjector : public StmtMutator {
     return true;
   }
   // Get storage scope
-  std::string GetStorageScope(const Variable* var) const {
+  std::string GetStorageScope(const VarNode* var) const {
     auto it = storage_scope_.find(var);
     if (it != storage_scope_.end()) {
       return it->second;
@@ -187,7 +187,7 @@ class CopyIntrinInjector : public StmtMutator {
   // function to lower copy intrinsics.
   const PackedFunc& flower_copy_fromto_;
   // Storage scope
-  std::unordered_map<const Variable*, std::string> storage_scope_;
+  std::unordered_map<const VarNode*, std::string> storage_scope_;
 };
 
 Stmt InjectCopyIntrin(Stmt stmt,
index 0158a94..4bd431e 100644 (file)
@@ -33,28 +33,28 @@ namespace ir {
 // Detect double buffer variables.
 class DoubleBufferDetector : public StmtExprVisitor {
  public:
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::double_buffer_scope) {
-      touched_.insert(op->node.as<Variable>());
+      touched_.insert(op->node.as<VarNode>());
       StmtExprVisitor::VisitStmt_(op);
     } else {
       StmtExprVisitor::VisitStmt_(op);
     }
   }
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     if (touched_.count(op)) {
       touched_.erase(op);
     }
   }
   // The set of touched variable.
-  std::unordered_set<const Variable*> touched_;
+  std::unordered_set<const VarNode*> touched_;
 };
 
 
 class StripDoubleBufferWrite : public StmtMutator {
  public:
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::double_buffer_write) {
       return VisitStmt(op->body);
     } else {
@@ -72,18 +72,18 @@ class DoubleBufferInjector : public StmtExprMutator {
     DoubleBufferDetector detector;
     detector(stmt);
     if (detector.touched_.empty()) return stmt;
-    for (const Variable* v : detector.touched_) {
+    for (const VarNode* v : detector.touched_) {
       dbuffer_info_[v] = StorageEntry();
     }
     return ConvertSSA(operator()(std::move(stmt)));
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
+      const VarNode* buf = op->node.as<VarNode>();
       auto it = dbuffer_info_.find(buf);
       if (it != dbuffer_info_.end()) {
-        it->second.scope = op->value.as<StringImm>()->value;
+        it->second.scope = op->value.as<StringImmNode>()->value;
         return this->VisitStmt(op->body);
       } else {
         return StmtExprMutator::VisitStmt_(op);
@@ -95,38 +95,38 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
-      it->second.stride = arith::ComputeReduce<Mul>(
+      it->second.stride = arith::ComputeReduce<MulNode>(
           op->extents, Expr()) * op->dtype.lanes();
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      op = stmt.as<Allocate>();
+      op = stmt.as<AllocateNode>();
       Array<Expr> new_extents{make_const(op->extents[0].dtype(), 2)};
       for (Expr e : op->extents) {
         new_extents.push_back(e);
       }
       CHECK(it->second.loop != nullptr);
       auto& alloc_nest = loop_allocs_[it->second.loop];
-      alloc_nest.emplace_back(AttrStmt::make(
+      alloc_nest.emplace_back(AttrStmtNode::make(
           op->buffer_var, attr::storage_scope,
-          StringImm::make(it->second.scope),
-          Evaluate::make(0)));
-      alloc_nest.emplace_back(Allocate::make(
+          StringImmNode::make(it->second.scope),
+          EvaluateNode::make(0)));
+      alloc_nest.emplace_back(AllocateNode::make(
           op->buffer_var, op->dtype, new_extents, op->condition,
-          Evaluate::make(0)));
+          EvaluateNode::make(0)));
       return op->body;
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     loop_nest_.push_back(op);
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     auto it = loop_pre_.find(op);
     if (it != loop_pre_.end()) {
-      const For* old_loop = stmt.as<For>();
+      const ForNode* old_loop = stmt.as<ForNode>();
       if (split_loop_ != 0) {
         // Explicitly unroll the loop
         CHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
@@ -139,13 +139,13 @@ class DoubleBufferInjector : public StmtExprMutator {
         Expr outer_ext = new_ext / factor;
         Expr tail_base = outer_ext * factor;
         Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
-        std::unordered_map<const Variable*, Expr> vmap;
+        std::unordered_map<const VarNode*, Expr> vmap;
         std::vector<Stmt> loop_seq;
         for (int32_t i = 0; i < split_loop_; ++i) {
           vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
           loop_seq.emplace_back(Substitute(old_loop->body, vmap));
         }
-        Stmt loop = For::make(
+        Stmt loop = ForNode::make(
             outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
             SeqStmt::Flatten(loop_seq));
         // tail
@@ -155,7 +155,7 @@ class DoubleBufferInjector : public StmtExprMutator {
           Expr idx = tail_base + make_const(tail_base.dtype(), i);
           vmap[old_loop->loop_var.get()] = idx;
           tail_seq.emplace_back(
-              IfThenElse::make(idx < old_loop->extent,
+              IfThenElseNode::make(idx < old_loop->extent,
                                Substitute(tail_body, vmap)));
         }
         stmt = SeqStmt::Flatten(loop, tail_seq);
@@ -170,15 +170,15 @@ class DoubleBufferInjector : public StmtExprMutator {
     return stmt;
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
       const StorageEntry& e = it->second;
       CHECK(in_double_buffer_scope_);
       CHECK(e.stride.defined());
-      return Store::make(op->buffer_var,
+      return StoreNode::make(op->buffer_var,
                          op->value,
                          e.switch_write_var * e.stride + op->index,
                          op->predicate);
@@ -187,15 +187,15 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     auto it = dbuffer_info_.find(op->buffer_var.get());
     if (it != dbuffer_info_.end()) {
       const StorageEntry& e = it->second;
       CHECK(e.stride.defined());
       CHECK(e.switch_read_var.defined());
-      return Load::make(op->dtype,
+      return LoadNode::make(op->dtype,
                         op->buffer_var,
                         e.switch_read_var * e.stride + op->index,
                         op->predicate);
@@ -204,13 +204,13 @@ class DoubleBufferInjector : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     CHECK(!dbuffer_info_.count(op));
     return GetRef<Expr>(op);
   }
 
  private:
-  Stmt MakeProducer(const AttrStmt* op) {
+  Stmt MakeProducer(const AttrStmtNode* op) {
     const VarExpr buffer = Downcast<VarExpr>(op->node);
     CHECK_NE(loop_nest_.size(), 0U)
         << "Double buffer scope must be inside a loop";
@@ -231,15 +231,15 @@ class DoubleBufferInjector : public StmtExprMutator {
     in_double_buffer_scope_ = true;
     Stmt body = this->VisitStmt(op->body);
     in_double_buffer_scope_ = false;
-    std::unordered_map<const Variable*, Expr> vmap;
+    std::unordered_map<const VarNode*, Expr> vmap;
     vmap[e.switch_write_var.get()] = zero;
     vmap[e.loop->loop_var.get()] = zero;
     loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
     vmap[e.loop->loop_var.get()] = loop_shift;
     vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
     body = Substitute(body, vmap);
-    body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body);
-    body = IfThenElse::make(loop_shift < e.loop->extent, body);
+    body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body);
+    body = IfThenElseNode::make(loop_shift < e.loop->extent, body);
     return body;
   }
   // Storage entry for those who need double buffering.
@@ -247,7 +247,7 @@ class DoubleBufferInjector : public StmtExprMutator {
     // The size of the buffer
     Expr stride;
     // The loop we need
-    const For* loop{nullptr};
+    const ForNode* loop{nullptr};
     // The switch variable.
     VarExpr switch_write_var;
     // The switch variable for reading.
@@ -260,13 +260,13 @@ class DoubleBufferInjector : public StmtExprMutator {
   // Whether we are inside double buffer scope.
   bool in_double_buffer_scope_{false};
   // The current loop next
-  std::vector<const For*> loop_nest_;
+  std::vector<const ForNode*> loop_nest_;
   // The allocs to be appended before the loop
-  std::unordered_map<const For*, std::vector<Stmt> > loop_allocs_;
+  std::unordered_map<const ForNode*, std::vector<Stmt> > loop_allocs_;
   // The stmt to be appended before the loop
-  std::unordered_map<const For*, std::vector<Stmt> > loop_pre_;
+  std::unordered_map<const ForNode*, std::vector<Stmt> > loop_pre_;
   // The allocation size of the buffer
-  std::unordered_map<const Variable*, StorageEntry> dbuffer_info_;
+  std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
 };
 
 
index 73725c2..c58a91d 100644 (file)
@@ -35,9 +35,9 @@ using arith::DomainTouched;
 
 class PrefetchInjector : public StmtMutator {
  public:
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt ret = StmtMutator::VisitStmt_(op);
-    op = ret.as<AttrStmt>();
+    op = ret.as<AttrStmtNode>();
     if (op && op->attr_key == attr::prefetch_scope) {
       Tensor ts = Downcast<Tensor>(op->node);
       CHECK_NE(loop_nest_.size(), 0U);
@@ -58,13 +58,13 @@ class PrefetchInjector : public StmtMutator {
 
       vectorized_.erase(iter_var);
 
-      Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region);
+      Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region);
       return SeqStmt({prefetch, op->body});
     }
     return ret;
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     auto &var = op->loop_var;
     loop_nest_.push_back(var);
     if (op->for_type == ForType::Vectorized) {
@@ -80,7 +80,7 @@ class PrefetchInjector : public StmtMutator {
 
  private:
   std::vector<VarExpr> loop_nest_;
-  std::unordered_map<const Variable *, IntSet> vectorized_;
+  std::unordered_map<const VarNode *, IntSet> vectorized_;
   static const Range none;
 };
 
index 0887a83..8eeee9d 100644 (file)
@@ -32,7 +32,7 @@ namespace ir {
 // If expression is touched by var.
 class ExprTouched final : public StmtExprVisitor {
  public:
-  explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
+  explicit ExprTouched(const std::unordered_set<const VarNode*> &touched,
                        bool check_write)
       : touched_var_(touched), check_write_(check_write) {}
 
@@ -46,18 +46,18 @@ class ExprTouched final : public StmtExprVisitor {
     if (expr_touched_ && !check_write_) return;
     StmtExprVisitor::VisitStmt(n);
   }
-  void VisitExpr_(const Load *op) final {
+  void VisitExpr_(const LoadNode *op) final {
     HandleUseVar(op->buffer_var.get());
     StmtExprVisitor::VisitExpr_(op);
   }
-  void VisitExpr_(const Variable *op) final {
+  void VisitExpr_(const VarNode *op) final {
     HandleUseVar(op);
   }
-  void VisitExpr_(const Call *op) final {
+  void VisitExpr_(const CallNode *op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       int rw_mask = 0;
       CHECK(arith::GetConstInt(op->args[4], &rw_mask));
-      const Variable* buffer_var = op->args[1].as<Variable>();
+      const VarNode* buffer_var = op->args[1].as<VarNode>();
       CHECK(buffer_var);
       // read
       if (rw_mask & 1) {
@@ -71,7 +71,7 @@ class ExprTouched final : public StmtExprVisitor {
       StmtExprVisitor::VisitExpr_(op);
     }
   }
-  void HandleUseVar(const Variable* var) {
+  void HandleUseVar(const VarNode* var) {
     auto it = touched_var_.find(var);
     if (it != touched_var_.end()) {
       expr_touched_ = true;
@@ -82,33 +82,33 @@ class ExprTouched final : public StmtExprVisitor {
       used_vars_.push_back(var);
     }
   }
-  void HandleWriteVar(const Variable* var) {
+  void HandleWriteVar(const VarNode* var) {
     write_vars_.push_back(var);
   }
   // the fields.
   bool expr_touched_{false};
-  std::vector<const Variable*> used_vars_;
-  std::vector<const Variable*> write_vars_;
-  const std::unordered_set<const Variable*>& touched_var_;
+  std::vector<const VarNode*> used_vars_;
+  std::vector<const VarNode*> write_vars_;
+  const std::unordered_set<const VarNode*>& touched_var_;
   bool check_write_;
 };
 
 // Analyze if the buffers are invariant to value of var
 class VarTouchedAnalysis : public StmtVisitor {
  public:
-  void VisitStmt_(const LetStmt* op) final {
+  void VisitStmt_(const LetStmtNode* op) final {
     ExprTouched tc(touched_var_, false);
     tc(op->value);
     Record(op->var.get(), tc);
     this->VisitStmt(op->body);
   }
-  void VisitStmt_(const Store* op) final {
+  void VisitStmt_(const StoreNode* op) final {
     ExprTouched tc(touched_var_, false);
     tc(op->value);
     tc(op->index);
     Record(op->buffer_var.get(), tc);
   }
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     ExprTouched tc(touched_var_, false);
     tc(op->min);
     tc(op->extent);
@@ -116,14 +116,14 @@ class VarTouchedAnalysis : public StmtVisitor {
     this->VisitStmt(op->body);
   }
   // external function call
-  void VisitStmt_(const Evaluate* op) final {
+  void VisitStmt_(const EvaluateNode* op) final {
     ExprTouched tc(touched_var_, true);
     tc(op->value);
-    for (const Variable* var : tc.write_vars_) {
+    for (const VarNode* var : tc.write_vars_) {
       Record(var, tc);
     }
   }
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     ExprTouched tc(touched_var_, false);
     for (size_t i = 0; i < op->extents.size(); ++i) {
       tc(op->extents[i]);
@@ -135,13 +135,13 @@ class VarTouchedAnalysis : public StmtVisitor {
     Record(op->buffer_var.get(), tc);
     this->VisitStmt(op->body);
   }
-  void Record(const Variable* var,
+  void Record(const VarNode* var,
               const ExprTouched& tc) {
     if (touched_var_.count(var)) return;
     if (tc.expr_touched_) {
       touched_var_.insert(var);
     } else {
-      for (const Variable* r : tc.used_vars_) {
+      for (const VarNode* r : tc.used_vars_) {
         if (r != var) {
           affect_[r].push_back(var);
         }
@@ -149,18 +149,18 @@ class VarTouchedAnalysis : public StmtVisitor {
     }
   }
 
-  std::unordered_set<const Variable*>
+  std::unordered_set<const VarNode*>
   TouchedVar(const Stmt& stmt,
-             const Variable* var) {
+             const VarNode* var) {
     touched_var_.insert(var);
     this->VisitStmt(stmt);
     // do a DFS to push affect around dependency.
-    std::vector<const Variable*> pending(
+    std::vector<const VarNode*> pending(
         touched_var_.begin(), touched_var_.end());
     while (!pending.empty()) {
-      const Variable* v = pending.back();
+      const VarNode* v = pending.back();
       pending.pop_back();
-      for (const Variable* r : affect_[v]) {
+      for (const VarNode* r : affect_[v]) {
         if (!touched_var_.count(r)) {
           touched_var_.insert(r);
           pending.push_back(r);
@@ -172,10 +172,10 @@ class VarTouchedAnalysis : public StmtVisitor {
 
  private:
   // Whether variable is touched by the thread variable.
-  std::unordered_set<const Variable*> touched_var_;
+  std::unordered_set<const VarNode*> touched_var_;
   // x -> all the buffers x read from
-  std::unordered_map<const Variable*,
-                     std::vector<const Variable*> > affect_;
+  std::unordered_map<const VarNode*,
+                     std::vector<const VarNode*> > affect_;
 };
 
 
@@ -186,7 +186,7 @@ class VTInjector : public StmtExprMutator {
   // constructor
   VTInjector(Var var,
              int num_threads,
-             const std::unordered_set<const Variable*>& touched_var,
+             const std::unordered_set<const VarNode*>& touched_var,
              bool allow_share)
       : var_(var), num_threads_(num_threads),
         touched_var_(touched_var), allow_share_(allow_share) {
@@ -205,7 +205,7 @@ class VTInjector : public StmtExprMutator {
     return stmt;
   }
   // Variable
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     CHECK(!alloc_remap_.count(op))
         << "Buffer address may get rewritten in virtual thread";
     if (touched_var_.count(op)) {
@@ -217,15 +217,15 @@ class VTInjector : public StmtExprMutator {
     return index + var_ * alloc_extent;
   }
   // Load
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     if (touched_var_.count(op->buffer_var.get())) {
       visit_touched_var_ = true;
     }
     auto it = alloc_remap_.find(op->buffer_var.get());
     if (it != alloc_remap_.end()) {
-      return Load::make(op->dtype, op->buffer_var,
+      return LoadNode::make(op->dtype, op->buffer_var,
                         RewriteIndex(op->index, it->second),
                         op->predicate);
     } else {
@@ -233,11 +233,11 @@ class VTInjector : public StmtExprMutator {
     }
   }
   // Expression.
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       CHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
-      const Variable* buffer = op->args[1].as<Variable>();
+      const VarNode* buffer = op->args[1].as<VarNode>();
       auto it = alloc_remap_.find(buffer);
       if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
       visit_touched_var_ = true;
@@ -246,7 +246,7 @@ class VTInjector : public StmtExprMutator {
       Expr stride =
           it->second / make_const(offset.dtype(), dtype.lanes());
       offset = stride * var_ + offset;
-      return Call::make(
+      return CallNode::make(
           op->dtype, op->name,
           {op->args[0], op->args[1], offset, extent, op->args[4]},
           op->call_type);
@@ -256,21 +256,21 @@ class VTInjector : public StmtExprMutator {
       return StmtExprMutator::VisitExpr_(op);
     }
   }
-  Stmt VisitStmt_(const Evaluate* op) final {
+  Stmt VisitStmt_(const EvaluateNode* op) final {
     trigger_base_inject_ = !allow_share_;
     return StmtExprMutator::VisitStmt_(op);
   }
   // Store
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     if (touched_var_.count(op->buffer_var.get())) {
       visit_touched_var_ = true;
     }
     trigger_base_inject_ = !allow_share_;
     auto it = alloc_remap_.find(op->buffer_var.get());
     if (it != alloc_remap_.end()) {
-      return Store::make(op->buffer_var,
+      return StoreNode::make(op->buffer_var,
                          op->value,
                          RewriteIndex(op->index, it->second),
                          op->predicate);
@@ -279,7 +279,7 @@ class VTInjector : public StmtExprMutator {
     }
   }
   // Attribute
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     Expr value = this->VisitExpr(op->value);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
@@ -293,12 +293,12 @@ class VTInjector : public StmtExprMutator {
           body.same_as(op->body)) {
         return GetRef<Stmt>(op);
       } else {
-        return AttrStmt::make(op->node, op->attr_key, value, body);
+        return AttrStmtNode::make(op->node, op->attr_key, value, body);
       }
     }
   }
   // LetStmt
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     Expr value = this->VisitExpr(op->value);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
@@ -309,11 +309,11 @@ class VTInjector : public StmtExprMutator {
         body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return LetStmt::make(op->var, value, body);
+      return LetStmtNode::make(op->var, value, body);
     }
   }
   // For
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     CHECK(is_zero(op->min));
     Expr extent = this->VisitExpr(op->extent);
     if (visit_touched_var_ && !vt_loop_injected_) {
@@ -328,12 +328,12 @@ class VTInjector : public StmtExprMutator {
         body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return For::make(
+      return ForNode::make(
           op->loop_var, op->min, extent, op->for_type, op->device_api, body);
     }
   }
   // IfThenElse
-  Stmt VisitStmt_(const IfThenElse* op) final {
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
     Expr condition = this->VisitExpr(op->condition);
     if (visit_touched_var_ && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
@@ -353,7 +353,7 @@ class VTInjector : public StmtExprMutator {
         else_case.same_as(op->else_case)) {
       return GetRef<Stmt>(op);
     } else {
-      return IfThenElse::make(condition, then_case, else_case);
+      return IfThenElseNode::make(condition, then_case, else_case);
     }
   }
 
@@ -370,7 +370,7 @@ class VTInjector : public StmtExprMutator {
     return StmtMutator::VisitSeqStmt_(op, false, fmutate);
   }
   // Allocate
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     if (op->new_expr.defined() && !vt_loop_injected_) {
       return InjectVTLoop(GetRef<Stmt>(op), true);
     }
@@ -395,7 +395,7 @@ class VTInjector : public StmtExprMutator {
     // always rewrite if not allow sharing.
     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
       // place v on highest dimension.
-      Expr stride = arith::ComputeReduce<Mul>(
+      Expr stride = arith::ComputeReduce<MulNode>(
           op->extents, Expr()) * op->dtype.lanes();
       Array<Expr> other;
       other.push_back(make_const(op->extents[0].dtype(), num_threads_));
@@ -417,7 +417,7 @@ class VTInjector : public StmtExprMutator {
         condition.same_as(op->condition)) {
       return GetRef<Stmt>(op);
     } else {
-      return Allocate::make(
+      return AllocateNode::make(
           op->buffer_var, op->dtype,
           extents, condition, body,
           op->new_expr, op->free_function);
@@ -450,7 +450,7 @@ class VTInjector : public StmtExprMutator {
       Var idx(var_->name_hint + ".s", var_->dtype);
       Map<Var, Expr> values{{var_, idx}};
       stmt = Substitute(stmt, values);
-      return For::make(idx, make_zero(idx.dtype()),
+      return ForNode::make(idx, make_zero(idx.dtype()),
                        make_const(idx.dtype(), num_threads_),
                        ForType::Serial, DeviceAPI::None, stmt);
     }
@@ -470,23 +470,23 @@ class VTInjector : public StmtExprMutator {
   // the counter of loops in after mutation.
   int max_loop_depth_{0};
   // The variables that get touched.
-  const std::unordered_set<const Variable*>& touched_var_;
+  const std::unordered_set<const VarNode*>& touched_var_;
   // Whether allow shareding.
   bool allow_share_;
   // The allocations that get touched -> extent
-  std::unordered_map<const Variable*, Expr> alloc_remap_;
+  std::unordered_map<const VarNode*, Expr> alloc_remap_;
 };
 
 
 class VirtualThreadInjector : public StmtMutator {
  public:
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<AttrStmt>();
+    op = stmt.as<AttrStmtNode>();
     if (op->attr_key == attr::virtual_thread) {
       IterVar iv = Downcast<IterVar>(op->node);
       bool allow_share = iv->thread_tag == "vthread";
-      int nthread = static_cast<int>(op->value.as<IntImm>()->value);
+      int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
       VarTouchedAnalysis vs;
       auto touched = vs.TouchedVar(op->body, iv->var.get());
       VTInjector injecter(iv->var, nthread, touched, allow_share);
@@ -496,7 +496,7 @@ class VirtualThreadInjector : public StmtMutator {
     }
   }
 
-  Stmt VisitStmt_(const Provide* op) final {
+  Stmt VisitStmt_(const ProvideNode* op) final {
     LOG(FATAL) << "Need to call StorageFlatten first";
     return GetRef<Stmt>(op);
   }
index 50e5c18..4a087dd 100644 (file)
@@ -35,9 +35,9 @@ class IRInline final : public StmtExprMutator {
   IRInline(FunctionRef f, Array<Var> args, Expr body)
       : f_(f), args_(args), body_(body) {}
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
 
     if (op->func == f_) {
       CHECK_EQ(op->value_index, 0);
@@ -50,7 +50,7 @@ class IRInline final : public StmtExprMutator {
       }
       if (has_side_effect) {
         for (size_t i = 0; i < args_.size(); ++i) {
-          expr = Let::make(args_[i], op->args[i], expr);
+          expr = LetNode::make(args_[i], op->args[i], expr);
         }
       } else {
         Map<Var, Expr> vmap;
@@ -58,7 +58,7 @@ class IRInline final : public StmtExprMutator {
           vmap.Set(args_[i], op->args[i]);
         }
         expr = Substitute(
-            Evaluate::make(expr), vmap).as<Evaluate>()->value;
+            EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
       }
       return expr;
     } else {
index bbee9ee..a1218f2 100644 (file)
@@ -74,8 +74,8 @@ class IRDeepCompare :
     StmtComparator::VisitStmt(n, other);
   }
   // Stmt
-  void VisitStmt_(const LetStmt* op, const Stmt& other) final {
-    const LetStmt* rhs = other.as<LetStmt>();
+  void VisitStmt_(const LetStmtNode* op, const Stmt& other) final {
+    const LetStmtNode* rhs = other.as<LetStmtNode>();
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (tie_def_) {
       vmap_[op->var.get()] = rhs->var.get();
@@ -85,23 +85,23 @@ class IRDeepCompare :
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const AttrStmt* op, const Stmt& other) final {
-    const AttrStmt* rhs = other.as<AttrStmt>();
+  void VisitStmt_(const AttrStmtNode* op, const Stmt& other) final {
+    const AttrStmtNode* rhs = other.as<AttrStmtNode>();
     if (CompareString(op->attr_key, rhs->attr_key) != 0) return;
     if (CompareNodeRef(op->node, rhs->node) != 0) return;
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const IfThenElse* op, const Stmt& other) final {
-    const IfThenElse* rhs = other.as<IfThenElse>();
+  void VisitStmt_(const IfThenElseNode* op, const Stmt& other) final {
+    const IfThenElseNode* rhs = other.as<IfThenElseNode>();
     if (CompareExpr(op->condition, rhs->condition) != 0) return;
     if (CompareStmt(op->then_case, rhs->then_case) != 0) return;
     if (CompareStmt(op->else_case, rhs->else_case) != 0) return;
   }
 
-  void VisitStmt_(const For* op, const Stmt& other) final {
-    const For* rhs = other.as<For>();
+  void VisitStmt_(const ForNode* op, const Stmt& other) final {
+    const ForNode* rhs = other.as<ForNode>();
     if (CompareExpr(op->min, rhs->min) != 0) return;
     if (CompareExpr(op->extent, rhs->extent) != 0) return;
     if (tie_def_) {
@@ -112,8 +112,8 @@ class IRDeepCompare :
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const Allocate* op, const Stmt& other) final {
-    const Allocate* rhs = other.as<Allocate>();
+  void VisitStmt_(const AllocateNode* op, const Stmt& other) final {
+    const AllocateNode* rhs = other.as<AllocateNode>();
     if (tie_def_) {
       vmap_[op->buffer_var.get()] = rhs->buffer_var.get();
     } else {
@@ -127,43 +127,43 @@ class IRDeepCompare :
     if (CompareString(op->free_function, rhs->free_function) != 0) return;
   }
 
-  void VisitStmt_(const Store* op, const Stmt& other) final {
-    const Store* rhs = other.as<Store>();
+  void VisitStmt_(const StoreNode* op, const Stmt& other) final {
+    const StoreNode* rhs = other.as<StoreNode>();
     if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (CompareExpr(op->index, rhs->index) != 0) return;
     if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
   }
 
-  void VisitStmt_(const Free* op, const Stmt& other) final {
-    const Free* rhs = other.as<Free>();
+  void VisitStmt_(const FreeNode* op, const Stmt& other) final {
+    const FreeNode* rhs = other.as<FreeNode>();
     if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
   }
 
-  void VisitStmt_(const AssertStmt* op, const Stmt& other) final {
-    const AssertStmt* rhs = other.as<AssertStmt>();
+  void VisitStmt_(const AssertStmtNode* op, const Stmt& other) final {
+    const AssertStmtNode* rhs = other.as<AssertStmtNode>();
     if (CompareExpr(op->condition, rhs->condition) != 0) return;
     if (CompareExpr(op->message, rhs->message) != 0) return;
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const ProducerConsumer* op, const Stmt& other) final {
-    const ProducerConsumer* rhs = other.as<ProducerConsumer>();
+  void VisitStmt_(const ProducerConsumerNode* op, const Stmt& other) final {
+    const ProducerConsumerNode* rhs = other.as<ProducerConsumerNode>();
     if (CompareNodeRef(op->func, rhs->func) != 0) return;
     if (CompareValue(op->is_producer, rhs->is_producer) != 0) return;
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const Provide* op, const Stmt& other) final {
-    const Provide* rhs = other.as<Provide>();
+  void VisitStmt_(const ProvideNode* op, const Stmt& other) final {
+    const ProvideNode* rhs = other.as<ProvideNode>();
     if (CompareNodeRef(op->func, rhs->func) != 0) return;
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (CompareArray(op->args, rhs->args) != 0) return;
   }
 
-  void VisitStmt_(const Realize* op, const Stmt& other) final {
-    const Realize* rhs = other.as<Realize>();
+  void VisitStmt_(const RealizeNode* op, const Stmt& other) final {
+    const RealizeNode* rhs = other.as<RealizeNode>();
     if (CompareNodeRef(op->func, rhs->func) != 0) return;
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
     if (CompareType(op->dtype, rhs->dtype) != 0) return;
@@ -171,8 +171,8 @@ class IRDeepCompare :
     if (CompareStmt(op->body, rhs->body) != 0) return;
   }
 
-  void VisitStmt_(const Prefetch* op, const Stmt& other) final {
-    const Prefetch* rhs = other.as<Prefetch>();
+  void VisitStmt_(const PrefetchNode* op, const Stmt& other) final {
+    const PrefetchNode* rhs = other.as<PrefetchNode>();
     if (CompareNodeRef(op->func, rhs->func) != 0) return;
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
     if (CompareType(op->dtype, rhs->dtype) != 0) return;
@@ -187,14 +187,14 @@ class IRDeepCompare :
     }
   }
 
-  void VisitStmt_(const Evaluate* op, const Stmt& other) final {
-    const Evaluate* rhs = other.as<Evaluate>();
+  void VisitStmt_(const EvaluateNode* op, const Stmt& other) final {
+    const EvaluateNode* rhs = other.as<EvaluateNode>();
     CompareExpr(op->value, rhs->value);
   }
 
   // Exprs
-  void VisitExpr_(const Variable* op, const Expr& other) final {
-    const Variable* rhs = other.as<Variable>();
+  void VisitExpr_(const VarNode* op, const Expr& other) final {
+    const VarNode* rhs = other.as<VarNode>();
     auto it = vmap_.find(op);
     if (it != vmap_.end()) op = it->second;
     if (op < rhs) {
@@ -203,15 +203,15 @@ class IRDeepCompare :
       order_ = +1;
     }
   }
-  void VisitExpr_(const Load* op, const Expr& other) final {
-    const Load* rhs = other.as<Load>();
+  void VisitExpr_(const LoadNode* op, const Expr& other) final {
+    const LoadNode* rhs = other.as<LoadNode>();
     if (CompareExpr(op->buffer_var, rhs->buffer_var) != 0) return;
     if (CompareExpr(op->index, rhs->index) != 0) return;
     if (CompareExpr(op->predicate, rhs->predicate) != 0) return;
   }
 
-  void VisitExpr_(const Let* op, const Expr& other) final {
-    const Let* rhs = other.as<Let>();
+  void VisitExpr_(const LetNode* op, const Expr& other) final {
+    const LetNode* rhs = other.as<LetNode>();
     if (tie_def_) {
       vmap_[op->var.get()] = rhs->var.get();
     } else {
@@ -221,8 +221,8 @@ class IRDeepCompare :
     if (CompareExpr(op->body, rhs->body) != 0) return;
   }
 
-  void VisitExpr_(const Call* op, const Expr& other) final {
-    const Call* rhs = other.as<Call>();
+  void VisitExpr_(const CallNode* op, const Expr& other) final {
+    const CallNode* rhs = other.as<CallNode>();
     if (CompareString(op->name, rhs->name)) return;
     if (CompareArray(op->args, rhs->args)) return;
     if (CompareValue(op->call_type, rhs->call_type) != 0) return;
@@ -230,8 +230,8 @@ class IRDeepCompare :
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
   }
 
-  void VisitExpr_(const Reduce *op, const Expr& other) final {
-    const Reduce* rhs = other.as<Reduce>();
+  void VisitExpr_(const ReduceNode *op, const Expr& other) final {
+    const ReduceNode* rhs = other.as<ReduceNode>();
     if (CompareCommReducer(op->combiner, rhs->combiner) != 0) return;
     if (CompareValue(op->axis.size(), rhs->axis.size()) != 0) return;
     if (CompareValue(op->value_index, rhs->value_index) != 0) return;
@@ -248,73 +248,73 @@ class IRDeepCompare :
     if (CompareArray(op->source, rhs->source) != 0) return;
   }
 
-  void VisitExpr_(const IntImm *op, const Expr& other) final {
-    CompareValue(op->value, other.as<IntImm>()->value);
+  void VisitExpr_(const IntImmNode *op, const Expr& other) final {
+    CompareValue(op->value, other.as<IntImmNode>()->value);
   }
 
-  void VisitExpr_(const UIntImm *op, const Expr& other) final {
-    CompareValue(op->value, other.as<UIntImm>()->value);
+  void VisitExpr_(const UIntImmNode *op, const Expr& other) final {
+    CompareValue(op->value, other.as<UIntImmNode>()->value);
   }
 
-  void VisitExpr_(const FloatImm *op, const Expr& other) final {
-    CompareValue(op->value, other.as<FloatImm>()->value);
+  void VisitExpr_(const FloatImmNode *op, const Expr& other) final {
+    CompareValue(op->value, other.as<FloatImmNode>()->value);
   }
 
-  void VisitExpr_(const StringImm *op, const Expr& other) final {
-    CompareString(op->value, other.as<StringImm>()->value);
+  void VisitExpr_(const StringImmNode *op, const Expr& other) final {
+    CompareString(op->value, other.as<StringImmNode>()->value);
   }
 
-  void VisitExpr_(const Cast *op, const Expr& other) final {
-    CompareExpr(op->value, other.as<Cast>()->value);
+  void VisitExpr_(const CastNode *op, const Expr& other) final {
+    CompareExpr(op->value, other.as<CastNode>()->value);
   }
 
-  void VisitExpr_(const Not *op, const Expr& other) final {
-    CompareExpr(op->a, other.as<Not>()->a);
+  void VisitExpr_(const NotNode *op, const Expr& other) final {
+    CompareExpr(op->a, other.as<NotNode>()->a);
   }
 
-  void VisitExpr_(const Select *op, const Expr& other) final {
-    const Select* rhs = other.as<Select>();
+  void VisitExpr_(const SelectNode *op, const Expr& other) final {
+    const SelectNode* rhs = other.as<SelectNode>();
     if (CompareExpr(op->condition, rhs->condition) != 0) return;
     if (CompareExpr(op->true_value, rhs->true_value) != 0) return;
     if (CompareExpr(op->false_value, rhs->false_value) != 0) return;
   }
 
-  void VisitExpr_(const Ramp *op, const Expr& other) final {
-    const Ramp* rhs = other.as<Ramp>();
+  void VisitExpr_(const RampNode *op, const Expr& other) final {
+    const RampNode* rhs = other.as<RampNode>();
     if (CompareExpr(op->base, rhs->base) != 0) return;
     if (CompareExpr(op->stride, rhs->stride) != 0) return;
     if (CompareValue(op->lanes, rhs->lanes) != 0) return;
   }
 
-  void VisitExpr_(const Broadcast *op, const Expr& other) final {
-    const Broadcast* rhs = other.as<Broadcast>();
+  void VisitExpr_(const BroadcastNode *op, const Expr& other) final {
+    const BroadcastNode* rhs = other.as<BroadcastNode>();
     if (CompareExpr(op->value, rhs->value) != 0) return;
     if (CompareValue(op->lanes, rhs->lanes) != 0) return;
   }
 
-  void VisitExpr_(const Shuffle *op, const Expr& other) final {
-    const Shuffle* rhs = other.as<Shuffle>();
+  void VisitExpr_(const ShuffleNode *op, const Expr& other) final {
+    const ShuffleNode* rhs = other.as<ShuffleNode>();
     if (CompareArray(op->vectors, rhs->vectors) != 0) return;
     if (CompareArray(op->indices, rhs->indices) != 0) return;
   }
 
-  DEFINE_BIOP_EXPR_CMP_(Add)
-  DEFINE_BIOP_EXPR_CMP_(Sub)
-  DEFINE_BIOP_EXPR_CMP_(Mul)
-  DEFINE_BIOP_EXPR_CMP_(Div)
-  DEFINE_BIOP_EXPR_CMP_(Mod)
-  DEFINE_BIOP_EXPR_CMP_(FloorDiv)
-  DEFINE_BIOP_EXPR_CMP_(FloorMod)
-  DEFINE_BIOP_EXPR_CMP_(Min)
-  DEFINE_BIOP_EXPR_CMP_(Max)
-  DEFINE_BIOP_EXPR_CMP_(EQ)
-  DEFINE_BIOP_EXPR_CMP_(NE)
-  DEFINE_BIOP_EXPR_CMP_(LT)
-  DEFINE_BIOP_EXPR_CMP_(LE)
-  DEFINE_BIOP_EXPR_CMP_(GT)
-  DEFINE_BIOP_EXPR_CMP_(GE)
-  DEFINE_BIOP_EXPR_CMP_(And)
-  DEFINE_BIOP_EXPR_CMP_(Or)
+  DEFINE_BIOP_EXPR_CMP_(AddNode)
+  DEFINE_BIOP_EXPR_CMP_(SubNode)
+  DEFINE_BIOP_EXPR_CMP_(MulNode)
+  DEFINE_BIOP_EXPR_CMP_(DivNode)
+  DEFINE_BIOP_EXPR_CMP_(ModNode)
+  DEFINE_BIOP_EXPR_CMP_(FloorDivNode)
+  DEFINE_BIOP_EXPR_CMP_(FloorModNode)
+  DEFINE_BIOP_EXPR_CMP_(MinNode)
+  DEFINE_BIOP_EXPR_CMP_(MaxNode)
+  DEFINE_BIOP_EXPR_CMP_(EQNode)
+  DEFINE_BIOP_EXPR_CMP_(NENode)
+  DEFINE_BIOP_EXPR_CMP_(LTNode)
+  DEFINE_BIOP_EXPR_CMP_(LENode)
+  DEFINE_BIOP_EXPR_CMP_(GTNode)
+  DEFINE_BIOP_EXPR_CMP_(GENode)
+  DEFINE_BIOP_EXPR_CMP_(AndNode)
+  DEFINE_BIOP_EXPR_CMP_(OrNode)
 
  private:
   int CompareExpr(const Expr& lhs, const Expr& rhs) {
@@ -430,7 +430,7 @@ class IRDeepCompare :
   // Only equality/non-equality information is valid.
   bool tie_def_{false};
   // varaible remap if any
-  std::unordered_map<const Variable*, const Variable*> vmap_;
+  std::unordered_map<const VarNode*, const VarNode*> vmap_;
 };
 
 
index dddf90e..b7a7362 100644 (file)
@@ -123,7 +123,7 @@ Stmt IRTransform(Stmt ir_node,
                  const Array<Expr>& only_enable) {
   std::unordered_set<uint32_t> only_type_index;
   for (Expr s : only_enable) {
-    only_type_index.insert(Object::TypeKey2Index(s.as<StringImm>()->value.c_str()));
+    only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
   }
   IRTransformer transform(f_preorder, f_postorder, only_type_index);
   return transform(std::move(ir_node));
@@ -137,23 +137,23 @@ inline void VisitArray(const Array<T>& arr, F fvisit) {
   }
 }
 
-void StmtVisitor::VisitStmt_(const LetStmt* op) {
+void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
   this->VisitExpr(op->value);
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const AttrStmt* op) {
+void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
   this->VisitExpr(op->value);
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const For* op) {
+void StmtVisitor::VisitStmt_(const ForNode* op) {
   this->VisitExpr(op->min);
   this->VisitExpr(op->extent);
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const Allocate* op) {
+void StmtVisitor::VisitStmt_(const AllocateNode* op) {
   VisitArray(op->extents, [this](const Expr& e) { this->VisitExpr(e); });
   this->VisitStmt(op->body);
   this->VisitExpr(op->condition);
@@ -162,13 +162,13 @@ void StmtVisitor::VisitStmt_(const Allocate* op) {
   }
 }
 
-void StmtVisitor::VisitStmt_(const Store* op) {
+void StmtVisitor::VisitStmt_(const StoreNode* op) {
   this->VisitExpr(op->value);
   this->VisitExpr(op->index);
   this->VisitExpr(op->predicate);
 }
 
-void StmtVisitor::VisitStmt_(const IfThenElse* op) {
+void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
   this->VisitExpr(op->condition);
   this->VisitStmt(op->then_case);
   if (op->else_case.defined()) {
@@ -176,24 +176,24 @@ void StmtVisitor::VisitStmt_(const IfThenElse* op) {
   }
 }
 
-void StmtVisitor::VisitStmt_(const Free* op) {}
+void StmtVisitor::VisitStmt_(const FreeNode* op) {}
 
-void StmtVisitor::VisitStmt_(const AssertStmt* op) {
+void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
   this->VisitExpr(op->condition);
   this->VisitExpr(op->message);
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const ProducerConsumer* op) {
+void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) {
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const Provide* op) {
+void StmtVisitor::VisitStmt_(const ProvideNode* op) {
   VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
   this->VisitExpr(op->value);
 }
 
-void StmtVisitor::VisitStmt_(const Realize* op) {
+void StmtVisitor::VisitStmt_(const RealizeNode* op) {
   VisitArray(op->bounds, [this](const Range& r) {
       this->VisitExpr(r->min);
       this->VisitExpr(r->extent);
@@ -202,7 +202,7 @@ void StmtVisitor::VisitStmt_(const Realize* op) {
   this->VisitExpr(op->condition);
 }
 
-void StmtVisitor::VisitStmt_(const Prefetch* op) {
+void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
   VisitArray(op->bounds, [this](const Range& r) {
       this->VisitExpr(r->min);
       this->VisitExpr(r->extent);
@@ -215,23 +215,23 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
     });
 }
 
-void StmtVisitor::VisitStmt_(const Evaluate* op) {
+void StmtVisitor::VisitStmt_(const EvaluateNode* op) {
   this->VisitExpr(op->value);
 }
 
-void ExprVisitor::VisitExpr_(const Variable* op) {}
+void ExprVisitor::VisitExpr_(const VarNode* op) {}
 
-void ExprVisitor::VisitExpr_(const Load* op) {
+void ExprVisitor::VisitExpr_(const LoadNode* op) {
   this->VisitExpr(op->index);
   this->VisitExpr(op->predicate);
 }
 
-void ExprVisitor::VisitExpr_(const Let* op) {
+void ExprVisitor::VisitExpr_(const LetNode* op) {
   this->VisitExpr(op->value);
   this->VisitExpr(op->body);
 }
 
-void ExprVisitor::VisitExpr_(const Call* op) {
+void ExprVisitor::VisitExpr_(const CallNode* op) {
   VisitArray(op->args, [this](const Expr& e) { this->VisitExpr(e); });
 }
 
@@ -241,30 +241,30 @@ void ExprVisitor::VisitExpr_(const Call* op) {
     this->VisitExpr(op->b);                               \
   }
 
-DEFINE_BINOP_VISIT_(Add);
-DEFINE_BINOP_VISIT_(Sub);
-DEFINE_BINOP_VISIT_(Mul);
-DEFINE_BINOP_VISIT_(Div);
-DEFINE_BINOP_VISIT_(Mod);
-DEFINE_BINOP_VISIT_(FloorDiv);
-DEFINE_BINOP_VISIT_(FloorMod);
-DEFINE_BINOP_VISIT_(Min);
-DEFINE_BINOP_VISIT_(Max);
-DEFINE_BINOP_VISIT_(EQ);
-DEFINE_BINOP_VISIT_(NE);
-DEFINE_BINOP_VISIT_(LT);
-DEFINE_BINOP_VISIT_(LE);
-DEFINE_BINOP_VISIT_(GT);
-DEFINE_BINOP_VISIT_(GE);
-DEFINE_BINOP_VISIT_(And);
-DEFINE_BINOP_VISIT_(Or);
-
-void ExprVisitor::VisitExpr_(const IntImm* op) {}
-void ExprVisitor::VisitExpr_(const UIntImm* op) {}
-void ExprVisitor::VisitExpr_(const FloatImm* op) {}
-void ExprVisitor::VisitExpr_(const StringImm* op) {}
-
-void ExprVisitor::VisitExpr_(const Reduce* op) {
+DEFINE_BINOP_VISIT_(AddNode);
+DEFINE_BINOP_VISIT_(SubNode);
+DEFINE_BINOP_VISIT_(MulNode);
+DEFINE_BINOP_VISIT_(DivNode);
+DEFINE_BINOP_VISIT_(ModNode);
+DEFINE_BINOP_VISIT_(FloorDivNode);
+DEFINE_BINOP_VISIT_(FloorModNode);
+DEFINE_BINOP_VISIT_(MinNode);
+DEFINE_BINOP_VISIT_(MaxNode);
+DEFINE_BINOP_VISIT_(EQNode);
+DEFINE_BINOP_VISIT_(NENode);
+DEFINE_BINOP_VISIT_(LTNode);
+DEFINE_BINOP_VISIT_(LENode);
+DEFINE_BINOP_VISIT_(GTNode);
+DEFINE_BINOP_VISIT_(GENode);
+DEFINE_BINOP_VISIT_(AndNode);
+DEFINE_BINOP_VISIT_(OrNode);
+
+void ExprVisitor::VisitExpr_(const IntImmNode* op) {}
+void ExprVisitor::VisitExpr_(const UIntImmNode* op) {}
+void ExprVisitor::VisitExpr_(const FloatImmNode* op) {}
+void ExprVisitor::VisitExpr_(const StringImmNode* op) {}
+
+void ExprVisitor::VisitExpr_(const ReduceNode* op) {
   VisitArray(op->axis, [this](const IterVar& r) {
       this->VisitExpr(r->dom->min);
       this->VisitExpr(r->dom->extent);
@@ -273,31 +273,31 @@ void ExprVisitor::VisitExpr_(const Reduce* op) {
   this->VisitExpr(op->condition);
 }
 
-void ExprVisitor::VisitExpr_(const Cast* op) {
+void ExprVisitor::VisitExpr_(const CastNode* op) {
   this->VisitExpr(op->value);
 }
 
-void ExprVisitor::VisitExpr_(const Not* op) {
+void ExprVisitor::VisitExpr_(const NotNode* op) {
   this->VisitExpr(op->a);
 }
 
-void ExprVisitor::VisitExpr_(const Select* op) {
+void ExprVisitor::VisitExpr_(const SelectNode* op) {
   this->VisitExpr(op->condition);
   this->VisitExpr(op->true_value);
   this->VisitExpr(op->false_value);
 }
 
-void ExprVisitor::VisitExpr_(const Ramp* op) {
+void ExprVisitor::VisitExpr_(const RampNode* op) {
   this->VisitExpr(op->base);
   this->VisitExpr(op->stride);
 }
 
-void ExprVisitor::VisitExpr_(const Shuffle* op) {
+void ExprVisitor::VisitExpr_(const ShuffleNode* op) {
   VisitArray(op->indices, [this](const Expr& e) { this->VisitExpr(e); });
   VisitArray(op->vectors, [this](const Expr& e) { this->VisitExpr(e); });
 }
 
-void ExprVisitor::VisitExpr_(const Broadcast* op) {
+void ExprVisitor::VisitExpr_(const BroadcastNode* op) {
   this->VisitExpr(op->value);
 }
 
@@ -344,7 +344,7 @@ class StmtMutator::Internal {
   }
 };
 
-Stmt StmtMutator::VisitStmt_(const AttrStmt* op) {
+Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
   Expr value = this->VisitExpr(op->value);
   Stmt body = this->VisitStmt(op->body);
   if (value.same_as(op->value) &&
@@ -358,7 +358,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmt* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const LetStmt* op) {
+Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
   Expr value = this->VisitExpr(op->value);
   Stmt body = this->VisitStmt(op->body);
   if (value.same_as(op->value) &&
@@ -372,7 +372,7 @@ Stmt StmtMutator::VisitStmt_(const LetStmt* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const For* op) {
+Stmt StmtMutator::VisitStmt_(const ForNode* op) {
   Expr min = this->VisitExpr(op->min);
   Expr extent = this->VisitExpr(op->extent);
   Stmt body = this->VisitStmt(op->body);
@@ -389,7 +389,7 @@ Stmt StmtMutator::VisitStmt_(const For* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Allocate* op) {
+Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
   Array<Expr> extents = Internal::Mutate(this, op->extents);
   Stmt body = this->VisitStmt(op->body);
   Expr condition = this->VisitExpr(op->condition);
@@ -412,7 +412,7 @@ Stmt StmtMutator::VisitStmt_(const Allocate* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const IfThenElse* op) {
+Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
   Expr condition = this->VisitExpr(op->condition);
   Stmt then_case = this->VisitStmt(op->then_case);
   Stmt else_case;
@@ -432,7 +432,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElse* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Store* op) {
+Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
   Expr value = this->VisitExpr(op->value);
   Expr index = this->VisitExpr(op->index);
   Expr predicate = this->VisitExpr(op->predicate);
@@ -449,7 +449,7 @@ Stmt StmtMutator::VisitStmt_(const Store* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Provide* op) {
+Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
   Array<Expr> args = Internal::Mutate(this, op->args);
   Expr value = this->VisitExpr(op->value);
   if (args.same_as(op->args) &&
@@ -463,7 +463,7 @@ Stmt StmtMutator::VisitStmt_(const Provide* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Realize* op) {
+Stmt StmtMutator::VisitStmt_(const RealizeNode* op) {
   Region bounds = Internal::Mutate(this, op->bounds);
   Stmt body = this->VisitStmt(op->body);
   Expr condition = this->VisitExpr(op->condition);
@@ -480,7 +480,7 @@ Stmt StmtMutator::VisitStmt_(const Realize* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Prefetch* op) {
+Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
   Region bounds = Internal::Mutate(this, op->bounds);
   if (bounds.same_as(op->bounds)) {
     return GetRef<Stmt>(op);
@@ -548,7 +548,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op,
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const AssertStmt* op) {
+Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
   Expr condition = this->VisitExpr(op->condition);
   Expr message = this->VisitExpr(op->message);
   Stmt body = this->VisitStmt(op->body);
@@ -566,7 +566,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmt* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) {
+Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) {
   Stmt body = this->VisitStmt(op->body);
   if (body.same_as(op->body)) {
     return GetRef<Stmt>(op);
@@ -577,7 +577,7 @@ Stmt StmtMutator::VisitStmt_(const ProducerConsumer* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Evaluate* op) {
+Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
     return GetRef<Stmt>(op);
@@ -588,44 +588,44 @@ Stmt StmtMutator::VisitStmt_(const Evaluate* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const Free* op) {
+Stmt StmtMutator::VisitStmt_(const FreeNode* op) {
   return GetRef<Stmt>(op);
 }
 
 
-Expr ExprMutator::VisitExpr_(const Variable* op) {
+Expr ExprMutator::VisitExpr_(const VarNode* op) {
   return GetRef<Expr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const Load* op) {
+Expr ExprMutator::VisitExpr_(const LoadNode* op) {
   Expr index = this->VisitExpr(op->index);
   Expr predicate = this->VisitExpr(op->predicate);
   if (index.same_as(op->index) && predicate.same_as(op->predicate)) {
     return GetRef<Expr>(op);
   } else {
-    return Load::make(op->dtype, op->buffer_var, index, predicate);
+    return LoadNode::make(op->dtype, op->buffer_var, index, predicate);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Let* op) {
+Expr ExprMutator::VisitExpr_(const LetNode* op) {
   Expr value = this->VisitExpr(op->value);
   Expr body = this->VisitExpr(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
     return GetRef<Expr>(op);
   } else {
-    return Let::make(op->var, value, body);
+    return LetNode::make(op->var, value, body);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Call* op) {
+Expr ExprMutator::VisitExpr_(const CallNode* op) {
   auto fmutate = [this](const Expr& e) { return this->VisitExpr(e); };
   Array<Expr> args = MutateArray(op->args, fmutate);
 
   if (args.same_as(op->args)) {
     return GetRef<Expr>(op);
   } else {
-    return Call::make(op->dtype,
+    return CallNode::make(op->dtype,
                       op->name,
                       args,
                       op->call_type,
@@ -639,10 +639,10 @@ Expr ExprMutator::VisitExpr_(const Call* op) {
     return GetRef<Expr>(op);                                      \
   }
 
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
-DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode)
 
 #define DEFINE_BIOP_EXPR_MUTATE_(OP)                                    \
   Expr ExprMutator::VisitExpr_(const OP* op) {                          \
@@ -656,25 +656,25 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
     }                                                                   \
   }
 
-DEFINE_BIOP_EXPR_MUTATE_(Add);
-DEFINE_BIOP_EXPR_MUTATE_(Sub);
-DEFINE_BIOP_EXPR_MUTATE_(Mul);
-DEFINE_BIOP_EXPR_MUTATE_(Div);
-DEFINE_BIOP_EXPR_MUTATE_(Mod);
-DEFINE_BIOP_EXPR_MUTATE_(FloorDiv);
-DEFINE_BIOP_EXPR_MUTATE_(FloorMod);
-DEFINE_BIOP_EXPR_MUTATE_(Min);
-DEFINE_BIOP_EXPR_MUTATE_(Max);
-DEFINE_BIOP_EXPR_MUTATE_(EQ);
-DEFINE_BIOP_EXPR_MUTATE_(NE);
-DEFINE_BIOP_EXPR_MUTATE_(LT);
-DEFINE_BIOP_EXPR_MUTATE_(LE);
-DEFINE_BIOP_EXPR_MUTATE_(GT);
-DEFINE_BIOP_EXPR_MUTATE_(GE);
-DEFINE_BIOP_EXPR_MUTATE_(And);
-DEFINE_BIOP_EXPR_MUTATE_(Or);
-
-Expr ExprMutator::VisitExpr_(const Reduce* op) {
+DEFINE_BIOP_EXPR_MUTATE_(AddNode);
+DEFINE_BIOP_EXPR_MUTATE_(SubNode);
+DEFINE_BIOP_EXPR_MUTATE_(MulNode);
+DEFINE_BIOP_EXPR_MUTATE_(DivNode);
+DEFINE_BIOP_EXPR_MUTATE_(ModNode);
+DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode);
+DEFINE_BIOP_EXPR_MUTATE_(FloorModNode);
+DEFINE_BIOP_EXPR_MUTATE_(MinNode);
+DEFINE_BIOP_EXPR_MUTATE_(MaxNode);
+DEFINE_BIOP_EXPR_MUTATE_(EQNode);
+DEFINE_BIOP_EXPR_MUTATE_(NENode);
+DEFINE_BIOP_EXPR_MUTATE_(LTNode);
+DEFINE_BIOP_EXPR_MUTATE_(LENode);
+DEFINE_BIOP_EXPR_MUTATE_(GTNode);
+DEFINE_BIOP_EXPR_MUTATE_(GENode);
+DEFINE_BIOP_EXPR_MUTATE_(AndNode);
+DEFINE_BIOP_EXPR_MUTATE_(OrNode);
+
+Expr ExprMutator::VisitExpr_(const ReduceNode* op) {
   auto fitervar =  [this](const IterVar& v) {
     Range r = v->dom;
     Expr min = this->VisitExpr(r->min);
@@ -700,30 +700,30 @@ Expr ExprMutator::VisitExpr_(const Reduce* op) {
       condition.same_as(op->condition)) {
     return GetRef<Expr>(op);
   } else {
-    return Reduce::make(
+    return ReduceNode::make(
       op->combiner, source, axis, condition, op->value_index);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Cast* op) {
+Expr ExprMutator::VisitExpr_(const CastNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
     return GetRef<Expr>(op);
   } else {
-    return Cast::make(op->dtype, value);
+    return CastNode::make(op->dtype, value);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Not* op) {
+Expr ExprMutator::VisitExpr_(const NotNode* op) {
   Expr a = this->VisitExpr(op->a);
   if (a.same_as(op->a)) {
     return GetRef<Expr>(op);
   } else {
-    return Not::make(a);
+    return NotNode::make(a);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Select* op) {
+Expr ExprMutator::VisitExpr_(const SelectNode* op) {
   Expr condition = this->VisitExpr(op->condition);
   Expr true_value = this->VisitExpr(op->true_value);
   Expr false_value = this->VisitExpr(op->false_value);
@@ -732,37 +732,37 @@ Expr ExprMutator::VisitExpr_(const Select* op) {
       false_value.same_as(op->false_value)) {
     return GetRef<Expr>(op);
   } else {
-    return Select::make(condition, true_value, false_value);
+    return SelectNode::make(condition, true_value, false_value);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Ramp* op) {
+Expr ExprMutator::VisitExpr_(const RampNode* op) {
   Expr base = this->VisitExpr(op->base);
   Expr stride = this->VisitExpr(op->stride);
   if (base.same_as(op->base) &&
       stride.same_as(op->stride)) {
     return GetRef<Expr>(op);
   } else {
-    return Ramp::make(base, stride, op->lanes);
+    return RampNode::make(base, stride, op->lanes);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Broadcast* op) {
+Expr ExprMutator::VisitExpr_(const BroadcastNode* op) {
   Expr value = this->VisitExpr(op->value);
   if (value.same_as(op->value)) {
     return GetRef<Expr>(op);
   } else {
-    return Broadcast::make(value, op->lanes);
+    return BroadcastNode::make(value, op->lanes);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const Shuffle* op) {
+Expr ExprMutator::VisitExpr_(const ShuffleNode* op) {
   auto fexpr = [this](const Expr& e) { return this->VisitExpr(e); };
   auto vectors = MutateArray(op->vectors, fexpr);
   if (vectors.same_as(op->vectors)) {
     return GetRef<Expr>(op);
   } else {
-    return Shuffle::make(vectors, op->indices);
+    return ShuffleNode::make(vectors, op->indices);
   }
 }
 
index 8956a4d..8ecfbff 100644 (file)
@@ -30,23 +30,23 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
   // use reverse iteration
   for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
     Stmt s = *ri;
-    if (const auto* for_ = s.as<For>()) {
-      auto n = make_object<For>(*for_);
+    if (const auto* for_ = s.as<ForNode>()) {
+      auto n = make_object<ForNode>(*for_);
       CHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
-    } else if (const auto* let = s.as<LetStmt>()) {
-      auto n = make_object<LetStmt>(*let);
+    } else if (const auto* let = s.as<LetStmtNode>()) {
+      auto n = make_object<LetStmtNode>(*let);
       CHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
-    } else if (const auto* attr = s.as<AttrStmt>()) {
-      auto n = make_object<AttrStmt>(*attr);
+    } else if (const auto* attr = s.as<AttrStmtNode>()) {
+      auto n = make_object<AttrStmtNode>(*attr);
       CHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
-    } else if (const auto* ite = s.as<IfThenElse>()) {
-      auto n = make_object<IfThenElse>(*ite);
+    } else if (const auto* ite = s.as<IfThenElseNode>()) {
+      auto n = make_object<IfThenElseNode>(*ite);
       CHECK(is_no_op(n->then_case));
       CHECK(!n->else_case.defined());
       n->then_case = body;
@@ -56,13 +56,13 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
       CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
       n->seq.Set(n->size() - 1, body);
       body = Stmt(n);
-    } else if (const auto* assert_ = s.as<AssertStmt>()) {
-      auto n = make_object<AssertStmt>(*assert_);
+    } else if (const auto* assert_ = s.as<AssertStmtNode>()) {
+      auto n = make_object<AssertStmtNode>(*assert_);
       CHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
-    } else if (const auto* alloc = s.as<Allocate>()) {
-      auto n = make_object<Allocate>(*alloc);
+    } else if (const auto* alloc = s.as<AllocateNode>()) {
+      auto n = make_object<AllocateNode>(*alloc);
       CHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
index 900d6d5..74d5781 100644 (file)
@@ -88,7 +88,7 @@ inline Expr TVMStructGet(
     handle,
     make_const(DataType::Int(32), index),
     make_const(DataType::Int(32), static_cast<int>(kind))};
-  return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
+  return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic);
 }
 
 /*!
@@ -98,11 +98,11 @@ inline Expr TVMStructGet(
  * \param offset the offset index.
  */
 inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
-  return Call::make(
+  return CallNode::make(
       DataType::Handle(), intrinsic::tvm_address_of,
-      {Load::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
+      {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()),
                   const_true(dtype.lanes()))},
-      Call::PureIntrinsic);
+      CallNode::PureIntrinsic);
 }
 
 /*!
@@ -114,13 +114,13 @@ inline Expr AddressOffset(Var handle, DataType dtype, int offset) {
 inline Expr AddressOffset(Var handle, DataType dtype, Expr offset) {
   if (dtype.lanes() != 1) {
     offset = offset * make_const(offset.dtype(), dtype.lanes());
-    offset = Ramp::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
+    offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes());
   }
-  return Call::make(
+  return CallNode::make(
       DataType::Handle(), intrinsic::tvm_address_of,
-      {Load::make(dtype, handle, offset,
+      {LoadNode::make(dtype, handle, offset,
                   const_true(dtype.lanes()))},
-      Call::PureIntrinsic);
+      CallNode::PureIntrinsic);
 }
 
 /*!
@@ -139,8 +139,8 @@ inline Stmt TVMStructSet(
     make_const(DataType::Int(32), index),
     make_const(DataType::Int(32), static_cast<int>(kind)),
     value};
-  return Evaluate::make(
-      Call::make(DataType::Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
+  return EvaluateNode::make(
+      CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
 }
 
 /*!
@@ -183,7 +183,7 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
  * \return true if pattern match success and store the base to base.
  */
 inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
-  const Ramp* r = index.as<Ramp>();
+  const RampNode* r = index.as<RampNode>();
   if (!r) return false;
   if (!is_one(r->stride)) return false;
   CHECK_EQ(r->lanes, lanes);
index 4f2df7b..9a97031 100644 (file)
@@ -40,23 +40,23 @@ class AttrScopeLifter : public StmtMutator {
   Stmt Lift(Stmt stmt) {
     stmt = operator()(std::move(stmt));
     if (attr_node_.defined()) {
-      stmt = AttrStmt::make(
+      stmt = AttrStmtNode::make(
           attr_node_, attr_key_, attr_value_, stmt);
     }
     return stmt;
   }
 
   // do not go beyond
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     if (attr_node_.defined()) {
-      Stmt body = AttrStmt::make(
+      Stmt body = AttrStmtNode::make(
           attr_node_, attr_key_, attr_value_, op->body);
       // undefine them
       attr_node_ = ObjectRef();
       attr_value_ = Expr();
-      return Allocate::make(
+      return AllocateNode::make(
         op->buffer_var, op->dtype,
         op->extents, op->condition, body,
         op->new_expr, op->free_function);
@@ -65,7 +65,7 @@ class AttrScopeLifter : public StmtMutator {
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr_key_) {
       attr_node_ = op->node;
       attr_value_ = op->value;
@@ -116,7 +116,7 @@ class AttrScopeLifter : public StmtMutator {
       }
       Stmt stmt = SeqStmt::Flatten(seq);
       if (attr_node[begin].defined()) {
-        stmt = AttrStmt::make(
+        stmt = AttrStmtNode::make(
             attr_node[begin], attr_key_, attr_value[begin], stmt);
       }
       reorg.push_back(stmt);
@@ -127,7 +127,7 @@ class AttrScopeLifter : public StmtMutator {
     return SeqStmt::Flatten(reorg);
   }
 
-  Stmt VisitStmt_(const IfThenElse* op) final {
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
     if (!op->else_case.defined()) {
       return StmtMutator::VisitStmt_(op);
     }
@@ -147,15 +147,15 @@ class AttrScopeLifter : public StmtMutator {
           else_case.same_as(op->else_case)) {
         return GetRef<Stmt>(op);
       } else {
-        return IfThenElse::make(op->condition, then_case, else_case);
+        return IfThenElseNode::make(op->condition, then_case, else_case);
       }
     } else {
       if (first_node.defined()) {
-        then_case = AttrStmt::make(
+        then_case = AttrStmtNode::make(
             first_node, attr_key_, first_value, then_case);
       }
       if (attr_node_.defined()) {
-        else_case = AttrStmt::make(
+        else_case = AttrStmtNode::make(
             attr_node_, attr_key_, attr_value_, else_case);
         // undefine them
         attr_node_ = ObjectRef();
@@ -165,7 +165,7 @@ class AttrScopeLifter : public StmtMutator {
           else_case.same_as(op->else_case)) {
         return GetRef<Stmt>(op);
       } else {
-        return IfThenElse::make(op->condition, then_case, else_case);
+        return IfThenElseNode::make(op->condition, then_case, else_case);
       }
     }
   }
@@ -177,11 +177,11 @@ class AttrScopeLifter : public StmtMutator {
     if (!a.defined() || !b.defined()) return false;
     if (a->type_index() != b->type_index()) return false;
     if (a.dtype() != b.dtype()) return false;
-    if (const IntImm* op = a.as<IntImm>()) {
-      return op->value == b.as<IntImm>()->value;
+    if (const IntImmNode* op = a.as<IntImmNode>()) {
+      return op->value == b.as<IntImmNode>()->value;
     }
-    if (const UIntImm* op = a.as<UIntImm>()) {
-      return op->value == b.as<UIntImm>()->value;
+    if (const UIntImmNode* op = a.as<UIntImmNode>()) {
+      return op->value == b.as<UIntImmNode>()->value;
     }
     return false;
   }
index aa8ebe1..7d9ce62 100644 (file)
@@ -49,10 +49,10 @@ struct PartitionKeyHash {
 // condition cond is proven to have value cond_value (true or false) in interval.
 using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
 
-bool ExprUseVars(Expr expr, const std::unordered_set<const Variable*>& vars) {
+bool ExprUseVars(Expr expr, const std::unordered_set<const VarNode*>& vars) {
   bool success = false;
   PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) {
-    if (const Variable* v = node.as<Variable>()) {
+    if (const VarNode* v = node.as<VarNode>()) {
       if (vars.count(v)) {
         success = true;
         return;
@@ -72,10 +72,10 @@ class CandidateSelector final : public StmtExprVisitor {
   explicit CandidateSelector(bool split_const_loop)
       : split_const_loop_(split_const_loop) {}
 
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     // partition const loop when sets split_const_loop_
     if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) {
-      const Variable* var = op->loop_var.get();
+      const VarNode* var = op->loop_var.get();
       record_.insert({var, false});
       StmtExprVisitor::VisitStmt_(op);
       if (record_.at(var) && !no_split_) {
@@ -87,7 +87,7 @@ class CandidateSelector final : public StmtExprVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       const IterVarNode *iv = op->node.as<IterVarNode>();
       CHECK(iv);
@@ -118,8 +118,8 @@ class CandidateSelector final : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Call* op) final {
-    if (op->is_intrinsic(Call::likely)) {
+  void VisitExpr_(const CallNode* op) final {
+    if (op->is_intrinsic(CallNode::likely)) {
       in_likely_ = true;
       StmtExprVisitor::VisitExpr_(op);
       in_likely_ = false;
@@ -132,7 +132,7 @@ class CandidateSelector final : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     if (in_likely_ && record_.count(op)) {
       record_.at(op) = true;
     }
@@ -144,7 +144,7 @@ class CandidateSelector final : public StmtExprVisitor {
   bool in_likely_{false};
   bool no_split_{false};
   bool split_const_loop_{false};
-  std::unordered_map<const Variable*, VarIsUsed> record_;
+  std::unordered_map<const VarNode*, VarIsUsed> record_;
 };
 
 // Populate partitions data structure, i.e., for a specific variable,
@@ -153,8 +153,8 @@ class CandidateSelector final : public StmtExprVisitor {
 class PartitionFinder : public StmtExprVisitor {
  public:
   explicit PartitionFinder(VarExpr current_var,
-    const std::unordered_map<const Variable*, IntSet>& hint_map,
-    const std::unordered_map<const Variable*, IntSet>& relax_map)
+    const std::unordered_map<const VarNode*, IntSet>& hint_map,
+    const std::unordered_map<const VarNode*, IntSet>& relax_map)
       : current_var_(current_var), hint_map_(hint_map),  relax_map_(relax_map) {
         for (const auto& kv : hint_map) {
           out_vars_.insert(kv.first);
@@ -164,10 +164,10 @@ class PartitionFinder : public StmtExprVisitor {
         }
       }
 
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
 
-    const Variable* var = op->loop_var.get();
+    const VarNode* var = op->loop_var.get();
     hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
     relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
     StmtExprVisitor::VisitStmt_(op);
@@ -175,12 +175,12 @@ class PartitionFinder : public StmtExprVisitor {
     hint_map_.erase(var);
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     // handle thread_axis
     if (op->attr_key == attr::thread_extent) {
       const IterVarNode* thread_axis = op->node.as<IterVarNode>();
       CHECK(thread_axis);
-      const Variable* var = thread_axis->var.get();
+      const VarNode* var = thread_axis->var.get();
       IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
       hint_map_.insert({var, dom});
       relax_map_.insert({var, dom});
@@ -192,11 +192,11 @@ class PartitionFinder : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Call* op) final {
-    if (op->is_intrinsic(Call::likely)) {
+  void VisitExpr_(const CallNode* op) final {
+    if (op->is_intrinsic(CallNode::likely)) {
       Expr cond = op->args[0];
       if (ExprUseVars(cond,
-          std::unordered_set<const Variable*>({current_var_.get()}))) {
+          std::unordered_set<const VarNode*>({current_var_.get()}))) {
         // For cond, find out the interval, if exists, in which we can prove that cond is
         // true. Also find the interval, if exists, in which we can prove that cond is
         // false.
@@ -226,32 +226,32 @@ class PartitionFinder : public StmtExprVisitor {
  private:
   Expr InverseCond(const Expr& cond) {
     Expr inverse_cond;
-    if (const LT* op = cond.as<LT>()) {
+    if (const LTNode* op = cond.as<LTNode>()) {
       // a < b -> a >= b
-      inverse_cond = GE::make(op->a, op->b);
-    } else if (const GT* op = cond.as<GT>()) {
+      inverse_cond = GENode::make(op->a, op->b);
+    } else if (const GTNode* op = cond.as<GTNode>()) {
       // a > b -> a <= b
-      inverse_cond = LE::make(op->a, op->b);
-    } else if (const LE* op = cond.as<LE>()) {
+      inverse_cond = LENode::make(op->a, op->b);
+    } else if (const LENode* op = cond.as<LENode>()) {
       // a <= b -> a > b
-      inverse_cond = GT::make(op->a, op->b);
-    } else if (const GE* op = cond.as<GE>()) {
+      inverse_cond = GTNode::make(op->a, op->b);
+    } else if (const GENode* op = cond.as<GENode>()) {
       // a >= b -> a < b
-      inverse_cond = LT::make(op->a, op->b);
-    } else if (const EQ* op = cond.as<EQ>()) {
+      inverse_cond = LTNode::make(op->a, op->b);
+    } else if (const EQNode* op = cond.as<EQNode>()) {
       // a == b -> a != b
-      inverse_cond = NE::make(op->a, op->b);
+      inverse_cond = NENode::make(op->a, op->b);
       // a != b -> a == b
-    } else if (const NE* op = cond.as<NE>()) {
-      inverse_cond = EQ::make(op->a, op->b);
+    } else if (const NENode* op = cond.as<NENode>()) {
+      inverse_cond = EQNode::make(op->a, op->b);
     }
     return inverse_cond;
   }
 
   VarExpr current_var_;
-  std::unordered_set<const Variable*> out_vars_;
-  std::unordered_map<const Variable*, IntSet> hint_map_;
-  std::unordered_map<const Variable*, IntSet> relax_map_;
+  std::unordered_set<const VarNode*> out_vars_;
+  std::unordered_map<const VarNode*, IntSet> hint_map_;
+  std::unordered_map<const VarNode*, IntSet> relax_map_;
 };
 
 // Replace the set of conditions given by ps with cond_value (true or false)
@@ -279,16 +279,16 @@ class ThreadPartitionInserter : public StmtMutator {
   explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps,
     Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       innermost_thread_scope_ = true;
       Stmt stmt = StmtMutator::VisitStmt_(op);
       // add branch code inside the innermost thread scope
       if (innermost_thread_scope_) {
         Stmt simplified_body = ConditionEliminator(ps_)(op->body);
-        Stmt body = IfThenElse::make(cond_, simplified_body, op->body);
+        Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
         Expr value = this->VisitExpr(op->value);
-        stmt = AttrStmt::make(op->node, op->attr_key, value, body);
+        stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
       }
       innermost_thread_scope_ = false;
       return stmt;
@@ -315,7 +315,7 @@ class LoopPartitioner : public StmtMutator {
     return operator()(std::move(stmt));
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     if (selector.candidates.count(op)) {
       Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var,
           op->min, op->min + op->extent - 1, op->body, false);
@@ -331,7 +331,7 @@ class LoopPartitioner : public StmtMutator {
     return res;
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key != attr::thread_extent) {
       return StmtMutator::VisitStmt_(op);
     }
@@ -374,8 +374,8 @@ class LoopPartitioner : public StmtMutator {
   inline Stmt MakeFor(const Object* op, Expr extent, Stmt body);
 
   /* Candidate IRs that may be partitioned potentially */
-  std::unordered_map<const Variable*, IntSet> hint_map_;
-  std::unordered_map<const Variable*, IntSet> relax_map_;
+  std::unordered_map<const VarNode*, IntSet> hint_map_;
+  std::unordered_map<const VarNode*, IntSet> relax_map_;
   arith::Analyzer analyzer_;
   CandidateSelector selector;
 };
@@ -506,7 +506,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
       if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the pre doubt loop";
-        body_begin = Max::make(body_begin, min);
+        body_begin = MaxNode::make(body_begin, min);
         // stop recursing on this interval if we can't prove it has non-negative length
         pre_stmt_recurse = false;
       }
@@ -532,7 +532,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
       if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the post doubt loop";
-        post_doubt_begin = Min::make(post_doubt_begin, max+1);
+        post_doubt_begin = MinNode::make(post_doubt_begin, max+1);
         // stop recursing on this interval if we can't prove it has non-negative length
         post_stmt_recurse = false;
       }
@@ -581,21 +581,21 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
 }
 
 inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) {
-  const For *for_node = static_cast<const For*>(node);
+  const ForNode *for_node = static_cast<const ForNode*>(node);
   CHECK(for_node);
   if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) {
     // If the loop extent is 1, do not create the loop anymore
     return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
   } else {
-    return For::make(for_node->loop_var, 0, extent,
+    return ForNode::make(for_node->loop_var, 0, extent,
                      for_node->for_type, for_node->device_api, body);
   }
 }
 
 class RemoveLikelyTags : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const Call *op) final {
-    if (op->is_intrinsic(Call::likely)) {
+  Expr VisitExpr_(const CallNode *op) final {
+    if (op->is_intrinsic(CallNode::likely)) {
       CHECK_EQ(op->args.size(), 1);
       return StmtExprMutator::VisitExpr(op->args[0]);
     } else {
index 2440b1f..ded17d4 100644 (file)
@@ -41,14 +41,14 @@ class CustomDatatypesLowerer : public StmtExprMutator {
  public:
   explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {}
 
-  inline Expr VisitExpr_(const Cast* op) final {
+  inline Expr VisitExpr_(const CastNode* op) final {
     auto type_code = op->dtype.code();
     auto src_type_code = op->value.dtype().code();
     // If either datatype is a registered custom datatype, we must lower.
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) ||
                        datatype::Registry::Global()->GetTypeRegistered(src_type_code);
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Cast>();
+    op = expr.as<CastNode>();
     if (toBeLowered) {
       auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code);
       CHECK(lower) << "Cast lowering function for target " << target_ << " destination type "
@@ -59,7 +59,7 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;
   }
 
-  inline Expr VisitExpr_(const FloatImm* imm) final {
+  inline Expr VisitExpr_(const FloatImmNode* imm) final {
     auto type_code = imm->dtype.code();
     auto e = GetRef<Expr>(imm);
     if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
@@ -71,37 +71,37 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return e;
   }
 
-  inline Stmt VisitStmt_(const Allocate* allocate) final {
+  inline Stmt VisitStmt_(const AllocateNode* allocate) final {
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->dtype.code());
     Stmt stmt = StmtExprMutator::VisitStmt_(allocate);
-    allocate = stmt.as<Allocate>();
+    allocate = stmt.as<AllocateNode>();
 
     if (toBeLowered) {
       auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
-      return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents,
+      return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
                             allocate->condition, allocate->body, allocate->new_expr,
                             allocate->free_function);
     }
     return stmt;
   }
 
-  inline Expr VisitExpr_(const Load* load) final {
+  inline Expr VisitExpr_(const LoadNode* load) final {
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
     Expr expr = StmtExprMutator::VisitExpr_(load);
-    load = expr.as<Load>();
+    load = expr.as<LoadNode>();
     if (toBeLowered) {
       auto new_load_type = DataType::UInt(load->dtype.bits());
-      return Load::make(new_load_type, load->buffer_var, load->index, load->predicate);
+      return LoadNode::make(new_load_type, load->buffer_var, load->index, load->predicate);
     }
     return expr;
   }
 
-#define DEFINE_MUTATE__(OP)                                                        \
-  inline Expr VisitExpr_(const OP* op) final {                                     \
+#define DEFINE_MUTATE__(OP, NodeName)                                              \
+  inline Expr VisitExpr_(const NodeName* op) final {                                     \
     auto type_code = op->dtype.code();                                             \
     bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \
     Expr expr = StmtExprMutator::VisitExpr_(op);                                   \
-    op = expr.as<OP>();                                                            \
+    op = expr.as<NodeName>();                                                            \
     if (toBeLowered) {                                                             \
       auto lower = datatype::Get##OP##LowerFunc(target_, type_code);               \
       CHECK(lower) << #OP " lowering function for target " << target_ << " type "  \
@@ -111,19 +111,19 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     return expr;                                                                   \
   }
 
-  DEFINE_MUTATE__(Add)
-  DEFINE_MUTATE__(Sub)
-  DEFINE_MUTATE__(Mul)
-  DEFINE_MUTATE__(Div)
-  DEFINE_MUTATE__(Mod)
-  DEFINE_MUTATE__(Min)
-  DEFINE_MUTATE__(Max)
-  DEFINE_MUTATE__(EQ)
-  DEFINE_MUTATE__(NE)
-  DEFINE_MUTATE__(LT)
-  DEFINE_MUTATE__(LE)
-  DEFINE_MUTATE__(GT)
-  DEFINE_MUTATE__(GE)
+  DEFINE_MUTATE__(Add, AddNode);
+  DEFINE_MUTATE__(Sub, SubNode);
+  DEFINE_MUTATE__(Mul, MulNode);
+  DEFINE_MUTATE__(Div, DivNode);
+  DEFINE_MUTATE__(Mod, ModNode);
+  DEFINE_MUTATE__(Min, MinNode);
+  DEFINE_MUTATE__(Max, MaxNode);
+  DEFINE_MUTATE__(EQ, EQNode);
+  DEFINE_MUTATE__(NE, NENode);
+  DEFINE_MUTATE__(LT, LTNode);
+  DEFINE_MUTATE__(LE, LENode);
+  DEFINE_MUTATE__(GT, GTNode);
+  DEFINE_MUTATE__(GE, GENode);
   // Later changes may need to add more mutate functions as we support workloads with more ops.
 
  private:
index 0f49710..b46bf18 100644 (file)
@@ -53,19 +53,19 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     }
   }
 
-  Expr VisitExpr_(const Call* op) final {
-    if (op->call_type == Call::Intrinsic ||
-        op->call_type == Call::PureIntrinsic) {
+  Expr VisitExpr_(const CallNode* op) final {
+    if (op->call_type == CallNode::Intrinsic ||
+        op->call_type == CallNode::PureIntrinsic) {
       Expr r = ApplyPattern(op->name, GetRef<Expr>(op));
       if (r.defined()) return r;
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const Add* op) final {
-    if (const Mul* mb = op->b.as<Mul>()) {
+  Expr VisitExpr_(const AddNode* op) final {
+    if (const MulNode* mb = op->b.as<MulNode>()) {
       return MakeFMA(mb->a, mb->b, op->a, op);
-    } else if (const Mul* ma = op->a.as<Mul>()) {
+    } else if (const MulNode* ma = op->a.as<MulNode>()) {
       return MakeFMA(ma->a, ma->b, op->b, op);
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
@@ -73,10 +73,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
 
   // We use floordiv for integer analysis,
   // but will need to lower them to native truncdiv instructions
-  Expr VisitExpr_(const FloorDiv* op) final {
+  Expr VisitExpr_(const FloorDivNode* op) final {
     auto e = GetRef<Expr>(op);
     Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-    op = ret.as<FloorDiv>();
+    op = ret.as<FloorDivNode>();
     if (op == nullptr) return ret;
     int shift;
     const DataType& dtype = op->dtype;
@@ -104,7 +104,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
           // equivalent to rdiv + (rmod >= 0 ? 0: -1);
           return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
         } else {
-          return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
+          return ir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1));
         }
       }
     } else {
@@ -114,15 +114,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
       // b < 0  => (rmod <= 0 ? rdiv : rdiv - 1)
       Expr rdiv = truncdiv(op->a, op->b);
       Expr rmod = truncmod(op->a, op->b);
-      return ir::Select::make(
+      return ir::SelectNode::make(
           (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
           rdiv, rdiv - make_const(dtype, 1));
     }
   }
 
-  Expr VisitExpr_(const FloorMod* op) final {
+  Expr VisitExpr_(const FloorModNode* op) final {
     Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
-    op = ret.as<FloorMod>();
+    op = ret.as<FloorModNode>();
     if (op == nullptr) return ret;
     // Lower floordiv to native truncdiv.
     int shift;
@@ -153,7 +153,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
           // -> rmod >= 0 ? 0 : b
           return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1)));
         } else {
-          return ir::Select::make(rmod >= 0, rmod, rmod + op->b);
+          return ir::SelectNode::make(rmod >= 0, rmod, rmod + op->b);
         }
       }
     } else {
@@ -164,13 +164,13 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
       // b > 0 && rmod < 0  -> rmod + b
       // b < 0 && rmod < 0 -> rmod
       // b < 0 && rmod > 0 -> rmod + b
-      return ir::Select::make(
+      return ir::SelectNode::make(
           (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
           rmod, rmod + op->b);
     }
   }
 
-  Expr VisitExpr_(const Max* op) final {
+  Expr VisitExpr_(const MaxNode* op) final {
     using namespace arith;
     PVar<Expr> x, y;
     PVar<Integer> c;
@@ -183,7 +183,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const EQ* op) final {
+  Expr VisitExpr_(const EQNode* op) final {
     using namespace arith;
     PVar<Expr> x, y;
     auto e = GetRef<Expr>(op);
@@ -193,7 +193,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     return IRMutatorWithAnalyzer::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const NE* op) final {
+  Expr VisitExpr_(const NENode* op) final {
     using namespace arith;
     PVar<Expr> x, y;
     auto e = GetRef<Expr>(op);
@@ -209,8 +209,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
     // For some targets, LLVM will generate more efficient FMA
     // instruction with the latter. For example, vmla vs. vmlal
     // on ARM.
-    if (const Broadcast* bcast = e.as<Broadcast>()) {
-      if (const Cast* cast = bcast->value.as<Cast>()) {
+    if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
+      if (const CastNode* cast = bcast->value.as<CastNode>()) {
         auto should_swap = [&]() {
           // Maintain behaviour (int8 -> int16, fp16 -> fp32).
           if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
@@ -228,8 +228,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
         };
 
         if (should_swap()) {
-          Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
-          return Cast::make(bcast->dtype, new_bcast);
+          Expr new_bcast = BroadcastNode::make(cast->value, bcast->lanes);
+          return CastNode::make(bcast->dtype, new_bcast);
         }
       }
     }
@@ -237,19 +237,19 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
   }
 
   Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
-               const Add* op) {
+               const AddNode* op) {
     // emit fma instruction: a * b + c
     Expr lhs = SwapBroadcastCast(a);
     Expr rhs = SwapBroadcastCast(b);
 
     if (fma_ != nullptr && op->dtype.is_float()) {
-      Expr r = (*fma_)(Call::make(
-          op->dtype, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
+      Expr r = (*fma_)(CallNode::make(
+          op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic));
       if (r.defined()) return this->VisitExpr(r);
     } else {
       if (!lhs.same_as(a) || !rhs.same_as(b)) {
-        Expr mul = this->VisitExpr(Mul::make(lhs, rhs));
-        return Add::make(mul, this->VisitExpr(c));
+        Expr mul = this->VisitExpr(MulNode::make(lhs, rhs));
+        return AddNode::make(mul, this->VisitExpr(c));
       }
     }
     return IRMutatorWithAnalyzer::VisitExpr_(op);
index 4712bcc..d38d1da 100644 (file)
@@ -37,7 +37,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   explicit ThreadAllreduceBuilder(int warp_size)
       : warp_size_(warp_size) {}
 
-  Stmt VisitStmt_(const AttrStmt *op) final {
+  Stmt VisitStmt_(const AttrStmtNode *op) final {
     if (op->attr_key == attr::thread_extent) {
       thread_extents_.push_back(op);
       Stmt ret = StmtExprMutator::VisitStmt_(op);
@@ -45,8 +45,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return ret;
     } else if (op->attr_key == attr::storage_scope) {
       Stmt ret = StmtExprMutator::VisitStmt_(op);
-      op = ret.as<AttrStmt>();
-      const Variable* v = op->node.as<Variable>();
+      op = ret.as<AttrStmtNode>();
+      const VarNode* v = op->node.as<VarNode>();
       if (alloc_remap_.count(v)) {
         return op->body;
       } else {
@@ -63,37 +63,37 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const Evaluate* op) final {
+  Stmt VisitStmt_(const EvaluateNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Evaluate>();
-    const Call* call = op->value.as<Call>();
+    op = stmt.as<EvaluateNode>();
+    const CallNode* call = op->value.as<CallNode>();
     if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) {
       return MakeAllreduce(call);
     } else {
       return stmt;
     }
   }
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     auto it = alloc_remap_.find(op->buffer_var.get());
     if (it != alloc_remap_.end()) {
-      const Allocate* repl = it->second.as<Allocate>();
+      const AllocateNode* repl = it->second.as<AllocateNode>();
       // use volatile access to shared buffer.
-      stmt = AttrStmt::make(
+      stmt = AttrStmtNode::make(
           repl->buffer_var, attr::volatile_scope, 1, op->body);
-      stmt = Allocate::make(
+      stmt = AllocateNode::make(
           repl->buffer_var, repl->dtype,
           repl->extents, repl->condition, stmt);
-      stmt = AttrStmt::make(
+      stmt = AttrStmtNode::make(
           repl->buffer_var, attr::storage_scope,
-          StringImm::make("shared"), stmt);
+          StringImmNode::make("shared"), stmt);
       return stmt;
     } else {
       return stmt;
     }
   }
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     auto it = load_remap_.find(op->buffer_var.get());
     if (it != load_remap_.end()) {
       CHECK(is_zero(op->index));
@@ -115,12 +115,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
   };
   // make allreduce.
-  Stmt MakeAllreduce(const Call* call) {
+  Stmt MakeAllreduce(const CallNode* call) {
     CHECK(!reduce_combiner_.empty());
     const CommReducerNode *combiner = reduce_combiner_.back();
     size_t size = combiner->result.size();
 
-    const UIntImm *size_of_args = call->args[0].as<UIntImm>();
+    const UIntImmNode *size_of_args = call->args[0].as<UIntImmNode>();
     CHECK(size_of_args) << call->args[0]->GetTypeKey();
     CHECK_EQ(size, size_of_args->value);
     Array<Expr> inits = combiner->identity_element;
@@ -130,26 +130,26 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     for (size_t idx = 0; idx < size; ++idx) {
       values[idx] = call->args[1+idx];
       if (!is_one(cond)) {
-        values[idx] = Select::make(cond, values[idx], inits[idx]);
+        values[idx] = SelectNode::make(cond, values[idx], inits[idx]);
       }
       types[idx] = values[idx].dtype();
     }
-    std::vector<const Variable*> buffers(size);
+    std::vector<const VarNode*> buffers(size);
     for (size_t idx = 0; idx < size; ++idx) {
-      const Variable* buffer = call->args[2+size+idx].as<Variable>();
+      const VarNode* buffer = call->args[2+size+idx].as<VarNode>();
       CHECK(buffer);
       buffers[idx] = buffer;
     }
 
-    std::unordered_set<const Variable*> reduce_set;
+    std::unordered_set<const VarNode*> reduce_set;
     for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
-      const Variable* v = call->args[i].as<Variable>();
+      const VarNode* v = call->args[i].as<VarNode>();
       CHECK(v);
       reduce_set.insert(v);
     }
     size_t nmatch = 0;
     std::vector<ThreadEntry> vred, vpar;
-    for (const AttrStmt* attr : thread_extents_) {
+    for (const AttrStmtNode* attr : thread_extents_) {
       ThreadEntry e;
       IterVar iv = Downcast<IterVar>(attr->node);
       e.scope = runtime::ThreadScope::make(iv->thread_tag);
@@ -183,7 +183,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       for (size_t i = 0; i < size; ++i) {
         Expr pred = const_true(types[i].lanes());
         Var buffer_var = Downcast<Var>(call->args[2+size+i]);
-        stores[i] = Store::make(buffer_var, values[i], 0, pred);
+        stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
       }
       return SeqStmt::Flatten(stores);
     }
@@ -199,7 +199,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     for (size_t idx = 0; idx < size; ++idx) {
       shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
       Expr pred = const_true(types[idx].lanes());
-      seq.emplace_back(Store::make(
+      seq.emplace_back(StoreNode::make(
           shared_bufs[idx], values[idx],
           BufIndex(reduce_index, group_index, reduce_extent), pred));
     }
@@ -210,13 +210,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     for (size_t idx = 0; idx < size; ++idx) {
       CHECK(!load_remap_.count(buffers[idx]));
       Expr pred = const_true(types[idx].lanes());
-      load_remap_[buffers[idx]] = Load::make(
+      load_remap_[buffers[idx]] = LoadNode::make(
         types[idx], shared_bufs[idx],
         BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
-      alloc_remap_[buffers[idx]] = Allocate::make(
+      alloc_remap_[buffers[idx]] = AllocateNode::make(
         shared_bufs[idx], types[idx],
         {Expr(group_extent), Expr(reduce_extent)},
-        pred, Evaluate::make(0));
+        pred, EvaluateNode::make(0));
     }
     return SeqStmt::Flatten(seq);
   }
@@ -242,15 +242,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     auto freduce = [&](int offset) {
       Array<Expr> a, b;
       for (size_t i = 0; i < size; ++i) {
-        b.push_back(Load::make(types[i], shared_bufs[i],
+        b.push_back(LoadNode::make(types[i], shared_bufs[i],
           BufIndex(reduce_index + offset, group_index, reduce_extent),
           const_true()));
-        a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true()));
+        a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true()));
       }
       Array<Expr> ret = (*combiner)(a, b);
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
-        stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true());
+        stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
       }
       return SeqStmt::Flatten(stores);
     };
@@ -259,7 +259,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // reduction with the boundary condition
       reduce_align = reduce_align >> 1;
       Expr cond = reduce_index < (reduce_extent - reduce_align);
-      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+      seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
     CHECK(threadx_extent >= 1 && warp_size_ >= 1);
@@ -268,7 +268,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
            reduce_align > warp_size_) {
       reduce_align =  reduce_align >> 1;
       Expr cond = reduce_index < reduce_align;
-      seq.emplace_back(IfThenElse::make(cond, freduce(reduce_align)));
+      seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
     // in warp synchronization.
@@ -281,7 +281,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
     if (in_warp_seq.size() != 0) {
       Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
-      seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body));
+      seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body));
       seq.emplace_back(SyncThread("shared"));
     }
     return SeqStmt::Flatten(seq);
@@ -310,10 +310,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   }
   // sync thread op.
   static Stmt SyncThread(const std::string& sync) {
-    return Evaluate::make(
-        Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
-                   {StringImm::make(sync)},
-                   Call::Intrinsic));
+    return EvaluateNode::make(
+        CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+                   {StringImmNode::make(sync)},
+                   CallNode::Intrinsic));
   }
   // The local buffer index.
   static Expr BufIndex(Expr reduce_index, Expr group_index, int reduce_extent) {
@@ -327,12 +327,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   int warp_size_{1};
 
   // surrounding scope of thread extent.
-  std::vector<const AttrStmt*> thread_extents_;
+  std::vector<const AttrStmtNode*> thread_extents_;
   std::vector<const CommReducerNode*> reduce_combiner_;
   // The load remap
-  std::unordered_map<const Variable *, Expr> load_remap_;
+  std::unordered_map<const VarNode *, Expr> load_remap_;
   // Allocate remap
-  std::unordered_map<const Variable *, Stmt> alloc_remap_;
+  std::unordered_map<const VarNode *, Stmt> alloc_remap_;
 };
 
 LoweredFunc
index c0b9879..a9b401f 100644 (file)
@@ -37,8 +37,11 @@ inline Expr ConstInt32(size_t index) {
 }
 
 inline Expr StackAlloca(std::string type, size_t num) {
-  Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
-  return Call::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
+  Array<Expr> args = {StringImmNode::make(type), ConstInt32(num)};
+  return CallNode::make(
+      DataType::Handle(),
+      intrinsic::tvm_stack_alloca,
+      args, CallNode::Intrinsic);
 }
 
 // Calculate the statistics of packed function.
@@ -52,17 +55,17 @@ class BuiltinLower : public StmtExprMutator {
     stack_tcode_ = Var("stack_tcode", DataType::Handle());
     stmt = this->VisitStmt(stmt);
     if (max_shape_stack_ != 0) {
-      stmt = LetStmt::make(
+      stmt = LetStmtNode::make(
           stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
     }
     if (max_array_stack_ != 0) {
-      stmt = LetStmt::make(
+      stmt = LetStmtNode::make(
           stack_array_, StackAlloca("array", max_array_stack_), stmt);
     }
     if (max_arg_stack_ != 0) {
-      stmt = LetStmt::make(
+      stmt = LetStmtNode::make(
           stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
-      stmt = LetStmt::make(
+      stmt = LetStmtNode::make(
           stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
     }
     return stmt;
@@ -82,10 +85,10 @@ class BuiltinLower : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Allocate* op) {
+  Stmt VisitStmt_(const AllocateNode* op) {
     // Lower allocate to device allocate when needed.
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     if (op->new_expr.defined()) return stmt;
     // Get constant allocation bound.
     int64_t dev_type;
@@ -106,45 +109,48 @@ class BuiltinLower : public StmtExprMutator {
     }
     CHECK(device_type_.defined()) << "Unknown device type in current IR";
     CHECK(device_id_.defined()) << "Unknown device id in current IR";
-    Stmt throw_last_error = Evaluate::make(Call::make(DataType::Int(32),
-                                           intrinsic::tvm_throw_last_error, {},
-                                           Call::Intrinsic));
+    Stmt throw_last_error = EvaluateNode::make(
+        CallNode::make(DataType::Int(32),
+                       intrinsic::tvm_throw_last_error, {},
+                       CallNode::Intrinsic));
 
     Stmt body = SeqStmt({
-        IfThenElse::make(Call::make(DataType::Bool(1),
-                                    intrinsic::tvm_handle_is_null,
-                                    {op->buffer_var}, Call::PureIntrinsic),
-                         throw_last_error),
+        IfThenElseNode::make(
+            CallNode::make(DataType::Bool(1),
+                           intrinsic::tvm_handle_is_null,
+                           {op->buffer_var}, CallNode::PureIntrinsic),
+            throw_last_error),
         op->body});
 
-    Stmt alloca = LetStmt::make(
+    Stmt alloca = LetStmtNode::make(
         op->buffer_var,
-        Call::make(op->buffer_var.dtype(),
-                   "TVMBackendAllocWorkspace",
-                   {cast(DataType::Int(32), device_type_),
-                    cast(DataType::Int(32), device_id_),
-                    cast(DataType::UInt(64), total_bytes),
-                    IntImm::make(DataType::Int(32), op->dtype.code()),
-                    IntImm::make(DataType::Int(32), op->dtype.bits())},
-                   Call::Extern),
+        CallNode::make(op->buffer_var.dtype(),
+                       "TVMBackendAllocWorkspace",
+                       {cast(DataType::Int(32), device_type_),
+                        cast(DataType::Int(32), device_id_),
+                        cast(DataType::UInt(64), total_bytes),
+                        IntImmNode::make(DataType::Int(32), op->dtype.code()),
+                        IntImmNode::make(DataType::Int(32), op->dtype.bits())},
+                       CallNode::Extern),
         body);
 
-    Expr free_op = Call::make(DataType::Int(32),
-                              "TVMBackendFreeWorkspace",
-                              {cast(DataType::Int(32), device_type_),
-                                    cast(DataType::Int(32), device_id_),
-                                    op->buffer_var},
-                              Call::Extern);
-    Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
+    Expr free_op = CallNode::make(DataType::Int(32),
+                                  "TVMBackendFreeWorkspace",
+                                  {cast(DataType::Int(32), device_type_),
+                                   cast(DataType::Int(32), device_id_),
+                                   op->buffer_var},
+                                  CallNode::Extern);
+    Stmt free_stmt = IfThenElseNode::make(
+        free_op != make_zero(DataType::Int(32)), throw_last_error);
     body = SeqStmt({alloca, free_stmt});
-    body = AttrStmt::make(
+    body = AttrStmtNode::make(
         op->buffer_var, attr::storage_alignment,
         make_const(DataType::Int(32), runtime::kTempAllocaAlignment),
         body);
     return body;
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::device_context_id) {
       CHECK(!device_id_.defined());
       device_id_ = op->value;
@@ -157,7 +163,7 @@ class BuiltinLower : public StmtExprMutator {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
       return MakeCallPacked(op);
     } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
@@ -173,24 +179,24 @@ class BuiltinLower : public StmtExprMutator {
     }
   }
   // call shape
-  Expr MakeShape(const Call* op) {
+  Expr MakeShape(const CallNode* op) {
     size_t stack_begin = run_shape_stack_;
     run_shape_stack_ += op->args.size();
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
     for (size_t i = 0; i < op->args.size(); ++i) {
       prep_seq_.emplace_back(
-          Store::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
+          StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
                       ConstInt32(stack_begin +i), const_true(1)));
     }
     return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
   }
   // make array
-  Expr MakeArray(const Call* op) {
+  Expr MakeArray(const CallNode* op) {
     size_t idx = run_array_stack_;
     run_array_stack_ += 1;
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
     prep_seq_.emplace_back(
         TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
     prep_seq_.emplace_back(
@@ -233,32 +239,32 @@ class BuiltinLower : public StmtExprMutator {
     return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr);
   }
   // call packed.
-  Expr MakeCallPacked(const Call* op) {
+  Expr MakeCallPacked(const CallNode* op) {
     size_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
     run_arg_stack_ += op->args.size();
     // Specially handle the buffer packed intrinsic
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
     for (size_t i = 1; i < op->args.size(); ++i) {
       Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
       Expr arg = op->args[i];
       DataType t = arg.dtype();
       DataType api_type = APIType(t);
       if (t != api_type) {
-        arg = Cast::make(api_type, arg);
+        arg = CastNode::make(api_type, arg);
       }
       prep_seq_.emplace_back(TVMStructSet(
           stack_value_, static_cast<int>(arg_stack_begin + i - 1),
           intrinsic::kTVMValueContent, arg));
       int arg_tcode = api_type.code();
-      if (api_type.is_handle() && arg.as<StringImm>()) {
+      if (api_type.is_handle() && arg.as<StringImmNode>()) {
         arg_tcode = kStr;
       }
       if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
       prep_seq_.emplace_back(
-          Store::make(stack_tcode_,
+          StoreNode::make(stack_tcode_,
                       ConstInt32(arg_tcode),
                       stack_index, const_true(1)));
     }
@@ -276,12 +282,12 @@ class BuiltinLower : public StmtExprMutator {
       ConstInt32(arg_stack_begin),
       ConstInt32(arg_stack_begin + op->args.size() - 1)
     };
-    return Call::make(
+    return CallNode::make(
         DataType::Int(32), intrinsic::tvm_call_packed_lowered,
-        packed_args, Call::Intrinsic);
+        packed_args, CallNode::Intrinsic);
   }
 
-  Expr MakeCallTracePacked(const Call *op) {
+  Expr MakeCallTracePacked(const CallNode *op) {
     size_t restore_shape_stack = run_shape_stack_;
     size_t restore_array_stack = run_array_stack_;
     size_t arg_stack_begin = run_arg_stack_;
@@ -289,14 +295,14 @@ class BuiltinLower : public StmtExprMutator {
     size_t args_size = op->args.size();
     CHECK_GT(args_size, 0);
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
     for (size_t i = 1; i < op->args.size(); ++i) {
       Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
       Expr arg = op->args[i];
       DataType t = arg.dtype();
       DataType api_type = APIType(t);
       if (t != api_type) {
-        arg = Cast::make(api_type, arg);
+        arg = CastNode::make(api_type, arg);
       }
       prep_seq_.emplace_back(TVMStructSet(
           stack_value_, static_cast<int>(arg_stack_begin + i - 1),
@@ -304,7 +310,7 @@ class BuiltinLower : public StmtExprMutator {
       int arg_tcode = api_type.code();
       CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
       prep_seq_.emplace_back(
-          Store::make(stack_tcode_,
+          StoreNode::make(stack_tcode_,
                       ConstInt32(arg_tcode),
                       stack_index, const_true(1)));
     }
@@ -326,17 +332,17 @@ class BuiltinLower : public StmtExprMutator {
       // Pass traced value.
       op->args[args_size - 1]
     };
-    return Call::make(
+    return CallNode::make(
         op->dtype, intrinsic::tvm_call_trace_packed_lowered,
-        packed_args, Call::Intrinsic);
+        packed_args, CallNode::Intrinsic);
   }
 
  private:
   bool IsArrayHandle(const Expr& arg) {
     // specially set array handle.
-    if (const Call* buf = arg.as<Call>()) {
+    if (const CallNode* buf = arg.as<CallNode>()) {
       if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
-          buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
+          buf->args[2].as<IntImmNode>()->value == intrinsic::kArrAddr) {
         return true;
       }
     }
index 2d24ec4..75f128e 100644 (file)
@@ -76,7 +76,7 @@ namespace ir {
 // store warp_mem[m * warp_index + (warp_size * m) * y + x]
 class WarpStoreCoeffFinder : private StmtVisitor {
  public:
-  WarpStoreCoeffFinder(const Variable* buffer,
+  WarpStoreCoeffFinder(const VarNode* buffer,
                        Var warp_index,
                        arith::Analyzer* analyzer)
       : buffer_(buffer),
@@ -91,7 +91,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
 
  private:
   /// Visitor implementation
-  void VisitStmt_(const Store *op) final {
+  void VisitStmt_(const StoreNode *op) final {
     if (op->buffer_var.get() == buffer_) {
       if (op->value.dtype().lanes() == 1) {
         UpdatePattern(op->index);
@@ -129,7 +129,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
   }
 
   // The buffer variable
-  const Variable* buffer_;
+  const VarNode* buffer_;
   // the warp index
   Var warp_index_;
   // the coefficient
@@ -155,7 +155,7 @@ class WarpIndexFinder : private StmtVisitor {
 
  private:
   /// Visitor implementation
-  void VisitStmt_(const AttrStmt *op) final {
+  void VisitStmt_(const AttrStmtNode *op) final {
     if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
       if (iv->thread_tag == "threadIdx.x") {
@@ -190,7 +190,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
       : warp_size_(warp_size), analyzer_(analyzer) {}
   // Rewrite the allocate statement which transforms
   // warp memory to local memory.
-  Stmt Rewrite(const Allocate* op) {
+  Stmt Rewrite(const AllocateNode* op) {
     buffer_ = op->buffer_var.get();
     int alloc_size = op->constant_allocation_size();
     CHECK_GT(alloc_size, 0)
@@ -202,7 +202,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
     CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
         << "Warp memory must be multiple of warp size";
     warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
-    return Allocate::make(
+    return AllocateNode::make(
         op->buffer_var,
         op->dtype,
         {make_const(DataType::Int(32), alloc_size / warp_size_)},
@@ -211,23 +211,23 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
  protected:
-  Expr Mutate_(const Variable* op) {
+  Expr Mutate_(const VarNode* op) {
     CHECK(op != buffer_)
         << "Cannot access address of warp memory directly";
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const Store* op) {
+  Stmt VisitStmt_(const StoreNode* op) {
     if (op->buffer_var.get() == buffer_) {
       Expr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
-      return Store::make(op->buffer_var, op->value, local_index, op->predicate);
+      return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
 
-  Expr Mutate_(const Load* op) {
+  Expr Mutate_(const LoadNode* op) {
     if (op->buffer_var.get() == buffer_) {
       Expr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
@@ -235,12 +235,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
       CHECK(!ExprUseVar(local_index, {warp_index_.get()}))
           << "LowerWarpMemory failed to rewrite load to shuffle for index "
           << op->index << " local_index=" << local_index;
-      Expr load_value = Load::make(
+      Expr load_value = LoadNode::make(
           op->dtype, op->buffer_var, local_index, op->predicate);
-      return Call::make(load_value.dtype(),
+      return CallNode::make(load_value.dtype(),
                         intrinsic::tvm_warp_shuffle,
                         {load_value, group},
-                        Call::Intrinsic);
+                        CallNode::Intrinsic);
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
@@ -256,7 +256,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
       CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
       std::tie(local_index, group) = SplitIndexByGroup(base);
       local_index =
-          Ramp::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
+          RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
       return std::make_pair(local_index, group);
     }
     Expr m = make_const(index.dtype(), warp_coeff_);
@@ -281,7 +281,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
   // the warp size
   int warp_size_{0};
   // The buffer variable
-  const Variable* buffer_;
+  const VarNode* buffer_;
   // Warp index
   Var warp_index_;
   // the coefficient m
@@ -301,13 +301,13 @@ class BindVarBoundInfo : public StmtVisitor {
   explicit BindVarBoundInfo(arith::Analyzer* analyzer)
       : analyzer_(analyzer) {}
 
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     const Var& loop_var = op->loop_var;
     analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
     StmtVisitor::VisitStmt_(op);
   }
 
-  void VisitStmt_(const AttrStmt* op) {
+  void VisitStmt_(const AttrStmtNode* op) {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::virtual_thread) {
       IterVar iv = Downcast<IterVar>(op->node);
@@ -325,7 +325,7 @@ class BindVarBoundInfo : public StmtVisitor {
   // internal analyzer.
   arith::Analyzer* analyzer_;
   // variable domain
-  std::unordered_map<const Variable*, Range> var_dom_;
+  std::unordered_map<const VarNode*, Range> var_dom_;
 };
 
 // Mutator to change the read pattern
@@ -345,7 +345,7 @@ class WarpMemoryRewriter : private StmtMutator {
   }
 
  private:
-  Stmt VisitStmt_(const Allocate* op) {
+  Stmt VisitStmt_(const AllocateNode* op) {
     if (warp_buffer_.count(op->buffer_var.get())) {
       WarpAccessRewriter rewriter(warp_size_, &analyzer_);
       return rewriter.Rewrite(op);
@@ -354,27 +354,27 @@ class WarpMemoryRewriter : private StmtMutator {
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) {
+  Stmt VisitStmt_(const AttrStmtNode* op) {
     using runtime::StorageScope;
     if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
-      StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
+      const VarNode* buf = op->node.as<VarNode>();
+      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
       if (scope.rank == runtime::StorageRank::kWarp) {
         warp_buffer_.insert(buf);
         Stmt ret = StmtMutator::VisitStmt_(op);
-        op = ret.as<AttrStmt>();
-        return AttrStmt::make(
-            op->node, op->attr_key, StringImm::make("local"), op->body);
+        op = ret.as<AttrStmtNode>();
+        return AttrStmtNode::make(
+            op->node, op->attr_key, StringImmNode::make("local"), op->body);
       }
     }
     return StmtMutator::VisitStmt_(op);
   }
 
   int warp_size_{0};
-  std::unordered_set<const Variable*> warp_buffer_;
+  std::unordered_set<const VarNode*> warp_buffer_;
   arith::Analyzer analyzer_;
   // variable domain
-  std::unordered_map<const Variable*, Range> var_dom_;
+  std::unordered_map<const VarNode*, Range> var_dom_;
 };
 
 LoweredFunc
index 03a6035..56609bb 100644 (file)
@@ -36,7 +36,7 @@ namespace tvm {
 namespace ir {
 
 inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
-  return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
+  return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
 }
 
 LoweredFunc MakeAPI(Stmt body,
@@ -44,7 +44,7 @@ LoweredFunc MakeAPI(Stmt body,
                     Array<ObjectRef> api_args,
                     int num_unpacked_args,
                     bool is_restricted) {
-  const Stmt nop = Evaluate::make(0);
+  const Stmt nop = EvaluateNode::make(0);
   int num_args = static_cast<int>(api_args.size());
   CHECK_LE(num_unpacked_args, num_args);
   int num_packed_args = num_args - num_unpacked_args;
@@ -62,23 +62,23 @@ LoweredFunc MakeAPI(Stmt body,
   // seq_init gives sequence of initialization
   // seq_check gives sequence of later checks after init
   std::vector<Stmt> seq_init, seq_check;
-  std::unordered_map<const Variable*, Expr> vmap;
+  std::unordered_map<const VarNode*, Expr> vmap;
   ArgBinder binder(&vmap);
   // ---------------------------
   // local function definitions
   // load i-th argument as type t
   auto f_arg_value = [&](DataType t, int i) {
     Array<Expr> call_args{v_packed_args,
-                          IntImm::make(DataType::Int(32), i),
-                          IntImm::make(DataType::Int(32), intrinsic::kTVMValueContent)};
+                          IntImmNode::make(DataType::Int(32), i),
+                          IntImmNode::make(DataType::Int(32), intrinsic::kTVMValueContent)};
     // load 64 bit version
     DataType api_type = APIType(t);
-    Expr res = Call::make(
+    Expr res = CallNode::make(
         api_type, intrinsic::tvm_struct_get, call_args,
-        Call::PureIntrinsic);
+        CallNode::PureIntrinsic);
     // cast to the target version.
     if (api_type != t) {
-      res = Cast::make(t, res);
+      res = CastNode::make(t, res);
     }
     return res;
   };
@@ -86,7 +86,7 @@ LoweredFunc MakeAPI(Stmt body,
   auto f_arg_decl = [&](int i) {
     std::ostringstream os;
     os << "arg" << i;
-    const Variable* v = api_args[i].as<Variable>();
+    const VarNode* v = api_args[i].as<VarNode>();
     return Var(os.str(), v ? v->dtype: DataType::Handle());
   };
   // ---------------------------
@@ -110,40 +110,40 @@ LoweredFunc MakeAPI(Stmt body,
     Var v_arg = f_arg_decl(i);
     if (i < num_packed_args) {
       // Value loads
-      seq_init.emplace_back(LetStmt::make(
+      seq_init.emplace_back(LetStmtNode::make(
           v_arg, f_arg_value(v_arg.dtype(), i), nop));
       // type code checks
       Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
-      seq_init.emplace_back(LetStmt::make(
-          tcode, Load::make(
+      seq_init.emplace_back(LetStmtNode::make(
+          tcode, LoadNode::make(
               DataType::Int(32), v_packed_arg_type_ids,
-              IntImm::make(DataType::Int(32), i), const_true(1)),
+              IntImmNode::make(DataType::Int(32), i), const_true(1)),
           nop));
       DataType t = v_arg.dtype();
       if (t.is_handle()) {
         std::ostringstream msg;
         msg << name << ": Expect arg[" << i << "] to be pointer";
         seq_check.emplace_back(
-            AssertStmt::make(tcode == kHandle ||
+            AssertStmtNode::make(tcode == kHandle ||
                              tcode == kNDArrayContainer ||
                              tcode == kArrayHandle ||
                              tcode == kNull, msg.str(), nop));
       } else if (t.is_int() || t.is_uint()) {
         std::ostringstream msg;
         msg << name << ": Expect arg[" << i << "] to be int";
-        seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
+        seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
       } else {
         CHECK(t.is_float());
         std::ostringstream msg;
         msg << name << ": Expect arg[" << i << "] to be float";
         seq_check.emplace_back(
-            AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
+            AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
       }
     } else {
       args.push_back(v_arg);
     }
     // add checks for functions.
-    if (api_args[i].as<Variable>()) {
+    if (api_args[i].as<VarNode>()) {
       var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
     } else {
       // Buffer checks
@@ -184,22 +184,22 @@ LoweredFunc MakeAPI(Stmt body,
   n->handle_data_type = binder.def_handle_dtype();
   n->is_packed_func = num_unpacked_args == 0;
   n->is_restricted = is_restricted;
-  body = AttrStmt::make(
+  body = AttrStmtNode::make(
       make_zero(DataType::Int(32)), attr::compute_scope,
-      StringImm::make(name + "_compute_"), body);
+      StringImmNode::make(name + "_compute_"), body);
   // Set device context
   if (vmap.count(device_id.get())) {
-    Expr node = StringImm::make("default");
+    Expr node = StringImmNode::make("default");
     CHECK(vmap.count(device_type.get()));
-    seq_check.push_back(AttrStmt::make(
+    seq_check.push_back(AttrStmtNode::make(
         node, attr::device_context_id, device_id, nop));
-    seq_check.push_back(AttrStmt::make(
+    seq_check.push_back(AttrStmtNode::make(
         node, attr::device_context_type, device_type, nop));
-    Stmt set_device = IfThenElse::make(
-        device_type != kDLCPU, Evaluate::make(Call::make(
+    Stmt set_device = IfThenElseNode::make(
+        device_type != kDLCPU, EvaluateNode::make(CallNode::make(
             DataType::Int(32), intrinsic::tvm_call_packed,
-            {StringImm::make(runtime::symbol::tvm_set_device),
-             device_type, device_id}, Call::Intrinsic)));
+            {StringImmNode::make(runtime::symbol::tvm_set_device),
+             device_type, device_id}, CallNode::Intrinsic)));
     body = SeqStmt({set_device, body});
   }
   n->body = MergeNest(
@@ -222,28 +222,28 @@ class DeviceTypeBinder: public StmtExprMutator {
   explicit DeviceTypeBinder(int device_type)
       : device_type_(device_type) {}
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::device_context_type) {
-      if (const Variable* var = op->value.as<Variable>()) {
+      if (const VarNode* var = op->value.as<VarNode>()) {
         var_ = var;
         Expr value = make_const(op->value.dtype(), device_type_);
         Stmt body = StmtExprMutator::VisitStmt_(op);
         var_ = nullptr;
         std::ostringstream os;
         os << "device_type need to be " << device_type_;
-        return AssertStmt::make(op->value == value, os.str(), body);
+        return AssertStmtNode::make(op->value == value, os.str(), body);
       }
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const IfThenElse* op) final {
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
     // eager simplify if guard.
     Stmt res = StmtExprMutator::VisitStmt_(op);
-    op = res.as<IfThenElse>();
+    op = res.as<IfThenElseNode>();
     if (is_zero(op->condition)) {
       if (op->else_case.defined()) return op->else_case;
-      return Evaluate::make(0);
+      return EvaluateNode::make(0);
     }
     if (is_one(op->condition)) {
       return op->then_case;
@@ -251,17 +251,17 @@ class DeviceTypeBinder: public StmtExprMutator {
     return res;
   }
 
-  Expr VisitExpr_(const NE* op) final {
+  Expr VisitExpr_(const NENode* op) final {
     // eager check NE for device check
     Expr res = StmtExprMutator::VisitExpr_(op);
-    op = res.as<NE>();
+    op = res.as<NENode>();
     if (ir::Equal(op->a, op->b)) {
       return make_const(op->dtype, false);
     }
     return res;
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     if (op == var_) {
       return make_const(op->dtype, device_type_);
     } else {
@@ -270,7 +270,7 @@ class DeviceTypeBinder: public StmtExprMutator {
   }
 
  public:
-  const Variable* var_{nullptr};
+  const VarNode* var_{nullptr};
   int device_type_;
 };
 
index 92b941a..2a486b5 100644 (file)
@@ -42,28 +42,28 @@ class ThreadAxisRewriter : private StmtExprMutator {
   }
 
  private:
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
       CHECK_NE(iv->thread_tag.length(), 0U);
       auto it = tmap_.find(iv->thread_tag);
       if (it != tmap_.end()) {
         const IterVar& new_iv = it->second;
-        const Variable* v = iv->var.get();
+        const VarNode* v = iv->var.get();
         if (!vmap_.count(v)) {
           vmap_[v] = new_iv->var;
         } else {
           CHECK(vmap_[v].same_as(new_iv->var));
         }
         Stmt body = this->VisitStmt(op->body);
-        return AttrStmt::make(
+        return AttrStmtNode::make(
             new_iv, op->attr_key, op->value, body);
       }
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = vmap_.find(op);
     if (it != vmap_.end()) return it->second;
     return StmtExprMutator::VisitExpr_(op);
@@ -71,14 +71,14 @@ class ThreadAxisRewriter : private StmtExprMutator {
   // The thread map
   const std::unordered_map<std::string, IterVar>& tmap_;
   // variable map
-  std::unordered_map<const Variable*, Var> vmap_;
+  std::unordered_map<const VarNode*, Var> vmap_;
 };
 
 LoweredFunc
 RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
   std::unordered_map<std::string, IterVar> tmap;
   for (const auto& kv : thread_map) {
-    const StringImm* str = kv.first.as<StringImm>();
+    const StringImmNode* str = kv.first.as<StringImmNode>();
     CHECK(str != nullptr);
     tmap[str->value] = kv.second;
   }
index 6891870..3c9114d 100644 (file)
@@ -32,28 +32,28 @@ namespace ir {
 // Mark the statment of each stage.
 class NoOpRemover : public StmtMutator {
  public:
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<LetStmt>();
+    op = stmt.as<LetStmtNode>();
     return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
   }
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == "pragma_debug_skip_region") {
       return MakeEvaluate(0);
     }
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<AttrStmt>();
+    op = stmt.as<AttrStmtNode>();
     return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
   }
-  Stmt VisitStmt_(const IfThenElse* op) final {
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<IfThenElse>();
+    op = stmt.as<IfThenElseNode>();
     if (op->else_case.defined()) {
       if (is_no_op(op->else_case)) {
         if (is_no_op(op->then_case)) {
           return MakeEvaluate(op->condition);
         } else {
-          return IfThenElse::make(op->condition, op->then_case);
+          return IfThenElseNode::make(op->condition, op->then_case);
         }
       } else {
         return stmt;
@@ -66,32 +66,32 @@ class NoOpRemover : public StmtMutator {
       }
     }
   }
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<For>();
+    op = stmt.as<ForNode>();
     if (is_zero(op->extent)) {
-      return Evaluate::make(0);
+      return EvaluateNode::make(0);
     }
     return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
   }
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
   }
-  Stmt VisitStmt_(const ProducerConsumer* op) final {
+  Stmt VisitStmt_(const ProducerConsumerNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<ProducerConsumer>();
+    op = stmt.as<ProducerConsumerNode>();
     return is_no_op(op->body) ? op->body : stmt;
   }
-  Stmt VisitStmt_(const Realize* op) final {
+  Stmt VisitStmt_(const RealizeNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<Realize>();
+    op = stmt.as<RealizeNode>();
     return is_no_op(op->body) ? op->body : stmt;
   }
-  Stmt VisitStmt_(const Evaluate* op) final {
+  Stmt VisitStmt_(const EvaluateNode* op) final {
     if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
-    return Evaluate::make(0);
+    return EvaluateNode::make(0);
   }
 
   Stmt VisitStmt_(const SeqStmtNode* op) final {
@@ -128,9 +128,9 @@ class NoOpRemover : public StmtMutator {
  private:
   Stmt MakeEvaluate(Expr value) {
     if (HasSideEffect(value)) {
-      return Evaluate::make(value);
+      return EvaluateNode::make(value);
     } else {
-      return Evaluate::make(0);
+      return EvaluateNode::make(0);
     }
   }
   Stmt MakeEvaluate(const Array<Expr>& values) {
@@ -138,13 +138,13 @@ class NoOpRemover : public StmtMutator {
     for (Expr e : values) {
       if (HasSideEffect(e)) {
         if (stmt.defined()) {
-          stmt = SeqStmt({stmt, Evaluate::make(e)});
+          stmt = SeqStmt({stmt, EvaluateNode::make(e)});
         } else {
-          stmt = Evaluate::make(e);
+          stmt = EvaluateNode::make(e);
         }
       }
     }
-    return stmt.defined() ? stmt : Evaluate::make(0);
+    return stmt.defined() ? stmt : EvaluateNode::make(0);
   }
 };
 
index 0a27671..c38fac1 100644 (file)
@@ -35,14 +35,14 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
  public:
   // select itself is always considered safe if condition is safe
   // Because we will issue guard to make sure it is.
-  bool VisitExpr_(const Select* op) {
+  bool VisitExpr_(const SelectNode* op) {
     return VisitExpr(op->condition);
   }
-  bool VisitExpr_(const Call* op) {
+  bool VisitExpr_(const CallNode* op) {
     if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
       return VisitExpr(op->args[0]);
     } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-      const Load* l = op->args[0].as<Load>();
+      const LoadNode* l = op->args[0].as<LoadNode>();
       return this->VisitExpr(l->index);
     } else if (op->is_pure()) {
       for (Expr e : op->args) {
@@ -53,53 +53,53 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
       return true;
     }
   }
-  bool VisitExpr_(const Load* op) {
+  bool VisitExpr_(const LoadNode* op) {
     // Load is considered unsafe.
     return true;
   }
-  bool VisitExpr_(const Add* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Sub* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const NE* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const LT* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const LE* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const GT* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const GE* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const And* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Or* op) final { return BinaryOp(op); }
-  bool VisitExpr_(const Not* op) final {
+  bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const NENode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const LENode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const GENode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); }
+  bool VisitExpr_(const NotNode* op) final {
     return VisitExpr(op->a);
   }
-  bool VisitExpr_(const Let* op) final {
+  bool VisitExpr_(const LetNode* op) final {
     return VisitExpr(op->body) || VisitExpr(op->value);
   }
-  bool VisitExpr_(const Cast* op) final {
+  bool VisitExpr_(const CastNode* op) final {
     return VisitExpr(op->value);
   }
-  bool VisitExpr_(const Broadcast* op) final {
+  bool VisitExpr_(const BroadcastNode* op) final {
     return VisitExpr(op->value);
   }
-  bool VisitExpr_(const Ramp* op) final {
+  bool VisitExpr_(const RampNode* op) final {
     return VisitExpr(op->base) && VisitExpr(op->stride);
   }
-  bool VisitExpr_(const Shuffle* op) final {
+  bool VisitExpr_(const ShuffleNode* op) final {
     for (Expr e : op->vectors) {
       if (VisitExpr(e)) return true;
     }
     return false;
   }
-  bool VisitExpr_(const Variable* op) final { return false; }
-  bool VisitExpr_(const UIntImm* op) final { return false; }
-  bool VisitExpr_(const IntImm* op) final { return false; }
-  bool VisitExpr_(const FloatImm* op) final { return false; }
-  bool VisitExpr_(const StringImm* op) final { return false; }
+  bool VisitExpr_(const VarNode* op) final { return false; }
+  bool VisitExpr_(const UIntImmNode* op) final { return false; }
+  bool VisitExpr_(const IntImmNode* op) final { return false; }
+  bool VisitExpr_(const FloatImmNode* op) final { return false; }
+  bool VisitExpr_(const StringImmNode* op) final { return false; }
 
  private:
   template<typename T>
@@ -110,19 +110,19 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
 
 class UnsafeSelectRewriter : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const Select* op) {
+  Expr VisitExpr_(const SelectNode* op) {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Select>();
+    op = expr.as<SelectNode>();
     UnsafeExprDetector unsafe;
     bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar();
     if ((unsafe.VisitExpr(op->true_value) ||
         unsafe.VisitExpr(op->false_value)) &&
         cond_is_scalar_bool) {
-      return Call::make(
+      return CallNode::make(
           op->dtype,
           intrinsic::tvm_if_then_else,
           {op->condition, op->true_value, op->false_value},
-          Call::Intrinsic);
+          CallNode::Intrinsic);
     } else {
       return expr;
     }
index e9ed893..3233e50 100644 (file)
@@ -35,7 +35,7 @@ class IRSideEffect : public ExprVisitor {
     ExprVisitor::VisitExpr(e);
   }
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     if (!op->is_pure()) {
       has_side_effect_ = true; return;
     } else {
@@ -55,11 +55,11 @@ bool HasSideEffect(const Expr& e) {
 class IRSubstitue : public StmtExprMutator {
  public:
   explicit IRSubstitue(
-      const std::unordered_map<const Variable*, Expr>& smap)
+      const std::unordered_map<const VarNode*, Expr>& smap)
       : smap_(smap) {
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = smap_.find(op);
     if (it != smap_.end()) {
       return it->second;
@@ -69,23 +69,23 @@ class IRSubstitue : public StmtExprMutator {
   }
 
  private:
-  const std::unordered_map<const Variable*, Expr>& smap_;
+  const std::unordered_map<const VarNode*, Expr>& smap_;
 };
 
 Stmt Substitute(Stmt stmt,
-                const std::unordered_map<const Variable*, Expr>& value_map) {
+                const std::unordered_map<const VarNode*, Expr>& value_map) {
   if (value_map.size() == 0) return stmt;
   return IRSubstitue(value_map)(std::move(stmt));
 }
 
 Expr Substitute(Expr expr,
-                const std::unordered_map<const Variable*, Expr>& value_map) {
+                const std::unordered_map<const VarNode*, Expr>& value_map) {
   if (value_map.size() == 0) return expr;
   return IRSubstitue(value_map)(std::move(expr));
 }
 
 Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
-  std::unordered_map<const Variable*, Expr> vmap;
+  std::unordered_map<const VarNode*, Expr> vmap;
   for (const auto& kv : value_map) {
     vmap[kv.first.get()] = kv.second;
   }
@@ -93,7 +93,7 @@ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
 }
 
 Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
-  std::unordered_map<const Variable*, Expr> vmap;
+  std::unordered_map<const VarNode*, Expr> vmap;
   for (const auto& kv : value_map) {
     vmap[kv.first.get()] = kv.second;
   }
@@ -107,43 +107,43 @@ class VarTouchVisitor : public ExprVisitor {
     ExprVisitor::VisitExpr(e);
   }
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     Handle(op);
   }
 
-  void VisitExpr_(const Load* op) final {
+  void VisitExpr_(const LoadNode* op) final {
     Handle(op->buffer_var.get());
     ExprVisitor::VisitExpr_(op);
   }
 
-  virtual void Handle(const Variable* var) = 0;
+  virtual void Handle(const VarNode* var) = 0;
 
   bool use_var_{false};
 };
 
 class ExprUseVarVisitor : public VarTouchVisitor {
  public:
-  explicit ExprUseVarVisitor(const Variable* var)
+  explicit ExprUseVarVisitor(const VarNode* var)
       : var_(var) {}
 
-  void Handle(const Variable* var) final {
+  void Handle(const VarNode* var) final {
     if (var == var_) use_var_ = true;
   }
  private:
-  const Variable* var_;
+  const VarNode* var_;
 };
 
 class ExprUseVSetVisitor : public VarTouchVisitor {
  public:
   explicit ExprUseVSetVisitor(
-      const std::unordered_set<const Variable*>& vset)
+      const std::unordered_set<const VarNode*>& vset)
       : vset_(vset) {}
 
-  void Handle(const Variable* var) final {
+  void Handle(const VarNode* var) final {
     if (vset_.count(var)) use_var_ = true;
   }
  private:
-  const std::unordered_set<const Variable*>& vset_;
+  const std::unordered_set<const VarNode*>& vset_;
 };
 
 bool ExprUseVar(const Expr& e, const Var& v) {
@@ -153,7 +153,7 @@ bool ExprUseVar(const Expr& e, const Var& v) {
 }
 
 bool ExprUseVar(const Expr& e,
-                const std::unordered_set<const Variable*>& vset) {
+                const std::unordered_set<const VarNode*>& vset) {
   ExprUseVSetVisitor visitor(vset);
   visitor(e);
   return visitor.use_var_;
index 47a158f..d2c4dc7 100644 (file)
@@ -26,9 +26,9 @@ namespace ir {
 
 class AssertSkipper : public StmtMutator {
  public:
-  Stmt VisitStmt_(const AssertStmt* op) final {
+  Stmt VisitStmt_(const AssertStmtNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<AssertStmt>();
+    op = stmt.as<AssertStmtNode>();
     return op->body;
   }
 };
index 2a7c75e..f71f13b 100644 (file)
@@ -34,7 +34,7 @@ namespace ir {
 // use/def analysis, also delete unreferenced lets
 class IRUseDefAnalysis : public StmtExprMutator {
  public:
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
       CHECK_NE(iv->thread_tag.length(), 0U);
@@ -54,13 +54,13 @@ class IRUseDefAnalysis : public StmtExprMutator {
       if (value.same_as(op->value) && body.same_as(op->body)) {
         return GetRef<Stmt>(op);
       }
-      return AttrStmt::make(op->node, op->attr_key, value, body);
+      return AttrStmtNode::make(op->node, op->attr_key, value, body);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
 
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     this->HandleDef(op->var.get());
     Stmt body = this->VisitStmt(op->body);
     // eliminate unreferenced let
@@ -73,27 +73,27 @@ class IRUseDefAnalysis : public StmtExprMutator {
           value.same_as(op->value)) {
         return GetRef<Stmt>(op);
       } else {
-        return LetStmt::make(op->var, value, body);
+        return LetStmtNode::make(op->var, value, body);
       }
     }
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     this->HandleDef(op->loop_var.get());
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     this->HandleDef(op->buffer_var.get());
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     this->HandleUse(op->buffer_var);
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Expr VisitExpr_(const Let* op) final {
+  Expr VisitExpr_(const LetNode* op) final {
     this->HandleDef(op->var.get());
     Expr body = this->VisitExpr(op->body);
     // eliminate unreferenced let
@@ -106,22 +106,22 @@ class IRUseDefAnalysis : public StmtExprMutator {
           value.same_as(op->value)) {
         return GetRef<Expr>(op);
       } else {
-        return Let::make(op->var, value, body);
+        return LetNode::make(op->var, value, body);
       }
     }
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     this->HandleUse(GetRef<Expr>(op));
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     this->HandleUse(op->buffer_var);
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  void HandleDef(const Variable* v) {
+  void HandleDef(const VarNode* v) {
     CHECK(!def_count_.count(v))
         << "variable " << v->name_hint
         << " has already been defined, the Stmt is not SSA";
@@ -133,7 +133,7 @@ class IRUseDefAnalysis : public StmtExprMutator {
   }
 
   void HandleUse(const Expr& v) {
-    CHECK(v.as<Variable>());
+    CHECK(v.as<VarNode>());
     Var var = Downcast<Var>(v);
     auto it = use_count_.find(var.get());
     if (it != use_count_.end()) {
@@ -152,18 +152,18 @@ class IRUseDefAnalysis : public StmtExprMutator {
   Array<Var> undefined_;
   Array<IterVar> thread_axis_;
   Array<Expr> thread_extent_;
-  std::unordered_map<const Variable*, int> use_count_;
-  std::unordered_map<const Variable*, int> def_count_;
+  std::unordered_map<const VarNode*, int> use_count_;
+  std::unordered_map<const VarNode*, int> def_count_;
 };
 
 class HostDeviceSplitter : public StmtMutator {
  public:
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
     return StmtMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent ||
         op->attr_key == attr::pipeline_exec_scope ||
         op->attr_key == attr::device_scope) {
@@ -219,7 +219,7 @@ class HostDeviceSplitter : public StmtMutator {
     }
     LoweredFunc f_device(n);
     Array<Expr> call_args;
-    call_args.push_back(StringImm::make(f_device->name));
+    call_args.push_back(StringImmNode::make(f_device->name));
     for (Var arg : n->args) {
       call_args.push_back(arg);
     }
@@ -227,16 +227,16 @@ class HostDeviceSplitter : public StmtMutator {
       call_args.push_back(ext);
     }
     device_funcs_.emplace_back(f_device);
-    return Evaluate::make(Call::make(
+    return EvaluateNode::make(CallNode::make(
         DataType::Int(32), intrinsic::tvm_call_packed,
-        call_args, Call::Intrinsic));
+        call_args, CallNode::Intrinsic));
   }
 
   // function name
   std::string name_;
   // the device functions
   std::vector<LoweredFunc> device_funcs_;
-  std::unordered_map<const Variable*, Expr> handle_data_type_;
+  std::unordered_map<const VarNode*, Expr> handle_data_type_;
 };
 
 
index 94f045d..3dafb40 100644 (file)
@@ -45,118 +45,118 @@ class IRVerifySSA final : public StmtExprVisitor {
     if (!is_ssa) return;
     StmtExprVisitor::VisitStmt(n);
   }
-  void VisitExpr_(const Let* op) final {
+  void VisitExpr_(const LetNode* op) final {
     MarkDef(op->var.get());
     StmtExprVisitor::VisitExpr_(op);
   }
-  void VisitStmt_(const LetStmt* op) final {
+  void VisitStmt_(const LetStmtNode* op) final {
     MarkDef(op->var.get());
     StmtExprVisitor::VisitStmt_(op);
   }
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     MarkDef(op->loop_var.get());
     StmtExprVisitor::VisitStmt_(op);
   }
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     MarkDef(op->buffer_var.get());
     StmtExprVisitor::VisitStmt_(op);
   }
 
  private:
-  void MarkDef(const Variable* v) {
+  void MarkDef(const VarNode* v) {
     if (defined_.count(v) != 0) {
       is_ssa = false; return;
     } else {
       defined_[v] = 1;
     }
   }
-  std::unordered_map<const Variable*, int> defined_;
+  std::unordered_map<const VarNode*, int> defined_;
 };
 
 
 class IRConvertSSA final : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     if (scope_.count(op)) {
       return scope_[op].back();
     } else {
       return GetRef<Expr>(op);
     }
   }
-  Expr VisitExpr_(const Let* op) final {
+  Expr VisitExpr_(const LetNode* op) final {
     const VarExpr& v = op->var;
     if (defined_.count(v.get())) {
       Expr value = this->VisitExpr(op->value);
-      VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Expr body = this->VisitExpr(op->body);
       scope_[v.get()].pop_back();
-      return Let::make(new_var, value, body);
+      return LetNode::make(new_var, value, body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitExpr_(op);
     }
   }
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     if (scope_.count(op->buffer_var.get())) {
-      return Load::make(
+      return LoadNode::make(
           op->dtype, scope_[op->buffer_var.get()].back(),
           op->index, op->predicate);
     } else {
       return expr;
     }
   }
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     if (scope_.count(op->buffer_var.get())) {
-      return Store::make(
+      return StoreNode::make(
           scope_[op->buffer_var.get()].back(), op->value,
           op->index, op->predicate);
     } else {
       return stmt;
     }
   }
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     const VarExpr& v = op->var;
     if (defined_.count(v.get())) {
       Expr value = this->VisitExpr(op->value);
-      VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt body = this->VisitStmt(op->body);
       scope_[v.get()].pop_back();
-      return LetStmt::make(new_var, value, body);
+      return LetStmtNode::make(new_var, value, body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     const VarExpr& v = op->loop_var;
     if (defined_.count(v.get())) {
-      VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
-      op = stmt.as<For>();
-      return For::make(
+      op = stmt.as<ForNode>();
+      return ForNode::make(
           new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     const VarExpr& v = op->buffer_var;
     if (defined_.count(v.get())) {
-      VarExpr new_var = Variable::make(v.dtype(), v->name_hint);
+      VarExpr new_var = VarNode::make(v.dtype(), v->name_hint);
       scope_[v.get()].push_back(new_var);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
-      op = stmt.as<Allocate>();
-      return Allocate::make(
+      op = stmt.as<AllocateNode>();
+      return AllocateNode::make(
           new_var, op->dtype, op->extents, op->condition,
           op->body, op->new_expr, op->free_function);
     } else {
@@ -164,23 +164,23 @@ class IRConvertSSA final : public StmtExprMutator {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const AttrStmt* op) final {
-    if (const Variable* v = op->node.as<Variable>()) {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
+    if (const VarNode* v = op->node.as<VarNode>()) {
       if (op->attr_key == attr::storage_scope) {
-        const Allocate* alloc = op->body.as<Allocate>();
+        const AllocateNode* alloc = op->body.as<AllocateNode>();
         if (alloc && op->node.same_as(alloc->buffer_var)) {
           Stmt new_alloc = this->VisitStmt(op->body);
           if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
-          alloc = new_alloc.as<Allocate>();
+          alloc = new_alloc.as<AllocateNode>();
           CHECK(alloc);
-          return AttrStmt::make(
+          return AttrStmtNode::make(
               alloc->buffer_var, op->attr_key, op->value, new_alloc);
         }
       }
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      op = stmt.as<AttrStmt>();
+      op = stmt.as<AttrStmtNode>();
       if (scope_.count(v) && scope_[v].size() != 0) {
-        return AttrStmt::make(
+        return AttrStmtNode::make(
             scope_[v].back(), op->attr_key, op->value, op->body);
       } else {
         return stmt;
@@ -191,8 +191,8 @@ class IRConvertSSA final : public StmtExprMutator {
   }
 
  private:
-  std::unordered_map<const Variable*, std::vector<VarExpr> > scope_;
-  std::unordered_set<const Variable*> defined_;
+  std::unordered_map<const VarNode*, std::vector<VarExpr> > scope_;
+  std::unordered_set<const VarNode*> defined_;
 };
 
 }  // namespace
index 0d59404..cb779f9 100644 (file)
@@ -31,8 +31,8 @@
 namespace tvm {
 namespace ir {
 
-void StorageAccessVisitor::VisitExpr_(const Load* op) {
-  const Variable* buf = op->buffer_var.as<Variable>();
+void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
+  const VarNode* buf = op->buffer_var.as<VarNode>();
   StorageScope scope = GetScope(buf);
   if (Enabled(buf, scope)) {
     CHECK(allow_append_);
@@ -49,11 +49,11 @@ void StorageAccessVisitor::VisitExpr_(const Load* op) {
   StmtExprVisitor::VisitExpr_(op);
 }
 
-void StorageAccessVisitor::VisitStmt_(const Store* op) {
+void StorageAccessVisitor::VisitStmt_(const StoreNode* op) {
   allow_append_ = true;
   CHECK_EQ(curr_stmt_.access.size(), 0U);
   curr_stmt_.stmt = op;
-  const Variable* buf = op->buffer_var.as<Variable>();
+  const VarNode* buf = op->buffer_var.as<VarNode>();
   StorageScope scope = GetScope(buf);
   if (Enabled(buf, scope)) {
     AccessEntry e;
@@ -74,7 +74,7 @@ void StorageAccessVisitor::VisitStmt_(const Store* op) {
   allow_append_ = false;
 }
 
-void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
+void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) {
   allow_append_ = true;
   CHECK_EQ(curr_stmt_.access.size(), 0U);
   curr_stmt_.stmt = op;
@@ -87,15 +87,15 @@ void StorageAccessVisitor::VisitStmt_(const Evaluate* op) {
   allow_append_ = false;
 }
 
-void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
+void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::storage_scope) {
-    const Variable* buf = op->node.as<Variable>();
+    const VarNode* buf = op->node.as<VarNode>();
     storage_scope_[buf] =
-        StorageScope::make(op->value.as<StringImm>()->value);
+        StorageScope::make(op->value.as<StringImmNode>()->value);
     StmtExprVisitor::VisitStmt_(op);
   } else if (op->attr_key == attr::double_buffer_write) {
     CHECK(double_buffer_write_ == nullptr);
-    double_buffer_write_ = op->node.as<Variable>();
+    double_buffer_write_ = op->node.as<VarNode>();
     scope_.push_back(std::vector<StmtEntry>());
     StmtExprVisitor::VisitStmt_(op);
     StmtEntry s;
@@ -136,7 +136,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) {
   }
 }
 
-void StorageAccessVisitor::VisitStmt_(const For* op) {
+void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
   scope_.push_back(std::vector<StmtEntry>());
   StmtExprVisitor::VisitStmt_(op);
   StmtEntry s;
@@ -145,7 +145,7 @@ void StorageAccessVisitor::VisitStmt_(const For* op) {
   scope_.pop_back();
   if (s.access.size() != 0) {
     // relax the touched set to contain all ranges in the loop.
-    std::unordered_map<const Variable*, arith::IntSet> relax_map;
+    std::unordered_map<const VarNode*, arith::IntSet> relax_map;
     relax_map[op->loop_var.get()] = arith::IntSet::range(
         Range::make_by_min_extent(op->min, op->extent));
     for (AccessEntry& e : s.access) {
@@ -160,7 +160,7 @@ void StorageAccessVisitor::VisitStmt_(const For* op) {
   }
 }
 
-void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
+void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
   ++condition_counter_;
   this->VisitExpr(op->condition);
   scope_.push_back(std::vector<StmtEntry>());
@@ -179,17 +179,17 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) {
   --condition_counter_;
 }
 
-void StorageAccessVisitor::VisitExpr_(const Call* op) {
+void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
   if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-    const Load *l = op->args[0].as<Load>();
+    const LoadNode *l = op->args[0].as<LoadNode>();
     StmtExprVisitor::VisitExpr_(l);
   } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
     CHECK_EQ(op->args.size(), 5U);
     DataType dtype = op->args[0].dtype();
-    const Variable* buffer = op->args[1].as<Variable>();
+    const VarNode* buffer = op->args[1].as<VarNode>();
     Expr offset = op->args[2];
     Expr extent = op->args[3];
-    const IntImm* flag = op->args[4].as<IntImm>();
+    const IntImmNode* flag = op->args[4].as<IntImmNode>();
     StorageScope scope = GetScope(buffer);
     // The buffer scope.
     if (Enabled(buffer, scope)) {
@@ -213,7 +213,7 @@ void StorageAccessVisitor::VisitExpr_(const Call* op) {
     StmtExprVisitor::VisitExpr_(op);
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
     CHECK(allow_append_);
-    const std::string& s = op->args[0].as<StringImm>()->value;
+    const std::string& s = op->args[0].as<StringImmNode>()->value;
     if (s != "warp") {
       StorageScope scope = StorageScope::make(s);
       AccessEntry e;
@@ -227,7 +227,7 @@ void StorageAccessVisitor::VisitExpr_(const Call* op) {
   }
 }
 
-StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
+StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
   auto it = storage_scope_.find(buf);
   StorageScope s;
   s.rank = StorageRank::kGlobal;
@@ -238,10 +238,10 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
 
 class StorageAccessInfoLower : public StmtExprMutator {
  public:
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     // Lower allocate to device allocate when needed.
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     // For special memory, remove allocate, or use head expr
     auto it = storage_info_.find(op->buffer_var.get());
     if (it != storage_info_.end() && it->second.info.defined()) {
@@ -250,7 +250,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
       CHECK_LE(it->second.alloc_count, 1)
           << "Double allocation of " << it->second.scope.to_string();
       if (info->head_address.defined()) {
-        return Allocate::make(
+        return AllocateNode::make(
             op->buffer_var, op->dtype, op->extents, op->condition,
             op->body, info->head_address, "nop");
       }
@@ -259,14 +259,14 @@ class StorageAccessInfoLower : public StmtExprMutator {
       return stmt;
     }
   }
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
-      StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
+      const VarNode* buf = op->node.as<VarNode>();
+      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
       StorageEntry e;
       e.scope = scope;
       if (scope.tag.length() != 0) {
-        e.info = GetMemoryInfo(op->value.as<StringImm>()->value);
+        e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
         CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
       }
       storage_info_[buf] = e;
@@ -277,7 +277,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       return MakeAccessPtr(op);
     } else {
@@ -287,13 +287,13 @@ class StorageAccessInfoLower : public StmtExprMutator {
 
  private:
   // tvm_access_ptr
-  Expr MakeAccessPtr(const Call* op) {
+  Expr MakeAccessPtr(const CallNode* op) {
     // Specially handle the buffer packed intrinsic
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
+    op = expr.as<CallNode>();
     CHECK_EQ(op->args.size(), 5U);
     DataType dtype = op->args[0].dtype();
-    const Variable* buffer = op->args[1].as<Variable>();
+    const VarNode* buffer = op->args[1].as<VarNode>();
     Var buffer_var = Downcast<Var>(op->args[1]);
     Expr offset = op->args[2];
     auto it = storage_info_.find(buffer);
@@ -333,7 +333,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
     int alloc_count{0};
   };
   // The storage scope of each buffer
-  std::unordered_map<const Variable*, StorageEntry> storage_info_;
+  std::unordered_map<const VarNode*, StorageEntry> storage_info_;
 };
 
 Stmt LowerStorageAccessInfo(Stmt stmt) {
index 12bf9f3..80400ad 100644 (file)
@@ -76,13 +76,13 @@ class StorageAccessVisitor : public StmtExprVisitor {
     std::vector<AccessEntry> access;
   };
   // override visitor pattern
-  void VisitExpr_(const Load* op) final;
-  void VisitStmt_(const Store* op) final;
-  void VisitStmt_(const Evaluate* op) final;
-  void VisitStmt_(const AttrStmt* op) final;
-  void VisitStmt_(const For* op) final;
-  void VisitStmt_(const IfThenElse* op) final;
-  void VisitExpr_(const Call* op) final;
+  void VisitExpr_(const LoadNode* op) final;
+  void VisitStmt_(const StoreNode* op) final;
+  void VisitStmt_(const EvaluateNode* op) final;
+  void VisitStmt_(const AttrStmtNode* op) final;
+  void VisitStmt_(const ForNode* op) final;
+  void VisitStmt_(const IfThenElseNode* op) final;
+  void VisitExpr_(const CallNode* op) final;
 
  protected:
   StorageAccessVisitor() {
@@ -106,7 +106,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
    * \param scope The scope of the buffer.
    * \return Whether the analysis of buffer is enabled.
    */
-  virtual bool Enabled(const Variable* buffer,
+  virtual bool Enabled(const VarNode* buffer,
                        const StorageScope& scope) const {
     return true;
   }
@@ -122,12 +122,12 @@ class StorageAccessVisitor : public StmtExprVisitor {
    *  the parent should taken care of to synchronize.
    */
   virtual std::vector<AccessEntry> Summarize(
-      std::vector<StmtEntry> seq, const For* loop) = 0;
+      std::vector<StmtEntry> seq, const ForNode* loop) = 0;
   /*!
    * \brief Get the scope of the buffer array.
    * \return The scope of the final buffer array.
    */
-  StorageScope GetScope(const Variable* buf) const;
+  StorageScope GetScope(const VarNode* buf) const;
   // access scope
   std::vector<std::vector<StmtEntry> > scope_;
 
@@ -139,13 +139,13 @@ class StorageAccessVisitor : public StmtExprVisitor {
   // Whether we are inside condition.
   int condition_counter_{0};
   // The current double buffer write scope.
-  const Variable* double_buffer_write_{nullptr};
+  const VarNode* double_buffer_write_{nullptr};
   // the current free stmt entry.
   StmtEntry curr_stmt_;
   // The involving threads
   Array<IterVar> env_threads_;
   // The storage scope of each buffer
-  std::unordered_map<const Variable*, StorageScope> storage_scope_;
+  std::unordered_map<const VarNode*, StorageScope> storage_scope_;
 };
 
 }  // namespace ir
index 6bb3fc5..ea828ff 100644 (file)
@@ -63,23 +63,23 @@ class StorageFlattener : public StmtExprMutator {
     cache_line_size_ = cache_line_size;
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     auto it = var_remap_.find(op->buffer_var.get());
     if (it != var_remap_.end() &&
         !it->second.same_as(op->buffer_var)) {
-      CHECK(it->second.as<Variable>());
+      CHECK(it->second.as<VarNode>());
       VarExpr buf_var = Downcast<VarExpr>(it->second);
-      return Store::make(buf_var, op->value, op->index, op->predicate);
+      return StoreNode::make(buf_var, op->value, op->index, op->predicate);
     } else {
       return stmt;
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::realize_scope) {
-      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
       return this->VisitStmt(op->body);
     } else if (op->attr_key == attr::double_buffer_scope &&
                op->node->IsInstance<OperationNode>()) {
@@ -90,7 +90,7 @@ class StorageFlattener : public StmtExprMutator {
         auto it = buf_map_.find(key);
         CHECK(it != buf_map_.end())
             << "Cannot find allocated buffer for " << key.f;
-        body = AttrStmt::make(
+        body = AttrStmtNode::make(
             it->second.buffer->data, op->attr_key, op->value, body);
       }
       return body;
@@ -105,16 +105,16 @@ class StorageFlattener : public StmtExprMutator {
       return HandleBufferBindScope(op);
     } else if (op->attr_key == attr::buffer_dim_align) {
       Tensor tensor = Downcast<Tensor>(op->node);
-      const Call* tuple = op->value.as<Call>();
+      const CallNode* tuple = op->value.as<CallNode>();
       CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
       TensorKey key{tensor->op, tensor->value_index};
       auto& vinfo = dim_align_[key];
-      int dim = tuple->args[0].as<IntImm>()->value;
+      int dim = tuple->args[0].as<IntImmNode>()->value;
       if (static_cast<size_t>(dim) >= vinfo.size()) {
         vinfo.resize(dim + 1);
       }
-      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
-      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
       return this->VisitStmt(op->body);
     } else if (op->attr_key == attr::opengl_stage_scope) {
       is_opengl_ = true;
@@ -122,11 +122,11 @@ class StorageFlattener : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const Provide* op) final {
+  Stmt VisitStmt_(const ProvideNode* op) final {
     if (create_bound_attributes_)
       shape_collector_.clear();
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Provide>();
+    op = stmt.as<ProvideNode>();
     TensorKey key{op->func, op->value_index};
     auto it = buf_map_.find(key);
     CHECK(it != buf_map_.end())
@@ -135,11 +135,11 @@ class StorageFlattener : public StmtExprMutator {
     CHECK(!e.released)
         << "Read a buffer that is already out of scope";
     if (is_opengl_) {
-      return Evaluate::make(Call::make(
+      return EvaluateNode::make(CallNode::make(
           DataType(),
-          Call::glsl_texture_store,
+          CallNode::glsl_texture_store,
           {e.buffer->data, op->value},
-          Call::Intrinsic));
+          CallNode::Intrinsic));
     } else {
       Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
       if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
@@ -149,7 +149,7 @@ class StorageFlattener : public StmtExprMutator {
       // To create bound attribute collector should has at least one item.
       if (create_bound_attributes_ && shape_collector_.size()) {
         for (size_t i = 0; i < shape_collector_.size(); ++i) {
-          body = AttrStmt::make(
+          body = AttrStmtNode::make(
               shape_collector_[i].first, ir::attr::buffer_bound,
               MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
         }
@@ -158,7 +158,7 @@ class StorageFlattener : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Realize* op) final {
+  Stmt VisitStmt_(const RealizeNode* op) final {
     TensorKey key{op->func, op->value_index};
     if (buf_map_.count(key)) {
       CHECK(buf_map_.at(key).external);
@@ -188,7 +188,7 @@ class StorageFlattener : public StmtExprMutator {
       }
 
       // use small alignment for small arrays
-      int32_t const_size = Allocate::constant_allocation_size(shape);
+      int32_t const_size = AllocateNode::constant_allocation_size(shape);
       int align = GetTempAllocaAlignment(op->dtype, const_size);
       if (skey.tag.length() != 0) {
         MemoryInfo info = GetMemoryInfo(skey.to_string());
@@ -237,7 +237,7 @@ class StorageFlattener : public StmtExprMutator {
       }
       if (strides.size() != 0) {
         int first_dim = 0;
-        ret = Allocate::make(
+        ret = AllocateNode::make(
             e.buffer->data, storage_type,
             {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
             make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
@@ -246,37 +246,37 @@ class StorageFlattener : public StmtExprMutator {
         if (shape.size() == 0) {
           shape.push_back(make_const(DataType::Int(32), 1));
         }
-        ret = Allocate::make(
+        ret = AllocateNode::make(
             e.buffer->data, storage_type, shape,
             make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
       }
-      ret = AttrStmt::make(
+      ret = AttrStmtNode::make(
           e.buffer->data, attr::storage_scope,
-          StringImm::make(e.buffer->scope), ret);
+          StringImmNode::make(e.buffer->scope), ret);
 
       if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
-        ret = AttrStmt::make(e.buffer->data, ir::attr::buffer_bound,
+        ret = AttrStmtNode::make(e.buffer->data, ir::attr::buffer_bound,
                              MakeBound(e.buffer->dtype, e.buffer->shape), ret);
       }
       return ret;
     }
   }
 
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     auto it = var_remap_.find(op->buffer_var.get());
     if (it != var_remap_.end() &&
         !it->second.same_as(op->buffer_var)) {
-      CHECK(it->second.as<Variable>());
+      CHECK(it->second.as<VarNode>());
       VarExpr buf_var = Downcast<VarExpr>(it->second);
-      return Load::make(op->dtype, buf_var, op->index, op->predicate);
+      return LoadNode::make(op->dtype, buf_var, op->index, op->predicate);
     } else {
       return expr;
     }
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = var_remap_.find(op);
     if (it != var_remap_.end()) {
       return it->second;
@@ -285,10 +285,10 @@ class StorageFlattener : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Call>();
-    if (op != nullptr && op->call_type == Call::Halide) {
+    op = expr.as<CallNode>();
+    if (op != nullptr && op->call_type == CallNode::Halide) {
       TensorKey key{op->func, op->value_index};
       auto it = buf_map_.find(key);
       CHECK(it != buf_map_.end())
@@ -307,9 +307,9 @@ class StorageFlattener : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Prefetch *op) final {
+  Stmt VisitStmt_(const PrefetchNode *op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Prefetch>();
+    op = stmt.as<PrefetchNode>();
     CHECK(op != nullptr);
     TensorKey key{op->func, op->value_index};
     auto it = buf_map_.find(key);
@@ -351,15 +351,17 @@ class StorageFlattener : public StmtExprMutator {
     }
     for (int i = starts; i >= 0; --i) {
       if (i < starts) {
-        stmt = For::make(
+        stmt = ForNode::make(
             vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
       } else {
         Expr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
-        Expr address = Call::make(DataType::Handle(), tvm_address_of, {load}, Call::PureIntrinsic);
-        Expr prefetch = Call::make(op->dtype, Call::prefetch, {address, 0, 3, 1}, Call::Intrinsic);
-        stmt = Evaluate::make(prefetch);
+        Expr address = CallNode::make(
+            DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
+        Expr prefetch = CallNode::make(
+            op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
+        stmt = EvaluateNode::make(prefetch);
         Expr extent = (op->bounds[i]->extent - 1) / stride + 1;
-        stmt = For::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
+        stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
       }
     }
     return stmt;
@@ -400,12 +402,12 @@ class StorageFlattener : public StmtExprMutator {
   //
   // We do support a few relaxed case, such as bindingx
   // region with shape [1, 1, n, m] to buffer with shape [n, m]
-  Stmt HandleBufferBindScope(const AttrStmt* op) {
+  Stmt HandleBufferBindScope(const AttrStmtNode* op) {
     Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
     CHECK_EQ(arr.size(), 2U);
     const BufferNode* buffer = arr[0].as<BufferNode>();
     const TensorNode* tensor = arr[1].as<TensorNode>();
-    const Call* tuple = op->value.as<Call>();
+    const CallNode* tuple = op->value.as<CallNode>();
     CHECK(buffer && tensor);
     CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
     TensorKey key{tensor->op, tensor->value_index};
@@ -495,17 +497,17 @@ class StorageFlattener : public StmtExprMutator {
 
   Expr MakeBound(const DataType &type, const Array<Expr> &shape) {
     // We have already checked the shape size to be greater then 0.
-    Expr bound = Mul::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
+    Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]);
     for (size_t i = 1; i < shape.size(); ++i) {
-      bound = Mul::make(
-          bound, Mul::make(make_const(bound.dtype(), type.lanes()), shape[i]));
+      bound = MulNode::make(
+          bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i]));
     }
     return bound;
   }
 
   // The buffer assignment map
   // Variable remap
-  std::unordered_map<const Variable*, Expr> var_remap_;
+  std::unordered_map<const VarNode*, Expr> var_remap_;
   // Buffer map
   std::unordered_map<TensorKey, BufferEntry> buf_map_;
   // Dimension alignment
index c820c47..928be4b 100644 (file)
@@ -65,7 +65,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
     // if offset < 0, means this is the end, the begin entry is current_index + offset
     int64_t scope_pair_offset{0};
     // The buffer variables this statment touched.
-    std::vector<const Variable*> touched;
+    std::vector<const VarNode*> touched;
   };
   // The scope of each allocation
   struct AllocEntry {
@@ -74,12 +74,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
     // scope level
     size_t level{0};
     // allocation stmt
-    const Allocate* alloc{nullptr};
+    const AllocateNode* alloc{nullptr};
   };
 
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     size_t level = scope_.size();
-    const Variable* buf = op->buffer_var.get();
+    const VarNode* buf = op->buffer_var.get();
     auto it = alloc_info_.find(buf);
     CHECK(it != alloc_info_.end());
     CHECK(it->second.alloc == nullptr);
@@ -87,12 +87,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
     it->second.level = level;
     StmtExprVisitor::VisitStmt_(op);
   }
-  void VisitStmt_(const Store* op) final {
+  void VisitStmt_(const StoreNode* op) final {
     scope_.push_back(StmtEntry());
     // visit subexpr
     StmtExprVisitor::VisitStmt_(op);
     // Add write access.
-    const Variable* buf = op->buffer_var.get();
+    const VarNode* buf = op->buffer_var.get();
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       CHECK_LT(it->second.level, scope_.size());
@@ -105,7 +105,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
       linear_seq_.push_back(e);
     }
   }
-  void VisitStmt_(const Evaluate* op) final {
+  void VisitStmt_(const EvaluateNode* op) final {
     scope_.push_back(StmtEntry());
     // visit subexpr
     StmtExprVisitor::VisitStmt_(op);
@@ -116,10 +116,10 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
       linear_seq_.push_back(e);
     }
   }
-  void VisitExpr_(const Load* op) final {
+  void VisitExpr_(const LoadNode* op) final {
     // Add write access.
     StmtExprVisitor::VisitExpr_(op);
-    const Variable* buf = op->buffer_var.get();
+    const VarNode* buf = op->buffer_var.get();
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
       CHECK_LT(it->second.level, scope_.size())
@@ -127,15 +127,15 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
       scope_[it->second.level].touched.push_back(buf);
     }
   }
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_address_of)) {
-      const Load* l = op->args[0].as<Load>();
+      const LoadNode* l = op->args[0].as<LoadNode>();
       this->VisitExpr(l->index);
     } else {
       StmtExprVisitor::VisitExpr_(op);
     }
   }
-  void VisitExpr_(const Variable* buf) final {
+  void VisitExpr_(const VarNode* buf) final {
     // Directly reference to the variable count as a read.
     auto it = alloc_info_.find(buf);
     if (it != alloc_info_.end() && it->second.alloc) {
@@ -164,7 +164,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
     CHECK_NE(end_index, 0U);
     linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
   }
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     // Only record the outer most thread extent.
     if (op->attr_key == attr::thread_extent && !in_thread_env_) {
       in_thread_env_ = true;
@@ -175,30 +175,30 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
     } else if (op->attr_key == attr::virtual_thread) {
       VisitNewScope(op);
     } else if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
+      const VarNode* buf = op->node.as<VarNode>();
       alloc_info_[buf].storage_scope =
-          StorageScope::make(op->value.as<StringImm>()->value);
+          StorageScope::make(op->value.as<StringImmNode>()->value);
       StmtExprVisitor::VisitStmt_(op);
     } else {
       StmtExprVisitor::VisitStmt_(op);
     }
   }
-  void VisitStmt_(const IfThenElse* op) final {
+  void VisitStmt_(const IfThenElseNode* op) final {
     VisitNewScope(op);
   }
 
-  void VisitStmt_(const For* op) final {
+  void VisitStmt_(const ForNode* op) final {
     VisitNewScope(op);
   }
 
-  void VisitStmt_(const AssertStmt* op) final {
+  void VisitStmt_(const AssertStmtNode* op) final {
     VisitNewScope(op);
   }
 
   // linearized access sequence.
   std::vector<StmtEntry> linear_seq_;
   // The storage scope of each buffer
-  std::unordered_map<const Variable*, AllocEntry> alloc_info_;
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
 
  private:
   // Whether already in thread env.
@@ -236,19 +236,19 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
 class InplaceOpVerifier : public StmtExprVisitor {
  public:
   bool Check(const Object* stmt,
-             const Variable* dst,
-             const Variable* src) {
+             const VarNode* dst,
+             const VarNode* src) {
     dst_ = dst;
     src_ = src;
     result_ = true;
-    if (stmt->IsInstance<AttrStmt>()) {
-      VisitStmt_(static_cast<const AttrStmt*>(stmt));
-    } else if (stmt->IsInstance<For>()) {
-      VisitStmt_(static_cast<const For*>(stmt));
-    } else if (stmt->IsInstance<IfThenElse>()) {
-      VisitStmt_(static_cast<const IfThenElse*>(stmt));
-    } else if (stmt->IsInstance<Store>()) {
-      VisitStmt_(static_cast<const Store*>(stmt));
+    if (stmt->IsInstance<AttrStmtNode>()) {
+      VisitStmt_(static_cast<const AttrStmtNode*>(stmt));
+    } else if (stmt->IsInstance<ForNode>()) {
+      VisitStmt_(static_cast<const ForNode*>(stmt));
+    } else if (stmt->IsInstance<IfThenElseNode>()) {
+      VisitStmt_(static_cast<const IfThenElseNode*>(stmt));
+    } else if (stmt->IsInstance<StoreNode>()) {
+      VisitStmt_(static_cast<const StoreNode*>(stmt));
     } else {
       return false;
     }
@@ -266,14 +266,14 @@ class InplaceOpVerifier : public StmtExprVisitor {
     StmtExprVisitor::VisitExpr(n);
   }
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     // assume all opaque access is unsafe
     if (op == dst_ || op == src_) {
       result_ = false; return;
     }
   }
 
-  void VisitStmt_(const Store* op) final {
+  void VisitStmt_(const StoreNode* op) final {
     ++mem_nest_;
     this->VisitExpr(op->index);
     --mem_nest_;
@@ -288,7 +288,7 @@ class InplaceOpVerifier : public StmtExprVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     // always reject extern code
     if (op->attr_key == attr::extern_scope ||
         op->attr_key == attr::volatile_scope) {
@@ -297,8 +297,8 @@ class InplaceOpVerifier : public StmtExprVisitor {
     StmtExprVisitor::VisitStmt_(op);
   }
 
-  void VisitExpr_(const Load* op) final {
-    const Variable* buf = op->buffer_var.get();
+  void VisitExpr_(const LoadNode* op) final {
+    const VarNode* buf = op->buffer_var.get();
     // cannot read from dst_ (no reduction)
     if (buf == dst_) {
       result_ = false; return;
@@ -324,14 +324,14 @@ class InplaceOpVerifier : public StmtExprVisitor {
   // result of the check
   bool result_{true};
   // destination memory
-  const Variable* dst_;
+  const VarNode* dst_;
   // source variable
-  const Variable* src_;
+  const VarNode* src_;
   // counter of load,
   // it is not safe to inplace when there is nested load like A[B[i]]
   int mem_nest_{0};
   // The current store to be inspected
-  const Store* store_{nullptr};
+  const StoreNode* store_{nullptr};
 };
 
 // Planner to plan and rewrite memory allocation.
@@ -355,10 +355,10 @@ class StoragePlanRewriter : public StmtExprMutator {
       for (StorageEntry* e : attach_map_.at(nullptr)) {
         // CHECK_EQ(e->scope.rank, 0);
         if (e->new_alloc.defined()) {
-          nest.emplace_back(AttrStmt::make(
+          nest.emplace_back(AttrStmtNode::make(
               e->alloc_var, attr::storage_scope,
-              StringImm::make(e->scope.to_string()),
-              Evaluate::make(0)));
+              StringImmNode::make(e->scope.to_string()),
+              EvaluateNode::make(0)));
           nest.push_back(e->new_alloc);
         }
       }
@@ -366,27 +366,27 @@ class StoragePlanRewriter : public StmtExprMutator {
     }
     return stmt;
   }
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     auto it = alloc_map_.find(op->buffer_var.get());
     if (it == alloc_map_.end()) return stmt;
-    return Store::make(it->second->alloc_var,
+    return StoreNode::make(it->second->alloc_var,
                        op->value,
                        RemapIndex(op->value.dtype(), op->index, it->second),
                        op->predicate);
   }
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     auto it = alloc_map_.find(op->buffer_var.get());
     if (it == alloc_map_.end()) return expr;
-    return Load::make(op->dtype,
+    return LoadNode::make(op->dtype,
                       it->second->alloc_var,
                       RemapIndex(op->dtype, op->index, it->second),
                       op->predicate);
   }
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = alloc_map_.find(op);
     if (it != alloc_map_.end()) {
       if (it->second->bits_offset != 0) {
@@ -397,11 +397,11 @@ class StoragePlanRewriter : public StmtExprMutator {
       return GetRef<Expr>(op);
     }
   }
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       CHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
-      const Variable* buffer = op->args[1].as<Variable>();
+      const VarNode* buffer = op->args[1].as<VarNode>();
       auto it = alloc_map_.find(buffer);
       if (it == alloc_map_.end()) {
         return StmtExprMutator::VisitExpr_(op);
@@ -414,7 +414,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       if (se->bits_offset != 0) {
         offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
       }
-      return Call::make(
+      return CallNode::make(
           op->dtype, op->name,
           {op->args[0], se->alloc_var, offset, extent, op->args[4]},
           op->call_type);
@@ -423,7 +423,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
       return this->VisitStmt(op->body);
     } else if (op->attr_key == attr::thread_extent ||
@@ -433,8 +433,8 @@ class StoragePlanRewriter : public StmtExprMutator {
       if (attach_map_.count(op)) {
         auto& svec = attach_map_[op];
         Stmt stmt = StmtExprMutator::VisitStmt_(op);
-        op = stmt.as<AttrStmt>();
-        return AttrStmt::make(
+        op = stmt.as<AttrStmtNode>();
+        return AttrStmtNode::make(
             op->node, op->attr_key, op->value,
             MakeAttach(svec, op->body));
       } else {
@@ -442,24 +442,24 @@ class StoragePlanRewriter : public StmtExprMutator {
       }
     } else if (op->attr_key == attr::volatile_scope) {
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      op = stmt.as<AttrStmt>();
-      auto it = alloc_map_.find(op->node.as<Variable>());
+      op = stmt.as<AttrStmtNode>();
+      auto it = alloc_map_.find(op->node.as<VarNode>());
       if (it == alloc_map_.end()) return stmt;
-      return AttrStmt::make(
+      return AttrStmtNode::make(
           it->second->alloc_var, op->attr_key, op->value, op->body);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     CHECK(op->for_type != ForType::Vectorized)
         << "VectorizeLoop before LiftStorageAlloc";
     // remake all the allocation at the attach scope.
     if (attach_map_.count(op)) {
       auto& svec = attach_map_[op];
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      op = stmt.as<For>();
-      return For::make(
+      op = stmt.as<ForNode>();
+      return ForNode::make(
           op->loop_var, op->min, op->extent, op->for_type, op->device_api,
           MakeAttach(svec, op->body));
     } else {
@@ -467,7 +467,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     return this->VisitStmt(op->body);
   }
 
@@ -482,7 +482,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     // The storage scope.
     StorageScope scope;
     // Allocs that shares this entry.
-    std::vector<const Allocate*> allocs;
+    std::vector<const AllocateNode*> allocs;
     // The children of this entry, not including itself.
     std::vector<StorageEntry*> merged_children;
     // The replacement allocation, if any.
@@ -509,9 +509,9 @@ class StoragePlanRewriter : public StmtExprMutator {
   // Event entry in liveness analysis
   struct EventEntry {
     // variables we generate
-    std::vector<const Variable*> gen;
+    std::vector<const VarNode*> gen;
     // variables we kill
-    std::vector<const Variable*> kill;
+    std::vector<const VarNode*> kill;
   };
 
   Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
@@ -519,10 +519,10 @@ class StoragePlanRewriter : public StmtExprMutator {
     std::vector<Stmt> nest;
     for (StorageEntry* e : svec) {
       if (e->new_alloc.defined()) {
-        nest.emplace_back(AttrStmt::make(
+        nest.emplace_back(AttrStmtNode::make(
             e->alloc_var, attr::storage_scope,
-            StringImm::make(e->scope.to_string()),
-            Evaluate::make(0)));
+            StringImmNode::make(e->scope.to_string()),
+            EvaluateNode::make(0)));
         nest.push_back(e->new_alloc);
       }
     }
@@ -570,18 +570,18 @@ class StoragePlanRewriter : public StmtExprMutator {
         // Get the allocation size;
         e->alloc_var = e->allocs[0]->buffer_var;
         DataType alloc_type = e->allocs[0]->dtype;
-        for (const Allocate* op : e->allocs) {
+        for (const AllocateNode* op : e->allocs) {
           if (op->dtype.lanes() > alloc_type.lanes()) {
             alloc_type = op->dtype;
           }
         }
         if (e->allocs.size() == 1) {
           // simply use the original allocation.
-          Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
+          Expr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
                                               make_const(DataType::Int(32), 1));
-          e->new_alloc = Allocate::make(
+          e->new_alloc = AllocateNode::make(
               e->alloc_var, alloc_type, {sz},
-              e->allocs[0]->condition, Evaluate::make(0));
+              e->allocs[0]->condition, EvaluateNode::make(0));
           if (e->scope.tag.length() != 0) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -591,10 +591,10 @@ class StoragePlanRewriter : public StmtExprMutator {
         } else {
           // Build a merged allocation
           Expr combo_size;
-          for (const Allocate* op : e->allocs) {
-            Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(DataType::Int(32), 1));
+          for (const AllocateNode* op : e->allocs) {
+            Expr sz = arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
             auto nbits = op->dtype.bits() * op->dtype.lanes();
-            if (const auto* imm = sz.as<IntImm>()) {
+            if (const auto* imm = sz.as<IntImmNode>()) {
               if (imm->value > std::numeric_limits<int>::max() / nbits) {
                 LOG(WARNING) << "The allocation requires : " << imm->value
                              << " * " << nbits
@@ -621,9 +621,9 @@ class StoragePlanRewriter : public StmtExprMutator {
             combo_size = combo_size + make_const(DataType::Int(32), 1);
           }
           combo_size = ir::Simplify(combo_size);
-          e->new_alloc = Allocate::make(
+          e->new_alloc = AllocateNode::make(
               e->alloc_var, alloc_type, {combo_size}, const_true(),
-              Evaluate::make(0));
+              EvaluateNode::make(0));
           if (e->scope.tag.length() != 0) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -665,9 +665,9 @@ class StoragePlanRewriter : public StmtExprMutator {
     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
     Expr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
                                  (total_bits + type_bits - 1) / type_bits);
-    e->new_alloc = Allocate::make(
+    e->new_alloc = AllocateNode::make(
         e->alloc_var, e->elem_type, {alloc_size}, const_true(),
-        Evaluate::make(0));
+        EvaluateNode::make(0));
     if (info.defined()) {
       CHECK_LE(total_bits, info->max_num_bits)
           << "Allocation exceed bound of memory tag " << e->scope.to_string();
@@ -676,10 +676,10 @@ class StoragePlanRewriter : public StmtExprMutator {
   // Liveness analysis to find gen and kill point of each variable.
   void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
     // find kill point, do a reverse linear scan.
-    std::unordered_set<const Variable*> touched;
+    std::unordered_set<const VarNode*> touched;
     for (size_t i = seq.size(); i != 0; --i) {
       const StmtEntry& s = seq[i - 1];
-      for (const Variable* buffer : s.touched) {
+      for (const VarNode* buffer : s.touched) {
         if (!touched.count(buffer)) {
           touched.insert(buffer);
           event_map_[s.stmt].kill.push_back(buffer);
@@ -692,7 +692,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       int64_t offset = seq[i].scope_pair_offset;
       if (offset < 0) continue;
       const StmtEntry& s = seq[i + offset];
-      for (const Variable* buffer : s.touched) {
+      for (const VarNode* buffer : s.touched) {
         if (!touched.count(buffer)) {
           touched.insert(buffer);
           event_map_[s.stmt].gen.push_back(buffer);
@@ -726,8 +726,8 @@ class StoragePlanRewriter : public StmtExprMutator {
 
   // Memory plan algorithm
   void PlanMemory(const std::vector<StmtEntry>& seq,
-                  const std::unordered_map<const Variable*, AllocEntry>& alloc_info) {
-    std::unordered_set<const Variable*> inplace_flag;
+                  const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) {
+    std::unordered_set<const VarNode*> inplace_flag;
 
     for (size_t i = 0; i < seq.size(); ++i) {
       const StmtEntry& s = seq[i];
@@ -742,7 +742,7 @@ class StoragePlanRewriter : public StmtExprMutator {
         // specially handle this
         bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
 
-        for (const Variable* var : it->second.gen) {
+        for (const VarNode* var : it->second.gen) {
           CHECK(alloc_info.count(var));
           const AllocEntry& ae = alloc_info.at(var);
           StorageEntry* dst_entry = nullptr;
@@ -750,7 +750,7 @@ class StoragePlanRewriter : public StmtExprMutator {
           if (detect_inplace) {
             // only one inplace var for s.stmt
             bool inplace_found = false;
-            for (const Variable* src : it->second.kill) {
+            for (const VarNode* src : it->second.kill) {
               if (!inplace_flag.count(src) && alloc_map_.count(src)) {
                 InplaceOpVerifier visitor;
                 StorageEntry* src_entry = alloc_map_.at(src);
@@ -780,8 +780,8 @@ class StoragePlanRewriter : public StmtExprMutator {
         }
       }
       // enter/exit new scope
-      if (s.stmt->IsInstance<AttrStmt>()) {
-        const auto* op = static_cast<const AttrStmt*>(s.stmt);
+      if (s.stmt->IsInstance<AttrStmtNode>()) {
+        const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
         if (op->attr_key == attr::thread_extent ||
             op->attr_key == attr::virtual_thread ||
             attr::IsPragmaKey(op->attr_key)) {
@@ -789,8 +789,8 @@ class StoragePlanRewriter : public StmtExprMutator {
         } else {
           CHECK(op->attr_key == attr::extern_scope);
         }
-      } else if (s.stmt->IsInstance<For>()) {
-        const auto* op = static_cast<const For*>(s.stmt);
+      } else if (s.stmt->IsInstance<ForNode>()) {
+        const auto* op = static_cast<const ForNode*>(s.stmt);
         if (op->for_type == ForType::Parallel) {
           if (thread_scope_ == nullptr || thread_scope_ == op) {
             PlanNewScope(op);
@@ -802,7 +802,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       // - end of scope(offset < 0)
       // In both cases, we need to handle the kill event correctly
       if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
-        for (const Variable* var : it->second.kill) {
+        for (const VarNode* var : it->second.kill) {
           // skip space which are already replaced by inplace
           if (!inplace_flag.count(var)) {
             this->Free(var);
@@ -812,7 +812,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     }
   }
   // Allocate new storage entry.
-  StorageEntry* NewAlloc(const Allocate* op,
+  StorageEntry* NewAlloc(const AllocateNode* op,
                          const Object* attach_scope,
                          const StorageScope& scope,
                          size_t const_nbits) {
@@ -828,7 +828,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     return e;
   }
 
-  StorageEntry* FindAlloc(const Allocate* op,
+  StorageEntry* FindAlloc(const AllocateNode* op,
                           const Object* attach_scope,
                           const StorageScope& scope) {
     CHECK(op != nullptr);
@@ -890,7 +890,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     return NewAlloc(op, attach_scope, scope, const_nbits);
   }
   // simulated free.
-  void Free(const Variable* var) {
+  void Free(const VarNode* var) {
     auto it = alloc_map_.find(var);
     CHECK(it != alloc_map_.end());
     StorageEntry* e = it->second;
@@ -925,7 +925,7 @@ class StoragePlanRewriter : public StmtExprMutator {
   // The allocation attach map
   std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
   // The allocation assign map
-  std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
+  std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
   // The allocations
   std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
   // analyzer
@@ -936,27 +936,27 @@ class StoragePlanRewriter : public StmtExprMutator {
 // if all its access is the same vector type.
 class VectorAllocRewriter : public StmtExprMutator {
  public:
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     UpdateTypeMap(op->buffer_var.get(), op->dtype);
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     UpdateTypeMap(op->buffer_var.get(), op->value.dtype());
     return StmtExprMutator::VisitStmt_(op);
   }
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       DataType dtype = op->args[0].dtype();
-      const Variable* buffer = op->args[1].as<Variable>();
+      const VarNode* buffer = op->args[1].as<VarNode>();
       UpdateTypeMap(buffer, dtype);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Allocate>();
+    op = stmt.as<AllocateNode>();
     const auto& tvec = acc_map_[op->buffer_var.get()];
 
     if (tvec.size() == 1 &&
@@ -969,7 +969,7 @@ class VectorAllocRewriter : public StmtExprMutator {
       if (me->base % factor == 0 && me->coeff % factor == 0) {
         extents.Set(extents.size() - 1,
                     extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
-        return Allocate::make(
+        return AllocateNode::make(
             op->buffer_var, tvec[0], extents,
             op->condition, op->body);
       }
@@ -977,7 +977,7 @@ class VectorAllocRewriter : public StmtExprMutator {
     return stmt;
   }
 
-  void UpdateTypeMap(const Variable* buffer, DataType t) {
+  void UpdateTypeMap(const VarNode* buffer, DataType t) {
     auto& tvec = acc_map_[buffer];
     if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
       tvec.push_back(t);
@@ -985,7 +985,7 @@ class VectorAllocRewriter : public StmtExprMutator {
   }
 
   // Internal access map
-  std::unordered_map<const Variable*, std::vector<DataType> > acc_map_;
+  std::unordered_map<const VarNode*, std::vector<DataType> > acc_map_;
   // internal analyzer
   arith::Analyzer analyzer_;
 };
index 85cf2b9..7edf98b 100644 (file)
@@ -41,13 +41,13 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
   std::unordered_set<const Object*> syncs_inserted_;
 
  protected:
-  bool Enabled(const Variable* buf,
+  bool Enabled(const VarNode* buf,
                const StorageScope& scope) const final {
     return in_device_env() && scope == sync_scope_;
   }
   // Plan the sync
   std::vector<AccessEntry> Summarize(
-      std::vector<StmtEntry> seq, const For* loop) final {
+      std::vector<StmtEntry> seq, const ForNode* loop) final {
     // Unsynced reads and writes
     std::vector<AccessEntry> reads;
     std::vector<AccessEntry> writes;
@@ -209,10 +209,10 @@ class ThreadSyncInserter : public StmtExprMutator {
       if (sync_scope_.rank == StorageRank::kGlobal) {
         barrier = MakeGlobalBarrier();
       } else {
-        barrier = Evaluate::make(
-                Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
-                           {StringImm::make(sync_scope_.to_string())},
-                           Call::Intrinsic));
+        barrier = EvaluateNode::make(
+                CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+                           {StringImmNode::make(sync_scope_.to_string())},
+                           CallNode::Intrinsic));
       }
       // Mutate after query, to avoid stmt change.
       auto ret = StmtExprMutator::VisitStmt(stmt);
@@ -222,21 +222,21 @@ class ThreadSyncInserter : public StmtExprMutator {
       return StmtExprMutator::VisitStmt(stmt);
     }
   }
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     if (sync_scope_.rank == StorageRank::kGlobal &&
         GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
       ++rw_stats_[op->buffer_var].read_count;
     }
     return StmtExprMutator::VisitExpr_(op);
   }
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     if (sync_scope_.rank == StorageRank::kGlobal &&
         GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
       ++rw_stats_[op->buffer_var].write_count;
     }
     return StmtExprMutator::VisitStmt_(op);
   }
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       bool temp = true;
       std::swap(temp, in_thread_env_);
@@ -246,29 +246,29 @@ class ThreadSyncInserter : public StmtExprMutator {
       std::swap(temp, in_thread_env_);
       // first thread scope.
       if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
-        ret = InitGlobalBarrier(ret.as<AttrStmt>());
+        ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
         num_blocks_ = Expr();
         is_lead_ = Expr();
       }
       return ret;
     } else if (op->attr_key == attr::storage_scope) {
-      const Variable* buf = op->node.as<Variable>();
+      const VarNode* buf = op->node.as<VarNode>();
       storage_scope_[buf] =
-          StorageScope::make(op->value.as<StringImm>()->value);
+          StorageScope::make(op->value.as<StringImmNode>()->value);
       return StmtExprMutator::VisitStmt_(op);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
 
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
       Expr expr = StmtExprMutator::VisitExpr_(op);
-      op = expr.as<Call>();
+      op = expr.as<CallNode>();
       CHECK_EQ(op->args.size(), 5U);
-      const Variable* buffer_var = op->args[1].as<Variable>();
+      const VarNode* buffer_var = op->args[1].as<VarNode>();
       Var var(GetRef<Var>(buffer_var));
-      const IntImm* flag = op->args[4].as<IntImm>();
+      const IntImmNode* flag = op->args[4].as<IntImmNode>();
       if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
           GetScope(buffer_var).rank == StorageRank::kGlobal) {
         ++rw_stats_[var].read_count;
@@ -290,7 +290,7 @@ class ThreadSyncInserter : public StmtExprMutator {
     int write_count{0};
   };
   // Get current storage scope.
-  StorageScope GetScope(const Variable* buf) const {
+  StorageScope GetScope(const VarNode* buf) const {
     auto it = storage_scope_.find(buf);
     StorageScope s;
     s.rank = StorageRank::kGlobal;
@@ -298,23 +298,25 @@ class ThreadSyncInserter : public StmtExprMutator {
     return it->second;
   }
   // private functions.
-  Stmt InitGlobalBarrier(const AttrStmt* op) {
+  Stmt InitGlobalBarrier(const AttrStmtNode* op) {
     CHECK(op != nullptr);
-    Array<Expr> pargs = {StringImm::make(runtime::symbol::tvm_prepare_global_barrier)};
-    Stmt prep = Evaluate::make(
-        Call::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, Call::Intrinsic));
+    Array<Expr> pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)};
+    Stmt prep = EvaluateNode::make(
+        CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
     Stmt body = op->body;
     for (const auto& kv : rw_stats_) {
       const auto& e = kv.second;
       if (e.read_count != 0 && e.write_count != 0) {
-        body = AttrStmt::make(kv.first, attr::volatile_scope, 1, body);
+        body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body);
       }
     }
     rw_stats_.clear();
-    Stmt kinit = Evaluate::make(
-        Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic));
+    Stmt kinit = EvaluateNode::make(
+        CallNode::make(
+            DataType::Int(32),
+            intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
     body = SeqStmt({kinit, body});
-    body = AttrStmt::make(
+    body = AttrStmtNode::make(
         op->node, op->attr_key, op->value, body);
     return SeqStmt({prep, body});
   }
@@ -323,7 +325,7 @@ class ThreadSyncInserter : public StmtExprMutator {
     if (!num_blocks_.defined()) {
       CHECK(!is_lead_.defined());
       num_work_dim_ = thread_extents_.size();
-      for (const AttrStmt* attr : thread_extents_) {
+      for (const AttrStmtNode* attr : thread_extents_) {
         IterVar iv = Downcast<IterVar>(attr->node);
         runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
         if (s.rank == 0) {
@@ -337,23 +339,23 @@ class ThreadSyncInserter : public StmtExprMutator {
     } else {
       CHECK_EQ(num_work_dim_, thread_extents_.size());
     }
-    return Evaluate::make(
-        Call::make(DataType::Int(32), intrinsic::tvm_storage_sync,
-                   {StringImm::make(sync_scope_.to_string()),
+    return EvaluateNode::make(
+        CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync,
+                   {StringImmNode::make(sync_scope_.to_string()),
                     is_lead_, num_blocks_},
-                   Call::Intrinsic));
+                   CallNode::Intrinsic));
   }
   // data structure.
   StorageScope sync_scope_;
   const std::unordered_set<const Object*>& syncs_;
   // The storage scope of each buffer
-  std::unordered_map<const Variable*, StorageScope> storage_scope_;
+  std::unordered_map<const VarNode*, StorageScope> storage_scope_;
   // The read write statistics of storage
   std::unordered_map<VarExpr, Entry, ObjectHash, ObjectEqual> rw_stats_;
   // The statistics for global barrier
   bool in_thread_env_{false};
   // memorized results
-  std::vector<const AttrStmt*> thread_extents_;
+  std::vector<const AttrStmtNode*> thread_extents_;
   size_t num_work_dim_{0};
   Expr num_blocks_;
   Expr is_lead_;
index a3890cd..b2658d9 100644 (file)
@@ -60,7 +60,7 @@ std::string simplify_name(std::string input) {
 }
 
 Expr unpack_type_cast(const Expr &input, const DataType &target_type) {
-  auto cast = input.as<Cast>();
+  auto cast = input.as<CastNode>();
   if (cast == nullptr) {
     return input;
   } else if (cast->dtype == target_type) {
@@ -84,19 +84,19 @@ class MMAMatcher: public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::pragma_tensor_core) {
       tensor_core_on_ = true;
       StmtVisitor::VisitStmt_(op);
     } else if (op->attr_key == attr::realize_scope) {
-      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
       this->VisitStmt(op->body);
     } else {
       StmtVisitor::VisitStmt_(op);
     }
   }
 
-  void VisitStmt_(const Provide* op) final {
+  void VisitStmt_(const ProvideNode* op) final {
     StmtVisitor::VisitStmt_(op);
     auto it = buf_map_.find(TensorKey{op->func, op->value_index});
     if (it == buf_map_.end()) {
@@ -111,7 +111,7 @@ class MMAMatcher: public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const Realize* op) final {
+  void VisitStmt_(const RealizeNode* op) final {
     TensorKey key{op->func, op->value_index};
     if (buf_map_.count(key)) {
       if (!buf_map_.at(key).external) {
@@ -149,8 +149,8 @@ class MMAMatcher: public StmtVisitor {
   };
 
   // Check whether the storage scope is local
-  bool check_local_buffer_(const Call* op, BufferInfo* bi) {
-    if (op->call_type == Call::Halide) {
+  bool check_local_buffer_(const CallNode* op, BufferInfo* bi) {
+    if (op->call_type == CallNode::Halide) {
       auto it = storage_scope_.find(op->func.get());
       if (it == storage_scope_.end()) {
         return false;
@@ -173,13 +173,13 @@ class MMAMatcher: public StmtVisitor {
   }
 
   // Do the pattern matching
-  bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) {
-    auto* add = op->value.as<Add>();
+  bool mma_sync_match_(const ProvideNode* op, BufferInfo store_buffer) {
+    auto* add = op->value.as<AddNode>();
     if (add == nullptr) {
       return false;
     }
 
-    auto* load_c = add->a.as<Call>();
+    auto* load_c = add->a.as<CallNode>();
     BufferInfo buffer_c;
     if (!check_local_buffer_(load_c, &buffer_c)
         || !buffer_c.same_as(store_buffer)
@@ -188,13 +188,13 @@ class MMAMatcher: public StmtVisitor {
       return false;
     }
 
-    auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<Mul>();
+    auto mul = unpack_type_cast(add->b, buffer_c.dtype).as<MulNode>();
     if (mul == nullptr) {
       return false;
     }
 
     auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype);
-    auto load_a = load_a_expr.as<Call>();
+    auto load_a = load_a_expr.as<CallNode>();
     BufferInfo buffer_a;
     if (!check_local_buffer_(load_a, &buffer_a)
         || !(buffer_a.dtype == DataType::Float(16) ||
@@ -203,7 +203,7 @@ class MMAMatcher: public StmtVisitor {
     }
 
     auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype);
-    auto load_b = load_b_expr.as<Call>();
+    auto load_b = load_b_expr.as<CallNode>();
     BufferInfo buffer_b;
     if (!check_local_buffer_(load_b, &buffer_b)
         || !(buffer_b.dtype == DataType::Float(16) ||
@@ -224,7 +224,7 @@ class MMAMatcher: public StmtVisitor {
 
   std::unordered_map<TensorKey, BufferInfo> buf_map_;
   std::unordered_map<const Object*, std::string> storage_scope_;
-  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
   std::unordered_map<const Object*, std::string> buf_name_;
   std::unordered_set<std::string> frag_reg_;
   bool matched_{false};
@@ -238,14 +238,14 @@ class BodyVisitor : public StmtExprVisitor {
  public:
   BodyVisitor() {}
 
-  void VisitExpr_(const Reduce* op) final {
-    auto* comm_add = op->combiner->result[0].as<Add>();
+  void VisitExpr_(const ReduceNode* op) final {
+    auto* comm_add = op->combiner->result[0].as<AddNode>();
     if (comm_add == nullptr || op->combiner->result.size() > 1) {
       return;
     }
     for (Expr source : op->source) {
-      auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<Mul>();
-      auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<Mul>();
+      auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as<MulNode>();
+      auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as<MulNode>();
       if (mul_0 == nullptr && mul_1 == nullptr) {
         continue;
       }
@@ -255,7 +255,7 @@ class BodyVisitor : public StmtExprVisitor {
     }
   }
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
     args_.insert(std::make_pair(op->name, op->args));
   }
@@ -287,11 +287,11 @@ class ScheduleAnalyser {
       if (axis.size() < 2 || reduce_axis.size() != 1) {
         continue;
       }
-      const Variable* axis_var[2];
-      const Variable* reduce_axis_var;
-      axis_var[0] = axis[axis.size()-2]->var.as<Variable>();
-      axis_var[1] = axis[axis.size()-1]->var.as<Variable>();
-      reduce_axis_var = reduce_axis[0]->var.as<Variable>();
+      const VarNode* axis_var[2];
+      const VarNode* reduce_axis_var;
+      axis_var[0] = axis[axis.size()-2]->var.as<VarNode>();
+      axis_var[1] = axis[axis.size()-1]->var.as<VarNode>();
+      reduce_axis_var = reduce_axis[0]->var.as<VarNode>();
 
       BodyVisitor body_visitor;
       for (Expr expr : compute->body) {
@@ -306,8 +306,8 @@ class ScheduleAnalyser {
         if (args.size() < 2) {
           continue;
         }
-        const Variable* var0 = args[args.size() - 2].as<Variable>();
-        const Variable* var1 = args[args.size() - 1].as<Variable>();
+        const VarNode* var0 = args[args.size() - 2].as<VarNode>();
+        const VarNode* var1 = args[args.size() - 1].as<VarNode>();
         if (var0 == nullptr || var1 == nullptr) {
           continue;
         }
@@ -334,8 +334,8 @@ class ScheduleAnalyser {
 
     for (auto &mma_sync : mma_sync_) {
       auto &operands = mma_sync.second;
-      auto* load_a = operands[0].as<Call>();
-      auto* load_b = operands[1].as<Call>();
+      auto* load_a = operands[0].as<CallNode>();
+      auto* load_b = operands[1].as<CallNode>();
       auto input0 = simplify_name(buf_name_.find(load_a)->second);
       auto input1 = simplify_name(buf_name_.find(load_b)->second);
       auto it0 = matrix_abc_.find(input0);
@@ -361,7 +361,7 @@ class ScheduleAnalyser {
  private:
   std::unordered_map<std::string, std::string> matrix_abc_;
   std::unordered_map<std::string, std::string> matrix_major_;
-  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
   std::unordered_map<const Object*, std::string> buf_name_;
 };
 
@@ -371,7 +371,7 @@ class IndexVisitor : public StmtExprVisitor {
  public:
   IndexVisitor() {}
 
-  void VisitExpr_(const Variable* op) final {
+  void VisitExpr_(const VarNode* op) final {
     loop_scaling_.insert(std::make_pair(op, scaling_factor_));
   }
 
@@ -379,7 +379,7 @@ class IndexVisitor : public StmtExprVisitor {
   friend class TensorCoreIRMutator;
 
  private:
-  std::unordered_map<const Variable*, unsigned> loop_scaling_;
+  std::unordered_map<const VarNode*, unsigned> loop_scaling_;
   unsigned scaling_factor_{0};
 };
 
@@ -404,9 +404,9 @@ class BufferAnalyser : public StmtExprVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
-      if (const IntImm* value = op->value.as<IntImm>()) {
+      if (const IntImmNode* value = op->value.as<IntImmNode>()) {
         thread_extent_.insert(
             std::make_pair(
                 op->node.as<IterVarNode>()->var->name_hint,
@@ -414,26 +414,26 @@ class BufferAnalyser : public StmtExprVisitor {
       }
       StmtExprVisitor::VisitStmt_(op);
     } else if (op->attr_key == attr::realize_scope) {
-      storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
+      storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
       this->VisitStmt(op->body);
     } else if (op->attr_key == attr::buffer_dim_align) {
       Tensor tensor = Downcast<Tensor>(op->node);
-      const Call* tuple = op->value.as<Call>();
+      const CallNode* tuple = op->value.as<CallNode>();
       CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
       auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}];
-      size_t dim = tuple->args[0].as<IntImm>()->value;
+      size_t dim = tuple->args[0].as<IntImmNode>()->value;
       if (dim >= vinfo.size()) {
         vinfo.resize(dim + 1);
       }
-      vinfo[dim].align_factor = tuple->args[1].as<IntImm>()->value;
-      vinfo[dim].align_offset = tuple->args[2].as<IntImm>()->value;
+      vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
+      vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
       this->VisitStmt(op->body);
     } else {
       StmtExprVisitor::VisitStmt_(op);
     }
   }
 
-  void VisitStmt_(const Provide* op) final {
+  void VisitStmt_(const ProvideNode* op) final {
     StmtExprVisitor::VisitStmt_(op);
     TensorKey key{op->func, op->value_index};
     auto it = buf_map_.find(key);
@@ -449,7 +449,7 @@ class BufferAnalyser : public StmtExprVisitor {
         return;
       }
       for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
-        const IntImm* shape = bi.shape[i].as<IntImm>();
+        const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
         if (shape == nullptr || shape->value % 16 != 0) {
           invalid_ = true;
           return;
@@ -462,9 +462,9 @@ class BufferAnalyser : public StmtExprVisitor {
       strides = bi.strides;
     } else {
       for (size_t i = 1; i < bi.shape.size(); ++i) {
-        Expr stride = IntImm::make(DataType::Int(32), 1);
+        Expr stride = IntImmNode::make(DataType::Int(32), 1);
         for (size_t j = bi.shape.size() - 1; j >= i; --j) {
-          stride = Mul::make(stride, bi.shape[j]);
+          stride = MulNode::make(stride, bi.shape[j]);
         }
         strides.push_back(stride);
       }
@@ -473,10 +473,10 @@ class BufferAnalyser : public StmtExprVisitor {
     strides_.insert(std::make_pair(key.GetName(), strides));
 
     if (frag_reg_.count(bi.name)) {
-      Expr dst = Call::make(bi.dtype,
+      Expr dst = CallNode::make(bi.dtype,
                             bi.name,
                             op->args,
-                            Call::Halide,
+                            CallNode::Halide,
                             op->func,
                             0);
       frag_load_.insert(std::make_pair(op, dst));
@@ -489,7 +489,7 @@ class BufferAnalyser : public StmtExprVisitor {
       std::vector<int> tile_size;
       for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
         index_visitor.scaling_factor_ = 16;
-        if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+        if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
           tile_size.push_back(shape->value);
           index_visitor.scaling_factor_ = shape->value;
         } else {
@@ -533,21 +533,21 @@ class BufferAnalyser : public StmtExprVisitor {
       }
     }
 
-    const Call* value = op->value.as<Call>();
+    const CallNode* value = op->value.as<CallNode>();
     if (value != nullptr && frag_reg_.count(value->name)) {
-      Expr dst = Call::make(bi.dtype,
+      Expr dst = CallNode::make(bi.dtype,
                             bi.name,
                             op->args,
-                            Call::Halide,
+                            CallNode::Halide,
                             op->func,
                             0);
       frag_store_.insert(std::make_pair(op, dst));
     }
   }
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
-    if (op->call_type == Call::Halide) {
+    if (op->call_type == CallNode::Halide) {
       TensorKey key{op->func, op->value_index};
       auto it = buf_map_.find(key);
       CHECK(it != buf_map_.end())
@@ -562,7 +562,7 @@ class BufferAnalyser : public StmtExprVisitor {
           return;
         }
         for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) {
-          const IntImm* shape = bi.shape[i].as<IntImm>();
+          const IntImmNode* shape = bi.shape[i].as<IntImmNode>();
           if (shape == nullptr || shape->value % 16 != 0) {
             invalid_ = true;
             return;
@@ -575,9 +575,9 @@ class BufferAnalyser : public StmtExprVisitor {
         strides = bi.strides;
       } else {
         for (size_t i = 1; i < bi.shape.size(); ++i) {
-          Expr stride = IntImm::make(DataType::Int(32), 1);
+          Expr stride = IntImmNode::make(DataType::Int(32), 1);
           for (size_t j = bi.shape.size() - 1; j >= i; --j) {
-            stride = Mul::make(stride, bi.shape[j]);
+            stride = MulNode::make(stride, bi.shape[j]);
           }
           strides.push_back(stride);
         }
@@ -596,7 +596,7 @@ class BufferAnalyser : public StmtExprVisitor {
       }
       for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) {
         index_visitor.scaling_factor_ = 16;
-        if (const IntImm* shape = bi.shape[i].as<IntImm>()) {
+        if (const IntImmNode* shape = bi.shape[i].as<IntImmNode>()) {
           index_visitor.scaling_factor_ = shape->value;
         }
         auto index = rel_index[i];
@@ -606,7 +606,7 @@ class BufferAnalyser : public StmtExprVisitor {
     }
   }
 
-  void VisitStmt_(const Realize* op) final {
+  void VisitStmt_(const RealizeNode* op) final {
     TensorKey key{op->func, op->value_index};
     if (buf_map_.count(key)) {
       CHECK(buf_map_.at(key).external);
@@ -745,8 +745,8 @@ class BufferAnalyser : public StmtExprVisitor {
   std::unordered_map<std::string, std::string> matrix_major_;
   std::unordered_set<std::string> frag_reg_;
   std::unordered_map<std::string, Array<Expr>> strides_;
-  std::unordered_map<const Provide*, Expr> frag_load_;
-  std::unordered_map<const Provide*, Expr> frag_store_;
+  std::unordered_map<const ProvideNode*, Expr> frag_load_;
+  std::unordered_map<const ProvideNode*, Expr> frag_store_;
   std::unordered_map<std::string, int> thread_extent_;
   IndexVisitor index_visitor;
   Tile warp_tile_;
@@ -760,17 +760,17 @@ class ThreadIdxMutator : public StmtExprMutator {
  public:
   explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {}
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Variable>();
+    op = expr.as<VarNode>();
     if (op != nullptr) {
       if (op->name_hint == "threadIdx.x") {
-        Expr zero = IntImm::make(DataType::Int(32), 0);
+        Expr zero = IntImmNode::make(DataType::Int(32), 0);
         return zero;
       }
       if (op->name_hint == "threadIdx.y") {
-        Expr div = Div::make(expr, warp_y_);
-        Expr mul = Mul::make(div, warp_y_);
+        Expr div = DivNode::make(expr, warp_y_);
+        Expr mul = MulNode::make(div, warp_y_);
         return mul;
       }
     }
@@ -798,11 +798,11 @@ class TensorCoreIRMutator : public StmtExprMutator {
       warp_tile_(buffer_analyser.warp_tile_),
       warp_threads_y_(buffer_analyser.warp_threads_y_) {}
 
-  Stmt VisitStmt_(const Realize* op) final {
+  Stmt VisitStmt_(const RealizeNode* op) final {
     TensorKey key{op->func, op->value_index};
     bounds_[key] = op->bounds;
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Realize>();
+    op = stmt.as<RealizeNode>();
     if (op != nullptr) {
       if (!frag_reg_.count(key.GetName())) {
         return stmt;
@@ -821,14 +821,14 @@ class TensorCoreIRMutator : public StmtExprMutator {
       new_bounds.push_back(Range::make_by_min_extent(
           op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
 
-      return Realize::make(op->func, op->value_index,
+      return RealizeNode::make(op->func, op->value_index,
                            op->dtype, new_bounds,
                            op->condition, op->body);
     }
     return stmt;
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     if (op->attr_key == attr::realize_scope) {
       auto node = op->node.as<OperationNode>();
@@ -842,7 +842,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
               << "Cannot find matrix info for " << node->name;
         auto matrix_abc = "wmma." + it->second;
         Stmt body = this->VisitStmt(op->body);
-        return AttrStmt::make(op->node,
+        return AttrStmtNode::make(op->node,
                               op->attr_key,
                               matrix_abc,
                               body);
@@ -851,17 +851,17 @@ class TensorCoreIRMutator : public StmtExprMutator {
     return stmt;
   }
 
-  Stmt VisitStmt_(const Provide* op) final {
+  Stmt VisitStmt_(const ProvideNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     auto it = mma_sync_.find(op);
     if (it != mma_sync_.end()) {
       const auto &operands = it->second;
       Expr a = operands[0];
-      auto ca = a.as<Call>();
+      auto ca = a.as<CallNode>();
       Expr b = operands[1];
-      auto cb = b.as<Call>();
+      auto cb = b.as<CallNode>();
       Expr c = operands[2];
-      auto cc = c.as<Call>();
+      auto cc = c.as<CallNode>();
 
       ObjectPtr<BufferNode> buffer_node_a = make_object<BufferNode>();
       ObjectPtr<BufferNode> buffer_node_b = make_object<BufferNode>();
@@ -872,14 +872,14 @@ class TensorCoreIRMutator : public StmtExprMutator {
         (const Buffer &buffer) {
           Buffer buffer_a(buffer_node_a);
           Buffer buffer_b(buffer_node_b);
-          return Evaluate::make(
-                  Call::make(DataType::Handle(),
+          return EvaluateNode::make(
+                  CallNode::make(DataType::Handle(),
                         intrinsic::tvm_mma_sync,
                         {buffer->data, buffer->elem_offset,
                         buffer_a->data, buffer_a->elem_offset,
                         buffer_b->data, buffer_b->elem_offset,
                         buffer->data, buffer->elem_offset},
-                        Call::Intrinsic));
+                        CallNode::Intrinsic));
         };
 
       auto call_add_c =
@@ -901,19 +901,19 @@ class TensorCoreIRMutator : public StmtExprMutator {
     auto it2 = frag_load_.find(op);
     if (it2 != frag_load_.end()) {
       Expr dst = it2->second;
-      if (op->value.as<FloatImm>() != nullptr ||
-          op->value.as<IntImm>() != nullptr) {
-        auto call = dst.as<Call>();
+      if (op->value.as<FloatImmNode>() != nullptr ||
+          op->value.as<IntImmNode>() != nullptr) {
+        auto call = dst.as<CallNode>();
 
         auto fill_fragment_call =
           [this, &op](const Buffer &buffer) {
-            return Evaluate::make(
-                    Call::make(DataType::Handle(),
+            return EvaluateNode::make(
+                    CallNode::make(DataType::Handle(),
                               intrinsic::tvm_fill_fragment,
                               {buffer->data,
                               warp_tile_.m, warp_tile_.n, warp_tile_.k,
                               buffer->elem_offset, op->value},
-                              Call::Intrinsic));
+                              CallNode::Intrinsic));
           };
 
         ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -922,7 +922,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
                                       fill_fragment_call, call->dtype);
       }
 
-      const Call* value = op->value.as<Call>();
+      const CallNode* value = op->value.as<CallNode>();
       CHECK(value != nullptr)
           << "Can only load fragment from a buffer";
 
@@ -934,36 +934,36 @@ class TensorCoreIRMutator : public StmtExprMutator {
       Expr stride = strides[strides.size()-2];
 
       // thread index unification inside a warp
-      Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
+      Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       Expr mutated_value = thread_idx_mutator(op->value);
-      Expr src = Call::make(value->dtype,
+      Expr src = CallNode::make(value->dtype,
                             "&",
                             {mutated_value},
-                            Call::Extern);
+                            CallNode::Extern);
 
-      auto call = dst.as<Call>();
+      auto call = dst.as<CallNode>();
       Expr matrix_major;
       auto iter2 = matrix_major_.find(simplify_name(call->name));
       CHECK(iter2 != matrix_major_.end())
           << "Can not determine matrix major for " << call->name;
       if (iter2->second == "col_major") {
-        matrix_major = StringImm::make("col_major");
+        matrix_major = StringImmNode::make("col_major");
       } else if (iter2->second == "row_major") {
-        matrix_major = StringImm::make("row_major");
+        matrix_major = StringImmNode::make("row_major");
       } else {
         LOG(FATAL) << "invalid matrix major for " << call->name;
       }
 
       auto load_matrix_call =
         [this, &src, &stride, &matrix_major](const Buffer &buffer) {
-        return Evaluate::make(
-                Call::make(DataType::Handle(),
+        return EvaluateNode::make(
+                CallNode::make(DataType::Handle(),
                           intrinsic::tvm_load_matrix_sync,
                           {buffer->data,
                           warp_tile_.m, warp_tile_.n, warp_tile_.k,
                           buffer->elem_offset, src, stride, matrix_major},
-                          Call::Intrinsic));
+                          CallNode::Intrinsic));
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -984,26 +984,26 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
       Expr dst = it3->second;
       // thread index unification inside a warp
-      Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_);
+      Expr warp_y = IntImmNode::make(DataType::Int(32), warp_threads_y_);
       ThreadIdxMutator thread_idx_mutator(warp_y);
       dst = thread_idx_mutator(dst);
-      dst = Call::make(DataType::Handle(),
+      dst = CallNode::make(DataType::Handle(),
                        "&",
                        {dst},
-                       Call::Extern);
+                       CallNode::Extern);
 
-      auto call = op->value.as<Call>();
+      auto call = op->value.as<CallNode>();
 
       auto store_matrix_call =
         [this, &dst, &stride](const Buffer &buffer) {
-          return Evaluate::make(
-                  Call::make(DataType::Handle(),
+          return EvaluateNode::make(
+                  CallNode::make(DataType::Handle(),
                             intrinsic::tvm_store_matrix_sync,
                             {buffer->data,
                             warp_tile_.m, warp_tile_.n, warp_tile_.k,
                             buffer->elem_offset, dst, stride,
-                            StringImm::make("col_major")},
-                            Call::Intrinsic));
+                            StringImmNode::make("col_major")},
+                            CallNode::Intrinsic));
         };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -1015,20 +1015,20 @@ class TensorCoreIRMutator : public StmtExprMutator {
     return stmt;
   }
 
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<For>();
+    op = stmt.as<ForNode>();
     if (op != nullptr) {
       auto it = loop_scaling_.find(op->loop_var.get());
       if (it != loop_scaling_.end()) {
         int scale_factor = it->second;
         int scaled_extent_value = 1;
-        if (const IntImm *ori_extent = op->extent.as<IntImm>()) {
+        if (const IntImmNode *ori_extent = op->extent.as<IntImmNode>()) {
           int ori_extent_value = ori_extent->value;
           scaled_extent_value = ori_extent_value / scale_factor;
         }
         Expr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
-        stmt = For::make(op->loop_var, op->min, scaled_extent, op->for_type,
+        stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type,
           op->device_api, op->body);
       }
     }
@@ -1067,7 +1067,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       return tile_size;
   }
 
-  Stmt add_buffer_bind_scope_(const Call* call,
+  Stmt add_buffer_bind_scope_(const CallNode* call,
       const ObjectPtr<BufferNode> &buffer_node, const TensorKey &key,
       const std::function<Stmt(const Buffer &buffer)> &call_back,
       DataType datatype) {
@@ -1089,26 +1089,26 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
     Array<Expr> strides;
     for (size_t i = 1; i < shape.size(); ++i) {
-      Expr stride = IntImm::make(DataType::Int(32), 1);
+      Expr stride = IntImmNode::make(DataType::Int(32), 1);
       for (size_t j = shape.size() - 1; j >= i; --j) {
-        stride = Mul::make(stride, shape[j]);
+        stride = MulNode::make(stride, shape[j]);
       }
       strides.push_back(stride);
     }
     strides.push_back(make_const(DataType::Int(32), 1));
 
-    Expr elem_offset = IntImm::make(DataType::Int(32), 0);
+    Expr elem_offset = IntImmNode::make(DataType::Int(32), 0);
     CHECK_EQ(call->args.size(), min_bound.size());
     for (size_t i = 0; i < min_bound.size(); i++) {
-      elem_offset = Add::make(
-        elem_offset, Mul::make(
-          strides[i], Sub::make(call->args[i], min_bound[i])));
+      elem_offset = AddNode::make(
+        elem_offset, MulNode::make(
+          strides[i], SubNode::make(call->args[i], min_bound[i])));
     }
 
     auto it2 = matrix_abc_.find(simplify_name(call->name));
     CHECK(it2 != matrix_abc_.end())
           << "Cannot find matrix info for " << call->name;
-    buffer_node->data = Variable::make(DataType::Handle(), call->name);
+    buffer_node->data = VarNode::make(DataType::Handle(), call->name);
     buffer_node->name = call->name;
     buffer_node->scope = "wmma." + it2->second;
     buffer_node->dtype = datatype;
@@ -1131,12 +1131,12 @@ class TensorCoreIRMutator : public StmtExprMutator {
       args.push_back(call->args[i]);
       args.push_back(shape[i]);
     }
-    auto tuple = Call::make(DataType::Handle(),
+    auto tuple = CallNode::make(DataType::Handle(),
                             intrinsic::tvm_tuple,
                             args,
-                            Call::Intrinsic);
+                            CallNode::Intrinsic);
     Array<ObjectRef> node = {buffer, tensor};
-    return AttrStmt::make(node,
+    return AttrStmtNode::make(node,
                           "buffer_bind_scope",
                           tuple,
                           call_back(buffer));
@@ -1144,12 +1144,12 @@ class TensorCoreIRMutator : public StmtExprMutator {
 
   std::unordered_map<std::string, std::string> matrix_abc_;
   std::unordered_map<std::string, std::string> matrix_major_;
-  std::unordered_map<const Provide*, Array<Expr>> mma_sync_;
+  std::unordered_map<const ProvideNode*, Array<Expr>> mma_sync_;
   std::unordered_map<std::string, Array<Expr>> strides_;
   std::unordered_set<std::string> frag_reg_;
-  std::unordered_map<const Variable*, unsigned> loop_scaling_;
-  std::unordered_map<const Provide*, Expr> frag_load_;
-  std::unordered_map<const Provide*, Expr> frag_store_;
+  std::unordered_map<const VarNode*, unsigned> loop_scaling_;
+  std::unordered_map<const ProvideNode*, Expr> frag_load_;
+  std::unordered_map<const ProvideNode*, Expr> frag_store_;
   std::unordered_map<TensorKey, Region> bounds_;
   Tile warp_tile_;
   int warp_threads_y_{-1};
index 7826a9b..e2e7ad0 100644 (file)
@@ -45,7 +45,7 @@ class LoopUnroller : public StmtExprMutator {
         explicit_unroll_(explicit_unroll) {
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == "pragma_auto_unroll_max_step") {
       int value = 0;
       CHECK(arith::GetConstInt(op->value, &value));
@@ -66,9 +66,9 @@ class LoopUnroller : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const For* op) {
+  Stmt VisitStmt_(const ForNode* op) {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<For>();
+    op = stmt.as<ForNode>();
     int value = GetExtent(op);
     // condition for auto unroll
     bool auto_unroll = (
@@ -101,7 +101,7 @@ class LoopUnroller : public StmtExprMutator {
     } else {
       if (auto_unroll) {
         if (op->for_type != ForType::Unrolled) {
-          return For::make(
+          return ForNode::make(
               op->loop_var, op->min, op->extent,
               ForType::Unrolled, op->device_api, op->body);
         }
@@ -110,12 +110,12 @@ class LoopUnroller : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     ++step_count_;
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const Evaluate* op) final {
+  Stmt VisitStmt_(const EvaluateNode* op) final {
     ++step_count_;
     return StmtExprMutator::VisitStmt_(op);
   }
@@ -137,11 +137,11 @@ class LoopUnroller : public StmtExprMutator {
     return StmtMutator::VisitSeqStmt_(op, false, fmutate);
   }
 
-  Stmt Unroll(const For* op) {
+  Stmt Unroll(const ForNode* op) {
     int value = GetExtent(op);
     // For loop must have a constant integer extent
     CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
-    if (value == 0) return Evaluate::make(0);
+    if (value == 0) return EvaluateNode::make(0);
     Stmt body = op->body;
     Map<Var, Expr> vmap;
     Array<Stmt> unrolled;
@@ -155,11 +155,11 @@ class LoopUnroller : public StmtExprMutator {
 
  private:
   // returns the extent of the loop if it's a constant integer, otherwise return -1
-  int GetExtent(const For* op) {
+  int GetExtent(const ForNode* op) {
     // constant folding.
     Expr extent = ir::Simplify(op->extent);
-    const IntImm  *v1 = extent.as<IntImm>();
-    const UIntImm *v2 = extent.as<UIntImm>();
+    const IntImmNode  *v1 = extent.as<IntImmNode>();
+    const UIntImmNode *v2 = extent.as<UIntImmNode>();
     int value = -1;
     if (v1 != nullptr) {
       value = static_cast<int>(v1->value);
@@ -204,7 +204,7 @@ Stmt UnrollLoop(Stmt stmt,
 }
 
 Stmt UnrollLoopExplicitly(Stmt stmt) {
-  const For* op = stmt.as<For>();
+  const ForNode* op = stmt.as<ForNode>();
   if (!op) {
     LOG(FATAL) << "attempted to unroll a non-loop statement";
   }
index c22243c..450c6ba 100644 (file)
@@ -35,15 +35,15 @@ namespace ir {
 
 inline Expr BroadcastTo(Expr e, int lanes) {
   if (e.dtype().lanes() == lanes) return e;
-  if (const Broadcast* op = e.as<Broadcast>()) {
+  if (const BroadcastNode* op = e.as<BroadcastNode>()) {
     if (lanes % op->lanes == 0) {
-      return Broadcast::make(op->value, lanes);
+      return BroadcastNode::make(op->value, lanes);
     }
   }
   CHECK_EQ(e.dtype().lanes(), 1)
       << "Cannot broadcast lane=" << e.dtype().lanes()
       << " to " << lanes;
-  return Broadcast::make(e, lanes);
+  return BroadcastNode::make(e, lanes);
 }
 
 // Rewrite vectorized allocation access
@@ -56,14 +56,14 @@ inline Expr BroadcastTo(Expr e, int lanes) {
 //
 class VecAllocAccess : public StmtExprMutator {
  public:
-  VecAllocAccess(const Variable* buf, Var var, int var_lanes)
+  VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
       : buf_(buf), var_(var), var_lanes_(var_lanes) {}
   // Load
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<Load>();
+    op = expr.as<LoadNode>();
     if (op->buffer_var.get() == buf_) {
-      return Load::make(op->dtype, op->buffer_var,
+      return LoadNode::make(op->dtype, op->buffer_var,
                         op->index * var_lanes_ + var_,
                         op->predicate);
     } else {
@@ -71,11 +71,11 @@ class VecAllocAccess : public StmtExprMutator {
     }
   }
   // Store
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<Store>();
+    op = stmt.as<StoreNode>();
     if (op->buffer_var.get() == buf_) {
-      return Store::make(op->buffer_var,
+      return StoreNode::make(op->buffer_var,
                          op->value,
                          op->index * var_lanes_ + var_,
                          op->predicate);
@@ -86,7 +86,7 @@ class VecAllocAccess : public StmtExprMutator {
 
  private:
   // buffer var
-  const Variable* buf_;
+  const VarNode* buf_;
   // variable to be replaced
   Var var_;
   // the lanes.
@@ -97,7 +97,7 @@ class Vectorizer : public StmtExprMutator {
  public:
   Vectorizer(Var var, int var_lanes)
       : var_(var), var_lanes_(var_lanes) {
-    ramp_ = Ramp::make(0, 1, var_lanes);
+    ramp_ = RampNode::make(0, 1, var_lanes);
   }
 
   Stmt VisitStmt(const Stmt& stmt) final {
@@ -111,13 +111,13 @@ class Vectorizer : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Add* op) final {
+  Expr VisitExpr_(const AddNode* op) final {
     return AddSubVec(op);
   }
-  Expr VisitExpr_(const Sub* op) final {
+  Expr VisitExpr_(const SubNode* op) final {
     return AddSubVec(op);
   }
-  Expr VisitExpr_(const Mul* op) final {
+  Expr VisitExpr_(const MulNode* op) final {
     Expr a = this->VisitExpr(op->a);
     Expr b = this->VisitExpr(op->b);
     if (a.same_as(op->a) &&
@@ -126,70 +126,70 @@ class Vectorizer : public StmtExprMutator {
     } else {
       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
       if (lanes != 1) {
-        const Ramp* b_ramp = b.as<Ramp>();
-        const Ramp* a_ramp = a.as<Ramp>();
+        const RampNode* b_ramp = b.as<RampNode>();
+        const RampNode* a_ramp = a.as<RampNode>();
         if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
-          return Ramp::make(
+          return RampNode::make(
               a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
         }
         if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
-          return Ramp::make(
+          return RampNode::make(
               b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
         }
       }
-      return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+      return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
     }
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Div* op) final {
+  Expr VisitExpr_(const DivNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Mod* op) final {
+  Expr VisitExpr_(const ModNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const FloorDiv* op) final {
+  Expr VisitExpr_(const FloorDivNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const FloorMod* op) final {
+  Expr VisitExpr_(const FloorModNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Min* op) final {
+  Expr VisitExpr_(const MinNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Max* op) final {
+  Expr VisitExpr_(const MaxNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const EQ* op) final {
+  Expr VisitExpr_(const EQNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const NE* op) final {
+  Expr VisitExpr_(const NENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const LT* op) final {
+  Expr VisitExpr_(const LTNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const LE* op) final {
+  Expr VisitExpr_(const LENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const GT* op) final {
+  Expr VisitExpr_(const GTNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const GE* op) final {
+  Expr VisitExpr_(const GENode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const And* op) final {
+  Expr VisitExpr_(const AndNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Or* op) final {
+  Expr VisitExpr_(const OrNode* op) final {
     return BinaryVec(op);
   }
-  Expr VisitExpr_(const Ramp* op) final {
+  Expr VisitExpr_(const RampNode* op) final {
     Expr base = this->VisitExpr(op->base);
     Expr stride = this->VisitExpr(op->stride);
     if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
-      const Ramp* base_ramp = base.as<Ramp>();
+      const RampNode* base_ramp = base.as<RampNode>();
       if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
-        return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
+        return RampNode::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
       }
     }
     int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
@@ -198,13 +198,13 @@ class Vectorizer : public StmtExprMutator {
     Array<Expr> elems;
     for (int i = 0; i < lanes; ++i) {
       elems.push_back(
-          Ramp::make(Shuffle::make_extract_element(base, i),
-                     Shuffle::make_extract_element(stride, i),
+          RampNode::make(ShuffleNode::make_extract_element(base, i),
+                     ShuffleNode::make_extract_element(stride, i),
                      op->lanes));
     }
-    return Shuffle::make_concat(elems);
+    return ShuffleNode::make_concat(elems);
   }
-  Expr VisitExpr_(const Select *op) final {
+  Expr VisitExpr_(const SelectNode *op) final {
     Expr cond = this->VisitExpr(op->condition);
     Expr t = this->VisitExpr(op->true_value);
     Expr f = this->VisitExpr(op->false_value);
@@ -216,19 +216,19 @@ class Vectorizer : public StmtExprMutator {
       int lanes = std::max(std::max(
           cond.dtype().lanes(),
           t.dtype().lanes()), f.dtype().lanes());
-      return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
+      return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
     }
   }
-  Expr VisitExpr_(const Cast *op) final {
+  Expr VisitExpr_(const CastNode *op) final {
     Expr value = this->VisitExpr(op->value);
     if (value.same_as(op->value)) {
       return GetRef<Expr>(op);
     } else {
-      return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value);
+      return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value);
     }
   }
   // Variable
-  Expr VisitExpr_(const Variable* v) final {
+  Expr VisitExpr_(const VarNode* v) final {
     if (v == var_.get()) {
       return ramp_;
     } else if (lets_.count(v)) {
@@ -238,7 +238,7 @@ class Vectorizer : public StmtExprMutator {
     }
   }
   // IfThenElse expr
-  Expr MutateIfThenElseExpr_(const Call *op) {
+  Expr MutateIfThenElseExpr_(const CallNode *op) {
     Expr cond = this->VisitExpr(op->args[0]);
     if (cond.dtype().is_vector())  {
       need_scalarize_ = true;
@@ -254,13 +254,13 @@ class Vectorizer : public StmtExprMutator {
       int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
       t = BroadcastTo(t, lanes);
       f = BroadcastTo(f, lanes);
-      return Call::make(
+      return CallNode::make(
           op->dtype.with_lanes(lanes), op->name,
           {cond, t, f}, op->call_type, op->func, op->value_index);
     }
   }
   // Call
-  Expr VisitExpr_(const Call* op) final {
+  Expr VisitExpr_(const CallNode* op) final {
     if (op->name == intrinsic::tvm_if_then_else) {
       return MutateIfThenElseExpr_(op);
     }
@@ -278,7 +278,7 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<Expr>(op);
       } else {
-        return Call::make(
+        return CallNode::make(
             op->dtype, op->name, new_args, op->call_type, op->func, op->value_index);
       }
     } else {
@@ -288,21 +288,21 @@ class Vectorizer : public StmtExprMutator {
       if (op->args.same_as(new_args)) {
         return GetRef<Expr>(op);
       } else {
-        return Call::make(
+        return CallNode::make(
             op->dtype.with_lanes(lane), op->name, new_args,
             op->call_type, op->func, op->value_index);
       }
     }
   }
   // Load
-  Expr VisitExpr_(const Load* op) final {
+  Expr VisitExpr_(const LoadNode* op) final {
     Expr index = this->VisitExpr(op->index);
     Expr pred = this->VisitExpr(op->predicate);
     if (index.same_as(op->index) && pred.same_as(op->predicate)) {
       return GetRef<Expr>(op);
     } else {
       int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
-      return Load::make(
+      return LoadNode::make(
           op->dtype.with_lanes(lanes),
           op->buffer_var,
           BroadcastTo(index, lanes),
@@ -310,25 +310,25 @@ class Vectorizer : public StmtExprMutator {
     }
   }
   // Let
-  Expr VisitExpr_(const Let* op) final {
+  Expr VisitExpr_(const LetNode* op) final {
     Expr value = this->VisitExpr(op->value);
     CHECK(!lets_.count(op->var.get())) << "not SSA";
     if (value.dtype().lanes() != op->value.dtype().lanes()) {
       Var v(op->var->name_hint, value.dtype());
       lets_[op->var.get()] = v;
-      return Let::make(v, value, this->VisitExpr(op->body));
+      return LetNode::make(v, value, this->VisitExpr(op->body));
     } else {
       Expr body = this->VisitExpr(op->body);
       if (value.same_as(op->value) &&
           body.same_as(op->body)) {
         return GetRef<Expr>(op);
       } else {
-        return Let::make(op->var, value, body);
+        return LetNode::make(op->var, value, body);
       }
     }
   }
   // Provide
-  Stmt VisitStmt_(const Provide* op) final {
+  Stmt VisitStmt_(const ProvideNode* op) final {
     Expr new_value = this->VisitExpr(op->value);
     int lane = new_value.dtype().lanes();
     Array<Expr> new_args = MutateArray(op->args, &lane);
@@ -336,11 +336,11 @@ class Vectorizer : public StmtExprMutator {
       return GetRef<Stmt>(op);
     } else {
       new_value = BroadcastTo(new_value, lane);
-      return Provide::make(op->func, op->value_index, new_value, new_args);
+      return ProvideNode::make(op->func, op->value_index, new_value, new_args);
     }
   }
   // Store
-  Stmt VisitStmt_(const Store* op) final {
+  Stmt VisitStmt_(const StoreNode* op) final {
     Expr value = this->VisitExpr(op->value);
     Expr index = this->VisitExpr(op->index);
     Expr pred = this->VisitExpr(op->predicate);
@@ -349,14 +349,14 @@ class Vectorizer : public StmtExprMutator {
     } else {
       int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
       lanes = std::max(lanes, pred.dtype().lanes());
-      return Store::make(op->buffer_var,
+      return StoreNode::make(op->buffer_var,
                          BroadcastTo(value, lanes),
                          BroadcastTo(index, lanes),
                          BroadcastTo(pred, lanes));
     }
   }
   // For
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     if (op->for_type == ForType::Vectorized) {
       LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
     }
@@ -371,13 +371,13 @@ class Vectorizer : public StmtExprMutator {
         body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return For::make(
+      return ForNode::make(
           op->loop_var, op->min, extent,
           op->for_type, op->device_api, body);
     }
   }
   // IfThenElse
-  Stmt VisitStmt_(const IfThenElse* op) final {
+  Stmt VisitStmt_(const IfThenElseNode* op) final {
     CHECK(!op->condition.dtype().is_vector());
     Expr condition = this->VisitExpr(op->condition);
     if (condition.dtype().is_vector()) {
@@ -393,16 +393,16 @@ class Vectorizer : public StmtExprMutator {
         else_case.same_as(op->else_case)) {
       return GetRef<Stmt>(op);
     } else {
-      return IfThenElse::make(condition, then_case, else_case);
+      return IfThenElseNode::make(condition, then_case, else_case);
     }
   }
   // LetStmt
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize";
     return Scalarize(GetRef<Stmt>(op));
   }
   // Allocate
-  Stmt VisitStmt_(const Allocate* op) final {
+  Stmt VisitStmt_(const AllocateNode* op) final {
     if (op->new_expr.defined()) {
       LOG(WARNING) << "Cannot vectorize with new expr";
       return Scalarize(GetRef<Stmt>(op));
@@ -427,7 +427,7 @@ class Vectorizer : public StmtExprMutator {
     Stmt body = VecAllocAccess(
         op->buffer_var.get(), var_, var_lanes_)(op->body);
     body = this->VisitStmt(body);
-    return Allocate::make(
+    return AllocateNode::make(
         op->buffer_var, op->dtype,
         extents, condition, body,
         op->new_expr, op->free_function);
@@ -437,7 +437,7 @@ class Vectorizer : public StmtExprMutator {
     Var idx(var_->name_hint + ".s", var_->dtype);
     Map<Var, Expr> values{{var_, idx}};
     stmt = Substitute(stmt, values);
-    return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
+    return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
   }
 
  private:
@@ -452,7 +452,7 @@ class Vectorizer : public StmtExprMutator {
   // flag to mark requirment of scalarization.
   bool need_scalarize_{false};
   // The lets
-  std::unordered_map<const Variable*, Expr> lets_;
+  std::unordered_map<const VarNode*, Expr> lets_;
   // mutate array, with given lane requirement
   // when finished, p_lane updates the lane requirement.
   Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) {
@@ -499,16 +499,16 @@ class Vectorizer : public StmtExprMutator {
     } else {
       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
       if (lanes != 1) {
-        const Ramp* b_ramp = b.as<Ramp>();
-        const Ramp* a_ramp = a.as<Ramp>();
+        const RampNode* b_ramp = b.as<RampNode>();
+        const RampNode* a_ramp = a.as<RampNode>();
         if (a.dtype().lanes() == 1 && b_ramp) {
-          return Ramp::make(
+          return RampNode::make(
               arith::Compute<T>(a, b_ramp->base),
               arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
               b_ramp->lanes);
         }
         if (b.dtype().lanes() == 1 && a_ramp) {
-          return Ramp::make(
+          return RampNode::make(
               arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
         }
       }
@@ -519,7 +519,7 @@ class Vectorizer : public StmtExprMutator {
 
 class LoopVectorizer : public StmtMutator {
  public:
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     if (op->for_type == ForType::Vectorized) {
       CHECK(is_zero(op->min));
       int lanes = 0;
@@ -540,11 +540,11 @@ Stmt VectorizeLoop(Stmt stmt) {
 
 class VectorizeSkipper : public StmtMutator {
  public:
-  Stmt VisitStmt_(const For* op) final {
+  Stmt VisitStmt_(const ForNode* op) final {
     Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<For>();
+    op = stmt.as<ForNode>();
     if (op->for_type == ForType::Vectorized) {
-      return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
+      return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
                        op->body);
     } else {
        return stmt;
index 671b4a0..f6c454d 100644 (file)
@@ -39,7 +39,7 @@ class VerifyBuffer : public StmtVisitor {
     return is_compact_;
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     StmtVisitor::VisitStmt_(op);
     if (op->attr_key == attr::buffer_bind_scope) {
       is_compact_ = true;
index 08ec413..96f231e 100644 (file)
@@ -56,7 +56,7 @@ class GPUCodeVerifier : public StmtVisitor {
     return valid_;
   }
 
-  void VisitStmt_(const ProducerConsumer* op) final {
+  void VisitStmt_(const ProducerConsumerNode* op) final {
     if (nest_level_ == 0) {
       // enter a new kernel, reset statistics
       Reset_();
@@ -79,7 +79,7 @@ class GPUCodeVerifier : public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const Allocate* op) final {
+  void VisitStmt_(const AllocateNode* op) final {
     StmtVisitor::VisitStmt_(op);
     // visit an allocation of a buffer in shared memory, record its size
     if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
@@ -91,17 +91,17 @@ class GPUCodeVerifier : public StmtVisitor {
     }
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
-      std::string op_value = op->value.as<StringImm>()->value;
+      std::string op_value = op->value.as<StringImmNode>()->value;
       if (op_value == "local") {
-        visited_local_buffers_.insert(op->node.as<tvm::Variable>());
+        visited_local_buffers_.insert(op->node.as<tvm::VarNode>());
       } else if (op_value == "shared") {
-        visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
+        visited_shared_buffers_.insert(op->node.as<tvm::VarNode>());
       }
     } else if (op->attr_key == attr::thread_extent) {
       VarExpr var = op->node.as<tvm::IterVarNode>()->var;
-      const auto *extent = op->value.as<IntImm>();
+      const auto *extent = op->value.as<IntImmNode>();
       CHECK(extent);
 
       // record the number of threads in a block
@@ -140,8 +140,8 @@ class GPUCodeVerifier : public StmtVisitor {
  private:
   int nest_level_{0};
 
-  std::unordered_set<const tvm::Variable *> visited_local_buffers_;
-  std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
+  std::unordered_set<const tvm::VarNode *> visited_local_buffers_;
+  std::unordered_set<const tvm::VarNode *> visited_shared_buffers_;
   std::unordered_set<std::string> visited_threads_;
 
   size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
@@ -180,7 +180,7 @@ bool VerifyGPUCode(Stmt stmt,
   int64_t max_thread_z = INT64_MAX;
 
   for (auto iter : constraints) {
-    const IntImm* val = iter.second.as<IntImm>();
+    const IntImmNode* val = iter.second.as<IntImmNode>();
     if (iter.first == "max_local_memory_per_block")
       max_local_memory_per_block = val->value;
     else if (iter.first == "max_shared_memory_per_block")
index 415841d..25e7258 100644 (file)
@@ -75,13 +75,13 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
     StmtExprVisitor::VisitStmt(n);
   }
 
-  void VisitStmt_(const LetStmt* op) final {
+  void VisitStmt_(const LetStmtNode* op) final {
     // Book keep definitions
     defs_[op->var.get()] = op->value;
     return StmtExprVisitor::VisitStmt_(op);
   }
 
-  void VisitStmt_(const AttrStmt* op) final {
+  void VisitStmt_(const AttrStmtNode* op) final {
     if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
                            op->attr_key == attr::pipeline_exec_scope)) {
       EnterThreadEnv();
@@ -92,26 +92,26 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
     }
   }
 
-  void VisitStmt_(const ProducerConsumer* op) final {
+  void VisitStmt_(const ProducerConsumerNode* op) final {
     EnterProducerConsumer(op);
     StmtExprVisitor::VisitStmt_(op);
     ExitProducerConsumer();
   }
 
-  void VisitExpr_(const Load* op) final {
+  void VisitExpr_(const LoadNode* op) final {
     HandleLoadStoreToVariable(op->buffer_var);
     return StmtExprVisitor::VisitExpr_(op);
   }
 
-  void VisitStmt_(const Store* op) final {
+  void VisitStmt_(const StoreNode* op) final {
     HandleLoadStoreToVariable(op->buffer_var);
     return StmtExprVisitor::VisitStmt_(op);
   }
   //@}
 
   /// Check if the value of a Variable comes from function argument.
-  bool IsFromFunctionArgs(const Variable *var) const {
-    const Variable *V = var;
+  bool IsFromFunctionArgs(const VarNode *var) const {
+    const VarNode *V = var;
     while (true) {
       CHECK(V) << "Invalid Variable\n";
 
@@ -122,9 +122,9 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
       // Get the first argument of tvm_struct_get, and continue.
       const auto &iter = defs_.find(V);
       if (iter == defs_.end()) return false;
-      const Call *C = iter->second.as<const Call>();
+      const CallNode *C = iter->second.as<const CallNode>();
       if (!C || C->name != intrinsic::tvm_struct_get) return false;
-      V = C->args[0].as<Variable>();
+      V = C->args[0].as<VarNode>();
     }
     return false;
   }
@@ -155,8 +155,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   void EnterThreadEnv() { in_thread_env_ = true; }
   void ExitThreadEnv() { in_thread_env_ = false; }
   bool InProducerConsumer() const { return pc_ != nullptr; }
-  const ProducerConsumer *GetCurrentProducerConsumer() const { return pc_; }
-  void EnterProducerConsumer(const ProducerConsumer *pc) { this->pc_ = pc; }
+  const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; }
+  void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; }
   void ExitProducerConsumer() { pc_ = nullptr; }
   void SetFailure() { failure_ = true; }
   //@}
@@ -176,12 +176,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   /// Status of visitor
   //@{
   bool in_thread_env_{false};
-  const ProducerConsumer *pc_{nullptr};
+  const ProducerConsumerNode *pc_{nullptr};
   bool failure_{false};  ///< If the verification fails (i.e. has illegal access)
   //@}
   LoweredFunc func_{nullptr};  ///< Function to be verified.
   int dev_type_{kDLCPU};       ///< Device type
-  std::unordered_map<const Variable *, Expr> defs_;  ///< Variable definitions
+  std::unordered_map<const VarNode *, Expr> defs_;  ///< Variable definitions
 };
 }  // namespace
 
index 102e4c2..69731ea 100644 (file)
@@ -85,7 +85,7 @@ struct GraphCodegen {
     std::unordered_map<std::string, tvm::runtime::NDArray> ret;
     auto names = CallFunc<Array<tvm::Expr> >("list_params_name", nullptr);
     for (auto expr : names) {
-      auto key = expr.as<ir::StringImm>()->value;
+      auto key = expr.as<ir::StringImmNode>()->value;
       ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
     }
     return ret;
@@ -193,7 +193,7 @@ class RelayBuildModule : public runtime::ModuleNode {
   Array<tvm::Expr> ListParamNames() {
     Array<tvm::Expr> ret;
     for (const auto& kv : params_) {
-      ret.push_back(ir::StringImm::make(kv.first));
+      ret.push_back(ir::StringImmNode::make(kv.first));
     }
     return ret;
   }
index 96cd5a1..6c511ae 100644 (file)
@@ -88,9 +88,9 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
     if (pval != nullptr) {
       CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
       CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
-      res.push_back(ir::IntImm::make(DataType::Int(32), *pval));
-    } else if (val->IsInstance<ir::Any>()) {
-      res.push_back(val.as<ir::Any>()->ToVar());
+      res.push_back(ir::IntImmNode::make(DataType::Int(32), *pval));
+    } else if (val->IsInstance<ir::AnyNode>()) {
+      res.push_back(val.as<ir::AnyNode>()->ToVar());
     } else {
       res.push_back(val);
     }
@@ -395,7 +395,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     // set inputs
     for (auto param : prim_func->params) {
       int state = param_states_[param];
-      cache_node->shape_func_param_states.push_back(IntImm::make(DataType::Int(32), state));
+      cache_node->shape_func_param_states.push_back(IntImmNode::make(DataType::Int(32), state));
       if (state & kNeedInputData) {
         for (auto t : param_data_[param]) {
           cache_node->inputs.push_back(t);
@@ -528,7 +528,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
     auto ret_type = call_node->checked_type();
     Array<IndexExpr> out_ndims;
     if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
-      out_ndims.push_back(IntImm::make(DataType::Int(32), ttype->shape.size()));
+      out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
     } else {
       auto rtype = ret_type.as<TupleTypeNode>();
       // TODO(@icemelon): Allow recursive tuple
@@ -536,7 +536,7 @@ class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
       for (size_t i = 0; i < rtype->fields.size(); ++i) {
         auto ttype = rtype->fields[i].as<TensorTypeNode>();
         CHECK(ttype);
-        out_ndims.push_back(IntImm::make(DataType::Int(32), ttype->shape.size()));
+        out_ndims.push_back(IntImmNode::make(DataType::Int(32), ttype->shape.size()));
       }
     }
     // Call shape function
@@ -620,13 +620,13 @@ class CompileEngineImpl : public CompileEngineNode {
       CHECK(src_func.defined());
       if (!src_func->UseDefaultCompiler()) {
         auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
-        const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
+        const tvm::ir::StringImmNode* code_gen = compiler.as<tvm::ir::StringImmNode>();
         CHECK(code_gen) << "No external codegen is set";
         if (ext_mods.find(code_gen->value) == ext_mods.end()) {
           ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
         }
         auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
-        const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
+        const tvm::ir::StringImmNode* symbol_name = ext_symbol.as<tvm::ir::StringImmNode>();
         CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
         auto gv = GlobalVarNode::make(symbol_name->value);
         ext_mods[code_gen->value]->Add(gv, src_func);
@@ -697,7 +697,7 @@ class CompileEngineImpl : public CompileEngineNode {
     if (!key->source_func->UseDefaultCompiler()) {
       auto cache_node = make_object<CachedFuncNode>();
       const auto name_node =
-          FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
+          FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImmNode>();
       CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
       cache_node->func_name = name_node->value;
       cache_node->target = tvm::target::ext_dev();
index d97f5dc..f6fb222 100644 (file)
@@ -59,7 +59,8 @@ class CSourceModuleCodegenBase {
    * \return An external symbol.
    */
   std::string GetExtSymbol(const Function& func) const {
-    const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
+    const auto name_node =
+      FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImmNode>();
     CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
     std::string ext_symbol = name_node->value;
     return ext_symbol;
@@ -176,7 +177,7 @@ class CodegenCBase {
     CHECK(ttype) << "Expect TensorTypeNode";
     std::vector<int> shape;
     for (size_t i = 0; i < ttype->shape.size(); ++i) {
-      auto* val = ttype->shape[i].as<IntImm>();
+      auto* val = ttype->shape[i].as<IntImmNode>();
       CHECK(val);
       shape.push_back(val->value);
     }
index 4c0fe34..9c24944 100644 (file)
@@ -142,12 +142,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
     // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
     args.push_back(std::to_string(wshape[0]));
     args.push_back(std::to_string(conv2d_attr->groups));
-    args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImm>()->value));
-    args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImm>()->value));
+    args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
+    args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
     args.push_back(std::to_string(wshape[2]));
     args.push_back(std::to_string(wshape[3]));
-    args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImm>()->value));
-    args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImm>()->value));
+    args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
+    args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
 
     return args;
   }
index 5f21043..618e135 100644 (file)
@@ -623,7 +623,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
          Map<Integer, tvm::Target> tmp = args[1];
          TargetsMap targets;
          for (const auto& it : tmp) {
-           auto dev_type = it.first.as<ir::IntImm>();
+           auto dev_type = it.first.as<ir::IntImmNode>();
            CHECK(dev_type);
            targets[dev_type->value] = it.second;
          }
@@ -643,7 +643,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         Array<tvm::Expr> ret;
         for (const auto &kv : this->output_.params) {
-          tvm::Expr name = ir::StringImm::make(kv.first);
+          tvm::Expr name = ir::StringImmNode::make(kv.first);
           ret.push_back(name);
         }
         *rv = ret;
index b6cb1aa..b7ecadc 100644 (file)
@@ -44,7 +44,7 @@ inline std::string GenerateName(const Function& func) {
 
 bool IsClosure(const Function& func) {
   ObjectRef res = FunctionGetAttr(func, attr::kClosure);
-  const ir::IntImm* pval = res.as<ir::IntImm>();
+  const ir::IntImmNode* pval = res.as<ir::IntImmNode>();
   return pval && pval->value != 0;
 }
 
index c6fe490..cea6115 100644 (file)
@@ -103,7 +103,7 @@ Module RemoveUnusedFunctions(const Module& module,
                              Array<tvm::Expr> entry_funcs) {
   std::unordered_set<std::string> called_funcs{};
   for (auto entry : entry_funcs) {
-    auto* str_name = entry.as<ir::StringImm>();
+    auto* str_name = entry.as<ir::StringImmNode>();
     auto funcs = CallTracer(module).Trace(str_name->value);
     called_funcs.insert(funcs.cbegin(), funcs.cend());
   }
index 248f06b..b41d381 100644 (file)
@@ -196,7 +196,7 @@ class AlphaEqualHandler:
     }
   }
   using AttrsEqualHandler::VisitAttr_;
-  bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final {
+  bool VisitAttr_(const tvm::VarNode* lhs, const ObjectRef& other) final {
     return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
   }
 
index 81c0c25..a6f44ce 100644 (file)
@@ -56,7 +56,7 @@ TensorType ConstantNode::tensor_type() const {
     CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
     CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
     shape.push_back(
-        tvm::ir::IntImm::make(DataType::Int(32), data->shape[i]));
+        tvm::ir::IntImmNode::make(DataType::Int(32), data->shape[i]));
   }
 
   return TensorTypeNode::make(shape, dtype);
@@ -158,7 +158,7 @@ FuncType FunctionNode::func_type_annotation() const {
 
 bool FunctionNode::IsPrimitive() const {
   ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
-  const ir::IntImm* pval = res.as<ir::IntImm>();
+  const ir::IntImmNode* pval = res.as<ir::IntImmNode>();
   return pval && pval->value != 0;
 }
 
@@ -184,7 +184,7 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
 
 bool FunctionNode::UseDefaultCompiler() const {
   ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
-  const ir::StringImm* pval = res.as<ir::StringImm>();
+  const ir::StringImmNode* pval = res.as<ir::StringImmNode>();
   return pval == nullptr || pval->value == "default";
 }
 
index cf1e280..d179d7e 100644 (file)
@@ -125,8 +125,8 @@ class RelayHashHandler:
   }
 
   using AttrsHashHandler::VisitAttr_;
-  size_t VisitAttr_(const Variable* var) final {
-    size_t hash = std::hash<std::string>()(Variable::_type_key);
+  size_t VisitAttr_(const tvm::VarNode* var) final {
+    size_t hash = std::hash<std::string>()(VarNode::_type_key);
     auto it = hash_map_.find(GetRef<VarExpr>(var));
     if (it != hash_map_.end()) {
       return it->second;
index 362bbf0..a9d788d 100644 (file)
@@ -813,7 +813,7 @@ class PrettyPrinter :
   Doc PrintAttr(const ObjectRef& value, bool meta = false) {
     if (value.defined()) {
       Doc printed_attr;
-      if (value.as<tvm::ir::Any>()) {
+      if (value.as<tvm::ir::AnyNode>()) {
         printed_attr << "?";
       } else if (meta) {
         printed_attr = meta_.GetMetaNode(Downcast<ObjectRef>(value));
@@ -842,19 +842,19 @@ class PrettyPrinter :
     return doc;
   }
 
-  Doc VisitAttr_(const ir::IntImm* op) final {
+  Doc VisitAttr_(const ir::IntImmNode* op) final {
     return PrintConstScalar(op->dtype, &(op->value));
   }
 
-  Doc VisitAttr_(const ir::UIntImm* op) final {
+  Doc VisitAttr_(const ir::UIntImmNode* op) final {
     return PrintConstScalar(op->dtype, &(op->value));
   }
 
-  Doc VisitAttr_(const ir::FloatImm* op) final {
+  Doc VisitAttr_(const ir::FloatImmNode* op) final {
     return PrintConstScalar(op->dtype, &(op->value));
   }
 
-  Doc VisitAttr_(const ir::StringImm* op) final {
+  Doc VisitAttr_(const ir::StringImmNode* op) final {
     return PrintString(op->value);
   }
 
index 42f016d..bd8fb42 100644 (file)
@@ -141,7 +141,7 @@ bool AllocTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   // Second argument should be shape tensor.
   auto tt = types[1].as<TensorTypeNode>();
   CHECK(tt != nullptr) << "must be tensor type";
-  auto rank = tt->shape[0].as<tvm::IntImm>();
+  auto rank = tt->shape[0].as<tvm::IntImmNode>();
   CHECK(rank != nullptr);
   auto dims = rank->value;
 
index 2a7d6b3..627c420 100644 (file)
@@ -548,13 +548,13 @@ bool Conv2DWinogradRel(const Array<Type>& types,
 
   IndexExpr pad_h, pad_w;
   GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  if (!dshape_nchw[2].as<ir::Any>()) {
+  if (!dshape_nchw[2].as<ir::AnyNode>()) {
     oshape.Set(2, (dshape_nchw[2] + pad_h
                    - dilated_ksize_y) / param->strides[0] + 1);
   } else {
     oshape.Set(2, dshape_nchw[2]);
   }
-  if (!dshape_nchw[3].as<ir::Any>()) {
+  if (!dshape_nchw[3].as<ir::AnyNode>()) {
     oshape.Set(3, (dshape_nchw[3] + pad_w
                    - dilated_ksize_x) / param->strides[1] + 1);
   } else {
index 913c5b0..b61942d 100644 (file)
@@ -119,14 +119,14 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   IndexExpr pad_h, pad_w;
   GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  if (!dshape_nchw[2].as<ir::Any>()) {
+  if (!dshape_nchw[2].as<ir::AnyNode>()) {
     oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
                            param->strides[0]) + 1);
   } else {
     oshape.Set(2, dshape_nchw[2]);
   }
 
-  if (!dshape_nchw[3].as<ir::Any>()) {
+  if (!dshape_nchw[3].as<ir::AnyNode>()) {
     oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
                            param->strides[1]) + 1);
   } else {
@@ -232,21 +232,21 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   IndexExpr pad_d, pad_h, pad_w;
   GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
-  if (!dshape_ncdhw[2].as<ir::Any>()) {
+  if (!dshape_ncdhw[2].as<ir::AnyNode>()) {
     oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
                            param->strides[0]) + 1);
   } else {
     oshape.Set(2, dshape_ncdhw[2]);
   }
 
-  if (!dshape_ncdhw[3].as<ir::Any>()) {
+  if (!dshape_ncdhw[3].as<ir::AnyNode>()) {
     oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
                            param->strides[1]) + 1);
   } else {
     oshape.Set(3, dshape_ncdhw[3]);
   }
 
-  if (!dshape_ncdhw[4].as<ir::Any>()) {
+  if (!dshape_ncdhw[4].as<ir::AnyNode>()) {
     oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
                            param->strides[2]) + 1);
   } else {
index bab59f7..aeb40fd 100644 (file)
@@ -408,7 +408,7 @@ bool BatchFlattenRel(const Array<Type>& types,
   auto target_dim = make_const(DataType::Int(32), 1);
 
   for (uint32_t i = 1; i < data->shape.size(); ++i) {
-    if (!data->shape[i].as<ir::Any>()) {
+    if (!data->shape[i].as<ir::AnyNode>()) {
       target_dim = target_dim * data->shape[i];
     } else {
       target_dim = data->shape[i];
index 0d5810f..f9d753f 100644 (file)
@@ -80,7 +80,7 @@ Array<Array<Layout> > PadInferCorrectLayout(
 
         // If any pad_width element is not zero, do not change the layout.
         for (auto width : axis_pad_width.at(dual_axis_name)) {
-          if (auto* width_imm = width.as<IntImm>()) {
+          if (auto* width_imm = width.as<IntImmNode>()) {
             if (width_imm->value != 0) {
               is_layout_modified = false;
             }
@@ -147,7 +147,7 @@ bool PadRel(const Array<Type>& types,
       << "Param width elements should be positive but first pad width at "
       << "index " << i << " is " << *width2 << ".";
 
-    if (!data->shape[i].as<ir::Any>()) {
+    if (!data->shape[i].as<ir::AnyNode>()) {
       auto padding = make_const(data->shape[i].dtype(), *width1 + *width2);
       oshape.push_back(data->shape[i] + padding);
     } else {
index 529435d..a4e60f4 100644 (file)
@@ -139,7 +139,7 @@ bool Pool2DRel(const Array<Type>& types,
     oshape.push_back(e);
   }
 
-  if (dshape[hidx].as<ir::Any>()) {
+  if (dshape[hidx].as<ir::AnyNode>()) {
     oshape[hidx] = dshape[hidx];
   } else {
     if (param->ceil_mode) {
@@ -149,7 +149,7 @@ bool Pool2DRel(const Array<Type>& types,
       oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1;
     }
   }
-  if (dshape[widx].as<ir::Any>()) {
+  if (dshape[widx].as<ir::AnyNode>()) {
     oshape[widx] = dshape[widx];
   } else {
     if (param->ceil_mode) {
@@ -796,7 +796,7 @@ bool Pool3DRel(const Array<Type>& types,
   std::vector<int> idxes = {didx, hidx, widx};
   for (int i = 0; i < 3; i++) {
     int ii = idxes[i];
-    if (dshape[ii].as<ir::Any>()) {
+    if (dshape[ii].as<ir::AnyNode>()) {
       oshape[ii] = dshape[ii];
     } else {
       if (param->ceil_mode) {
index 1f2a016..73cb5a1 100644 (file)
@@ -83,8 +83,8 @@ bool UpSamplingRel(const Array<Type>& types,
     << " But got " << in_layout;
 
   auto oshape = layout_converter.ForwardShape(data->shape);
-  oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
-  oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
+  oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
+  oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
 
   // assign output type
   reporter->Assign(types[1],
@@ -162,9 +162,9 @@ bool UpSampling3DRel(const Array<Type>& types,
     << " But got " << in_layout;
 
   auto oshape = layout_converter.ForwardShape(data->shape);
-  oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
-  oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
-  oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
+  oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
+  oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
+  oshape.Set(4, ir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
 
   // assign output type
   reporter->Assign(types[1],
index ae8e62c..dde3ef2 100644 (file)
@@ -212,7 +212,7 @@ inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr> &in_s
   auto max_shape = make_const(DataType::Int(64), 1);
   bool is_dynamic_input = false;
   for (int64_t axis : r_axes) {
-    if (in_shape[axis].as<IntImm>()) {
+    if (in_shape[axis].as<IntImmNode>()) {
       max_shape *= in_shape[axis];
     } else {
       is_dynamic_input = true;
index 5885a00..b8ee7e7 100644 (file)
@@ -42,7 +42,7 @@
 
 namespace tvm {
 namespace relay {
-using ir::IntImm;
+using ir::IntImmNode;
 
 // relay.cast
 TVM_REGISTER_NODE_TYPE(CastAttrs);
@@ -695,8 +695,8 @@ Array<Tensor> ReshapeCompute(const Attrs& attrs,
   CHECK(out_ttype != nullptr);
   Array<IndexExpr> newshape;
   for (auto val : out_ttype->shape) {
-    if (val->IsInstance<ir::Any>()) {
-      newshape.push_back(val.as<ir::Any>()->ToVar());
+    if (val->IsInstance<ir::AnyNode>()) {
+      newshape.push_back(val.as<ir::AnyNode>()->ToVar());
     } else {
       newshape.push_back(val);
     }
@@ -800,7 +800,7 @@ bool ReshapeLikeRel(const Array<Type>& types,
   // Only check When input data has static shape.
   bool is_static_shape = true;
   for (size_t i = 0; i < data->shape.size(); ++i) {
-    if (!data->shape[i].as<IntImm>()) {
+    if (!data->shape[i].as<IntImmNode>()) {
       is_static_shape = false;
       break;
     }
@@ -852,7 +852,7 @@ bool ArgWhereRel(const Array<Type>& types,
   const auto& input_rank = input_shape.size();
   std::vector<IndexExpr> result_shape;
   result_shape.push_back(Any::make());
-  result_shape.push_back(IntImm::make(DataType::Int(32), input_rank));
+  result_shape.push_back(IntImmNode::make(DataType::Int(32), input_rank));
   reporter->Assign(types[1], TensorTypeNode::make(result_shape, DataType::Int(32)));
   return true;
 }
@@ -1384,7 +1384,7 @@ bool TileRel(const Array<Type>& types,
     << "repetition array is not defined. data.ndim = " << ndim;
   const size_t rndim = reps.size();
   for (size_t i = 0; i < rndim; ++i) {
-    if (const tvm::ir::IntImm* val = reps[i].as<tvm::ir::IntImm>()) {
+    if (const tvm::ir::IntImmNode* val = reps[i].as<tvm::ir::IntImmNode>()) {
       CHECK_GT(val->value, 0)
           << "Tile reps value should always be larger than 0, but get: " << val->value;
     }
@@ -1425,7 +1425,7 @@ bool TileRel(const Array<Type>& types,
   oshape.reserve(tndim);
   for (size_t i = 0; i < tndim; ++i) {
     // Save Any if it is dynamic shape
-    if (!data_shape[i].as<IntImm>()) {
+    if (!data_shape[i].as<IntImmNode>()) {
       oshape.emplace_back(Any::make());
     } else {
       oshape.emplace_back(data_shape[i] * reps_shape[i]);
@@ -1649,7 +1649,7 @@ bool SqueezeRel(const Array<Type>& types,
   // if axes is None, squeeze all axes of dimension 1
   if (!param->axis.defined()) {
     for (const auto& e : data->shape) {
-      if (!e.as<IntImm>()) {
+      if (!e.as<IntImmNode>()) {
         LOG(FATAL) << "axis needs to be defined for dynamic input.";
       }
       const int64_t* axis_ptr = as_const_int(e);
@@ -1838,7 +1838,7 @@ RELAY_REGISTER_OP("broadcast_to_like")
 // Adapter function to make int array.
 Array<Integer> GetIntArray(Array<IndexExpr> arr) {
   for (size_t i = 0; i < arr.size(); ++i) {
-    CHECK(!arr[i].defined() || arr[i].as<IntImm>())
+    CHECK(!arr[i].defined() || arr[i].as<IntImmNode>())
       << "Expect an int array";
   }
   return Downcast<Array<Integer> >(arr);
@@ -1988,7 +1988,7 @@ Array<Array<Layout> > StridedSliceInferCorrectLayout(
         }
         int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
         int64_t end = params->end[i].defined() ? params->end[i]->value :
-            shape[i].as<IntImm>()->value;
+            shape[i].as<IntImmNode>()->value;
         if (begin % factor || end % factor) {
           // transform to original layout
           return {{Layout::Undef()}, {Layout::Undef()}};
@@ -2139,7 +2139,7 @@ bool SplitRel(const Array<Type>& types,
   CHECK_GE(axis, 0)
     << "axis should be within the input dimension range.";
 
-  if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
+  if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
     CHECK(reporter->Assert(indexmod(data->shape[axis],
                                     sections->value) == make_zero(DataType::Int(64))))
         << "indices_or_sections need to be able to divide input.shape[axis]";
@@ -2182,7 +2182,7 @@ Array<Tensor> SplitCompute(const Attrs& attrs,
   const auto param = attrs.as<SplitAttrs>();
   CHECK(param != nullptr);
 
-  if (const IntImm* sections = param->indices_or_sections.as<IntImm>()) {
+  if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
     int64_t num_sections = sections->value;
     return Array<Tensor>{
       topi::split_sections(inputs[0], num_sections, param->axis) };
@@ -2489,7 +2489,7 @@ bool GatherNDRel(const Array<Type>& types,
     return false;
   }
   const size_t ndim = data->shape.size();
-  const IntImm* mdim = indices->shape[0].as<IntImm>();
+  const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
   const size_t kdim = indices->shape.size() - 1;
   CHECK(size_t(mdim->value) <= ndim)
         << "GatherND: indices shape does satisfy.";
index 630c25e..20a57fa 100644 (file)
@@ -123,7 +123,7 @@ Pass AlterOpLayout() {
       return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
   };
   return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
index 861efb4..7a52dcf 100644 (file)
@@ -134,7 +134,7 @@ Pass CanonicalizeCast() {
     return Downcast<Function>(CanonicalizeCast(f));
   };
   return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
index 78001bb..2226516 100644 (file)
@@ -74,7 +74,7 @@ Pass CanonicalizeOps() {
     return Downcast<Function>(CanonicalizeOps(f));
   };
   return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
index 869aa28..5cb4c45 100644 (file)
@@ -221,7 +221,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
       return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
   };
   return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
index af43225..81a4806 100644 (file)
@@ -81,7 +81,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
       return Downcast<Function>(CombineParallelDense(f, min_num_branches));
   };
   return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
index d8152f6..b240ba7 100644 (file)
@@ -194,7 +194,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
                                                        min_num_branches));
   };
   return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
index da0c28f..df711bf 100644 (file)
@@ -134,8 +134,8 @@ Pass ConvertLayout(const std::string& desired_layout) {
       };
   return CreateFunctionPass(
       pass_func, 3, "ConvertLayout",
-      {ir::StringImm::make("InferType"),
-       ir::StringImm::make("CanonicalizeOps")});
+      {ir::StringImmNode::make("InferType"),
+       ir::StringImmNode::make("CanonicalizeOps")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
index 1229324..3ef501d 100644 (file)
@@ -577,7 +577,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
     return Downcast<Function>(RewriteAnnotatedOps(f, fallback_device));
   };
   return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
index f9a303b..04aef0e 100644 (file)
@@ -92,7 +92,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
       return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
   };
   return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
index b830de0..d36733d 100644 (file)
@@ -248,9 +248,9 @@ class ConstantFolder : public ExprMutator {
       std::vector<int64_t> cshape = { static_cast<int64_t>(ishape.size()) };
       value = runtime::NDArray::Empty(cshape, cdtype, ctx);
       int32_t* dims = static_cast<int32_t*>(value->data);
-      using ::tvm::ir::IntImm;
+      using ::tvm::ir::IntImmNode;
       for (size_t i = 0; i < ishape.size(); ++i) {
-        if (const IntImm* dim = ishape[i].as<IntImm>()) {
+        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
           dims[i] = dim->value;
         } else {
           return expr;
index fea5cdb..ddb3ac0 100644 (file)
@@ -955,7 +955,7 @@ Pass ForwardFoldScaleAxis() {
           relay::fold_scale_axis::ForwardFoldScaleAxis(f));
   };
   return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
@@ -968,7 +968,7 @@ Pass BackwardFoldScaleAxis() {
           relay::fold_scale_axis::BackwardFoldScaleAxis(f));
     };
   return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
index 3af99a2..c217d06 100644 (file)
@@ -983,7 +983,7 @@ Pass FuseOps(int fuse_opt_level) {
     return Downcast<Function>(FuseOps(f, opt_level, m));
   };
   return CreateFunctionPass(pass_func, 1, "FuseOps",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
index 94eeba1..c5202b5 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -64,7 +64,7 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o
 
       // 4) a) Check if this shape element is 1.
       bool is_shape_one = false;
-      if (auto* shape_int = shape_val.as<IntImm>()) {
+      if (auto* shape_int = shape_val.as<IntImmNode>()) {
         if (shape_int->value == 1) {
           new_layout += "1";
           is_shape_one = true;
index 8f3830e..654c91e 100644 (file)
@@ -102,7 +102,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
       [=](Function f, Module m, PassContext pc) {
         return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
       };
-  return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImm::make("InferType")});
+  return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
index a5cd93a..9e3e95e 100644 (file)
@@ -41,7 +41,7 @@ namespace mac_count {
 inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
   int64_t ret = 1;
   for (size_t i = 0; i < arr.size(); i++) {
-    const auto* intImm = arr[i].as<IntImm>();
+    const auto* intImm = arr[i].as<IntImmNode>();
     ret *= static_cast<int64_t>(intImm->value);
   }
   return ret;
@@ -75,9 +75,9 @@ int64_t ConvMacCount(const Call& call_node) {
   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
   CHECK_NE(C_ind, -1)
     << "There is no input channel dimension.";
-  int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
+  int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
   if (c_ind != -1)
-    input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
+    input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
   Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
   CHECK_EQ(kernel_size.size(), 2)
     << "The dimension of the kernel in Conv 2D should be 2.";
@@ -108,9 +108,9 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
   CHECK_NE(C_ind, -1)
     << "There is no input channel dimension.";
-  int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
+  int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImmNode>()->value);
   if (c_ind != -1)
-    input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
+    input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImmNode>()->value);
   Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
   CHECK_EQ(kernel_size.size(), 2)
     << "The dimension of the kernel in Conv 2D Transpose should be 2.";
@@ -139,10 +139,10 @@ int64_t DenseMacCount(const Call& call_node) {
   Array<IndexExpr> weight_shape = weight_type->shape;
   CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
     << "The dimension of an input tensor to Dense node should be 2.";
-  int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
-  int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
-  int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
-  int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
+  int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImmNode>()->value);
+  int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImmNode>()->value);
+  int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImmNode>()->value);
+  int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImmNode>()->value);
   CHECK_EQ(d2, d4)
     << "The dimensions of input arguments do not match.";
   int64_t count = d1 * d2 * d3;
@@ -158,10 +158,10 @@ int64_t BatchMatmulMacCount(const Call& call_node) {
   CHECK_EQ(args.size(), 2);
   Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
   Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
-  int64_t batch = x_shape[0].as<IntImm>()->value;
-  int64_t m = x_shape[1].as<IntImm>()->value;
-  int64_t k = x_shape[2].as<IntImm>()->value;
-  int64_t n = y_shape[1].as<IntImm>()->value;
+  int64_t batch = x_shape[0].as<IntImmNode>()->value;
+  int64_t m = x_shape[1].as<IntImmNode>()->value;
+  int64_t k = x_shape[2].as<IntImmNode>()->value;
+  int64_t n = y_shape[1].as<IntImmNode>()->value;
   return batch * m * k * n;
 }
 
index e02dcc0..fea463d 100644 (file)
@@ -337,7 +337,7 @@ Module FunctionPassNode::operator()(const Module& mod,
 
 bool FunctionPassNode::SkipFunction(const Function& func) const {
   ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
-  const ir::IntImm* pval = skip_opt.as<ir::IntImm>();
+  const ir::IntImmNode* pval = skip_opt.as<ir::IntImmNode>();
   return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
 }
 
@@ -373,7 +373,7 @@ void SequentialNode::ResolveDependency(const Module& mod) {
 inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
                               const std::string& pass_name) {
   for (auto x : pass_array) {
-    auto* str_name = x.as<ir::StringImm>();
+    auto* str_name = x.as<ir::StringImmNode>();
     CHECK(str_name) << "pass name must be str";
     if (str_name->value == pass_name) return true;
   }
@@ -415,7 +415,7 @@ Module SequentialNode::operator()(const Module& module,
     if (!PassEnabled(pass_info))  continue;
     // resolve dependencies
     for (const auto& it : pass_info->required) {
-      const auto* name = it.as<tvm::ir::StringImm>();
+      const auto* name = it.as<tvm::ir::StringImmNode>();
       CHECK(name);
       mod = GetPass(name->value)(mod, pass_ctx);
     }
@@ -461,7 +461,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
   p->stream << "opt_level: " << node->opt_level;
   p->stream << "required passes: [" << "\n";
   for (const auto& it : node->required) {
-    const auto* str = it.as<tvm::ir::StringImm>();
+    const auto* str = it.as<tvm::ir::StringImmNode>();
     p->stream << str->value << ", ";
   }
   p->stream << "]\n";
index 5b2c3ae..801fc17 100644 (file)
@@ -208,7 +208,7 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
 inline bool IsScalar(const Expr& expr) {
   if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
     for (auto dim_index_expr : tensor_type->shape) {
-      if (auto dim_index = dim_index_expr.as<IntImm>()) {
+      if (auto dim_index = dim_index_expr.as<IntImmNode>()) {
         if (dim_index->value != 1) {
           return false;
         }
index 5e67085..108edfc 100644 (file)
@@ -188,7 +188,7 @@ Pass SimplifyInference() {
     return Downcast<Function>(SimplifyInference(f));
   };
   return CreateFunctionPass(pass_func, 0, "SimplifyInference",
-                            {ir::StringImm::make("InferType")});
+                            {ir::StringImmNode::make("InferType")});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
index 221f2c1..1f47b20 100644 (file)
@@ -182,22 +182,22 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       return Any::make();
     }
 
-    auto left_index0 = ulhs.as<tvm::Variable>();
-    auto right_index0 = urhs.as<tvm::IntImm>();
+    auto left_index0 = ulhs.as<tvm::VarNode>();
+    auto right_index0 = urhs.as<tvm::IntImmNode>();
     if (left_index0 && right_index0) {
       solver_->shape_uf_.Set(ulhs, urhs);
       return urhs;
     }
 
-    auto left_index1 = ulhs.as<tvm::IntImm>();
-    auto right_index1 = urhs.as<tvm::Variable>();
+    auto left_index1 = ulhs.as<tvm::IntImmNode>();
+    auto right_index1 = urhs.as<tvm::VarNode>();
     if (left_index1 && right_index1) {
       solver_->shape_uf_.Set(urhs, ulhs);
       return ulhs;
     }
 
-    auto left_index2 = ulhs.as<tvm::IntImm>();
-    auto right_index2 = urhs.as<tvm::IntImm>();
+    auto left_index2 = ulhs.as<tvm::IntImmNode>();
+    auto right_index2 = urhs.as<tvm::IntImmNode>();
     if (left_index2 && right_index2 && left_index2->value == right_index2->value) {
       return ulhs;
     }
index f5cd6c1..6c8df8b 100644 (file)
@@ -38,7 +38,7 @@ class ElemWiseDetector : public ir::ExprVisitor {
     ExprVisitor::VisitExpr(e);
   }
 
-  void VisitExpr_(const Call* op) final {
+  void VisitExpr_(const CallNode* op) final {
     Array<Expr> axis = op->args;
     if (axis_.size() != axis.size()) {
       is_elem_wise_ = false;
index 7cf5cff..ce2397b 100644 (file)
@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
   Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
   // The parent set.
   for (const Operation& op : consumers) {
-    std::unordered_map<const Variable*, IntSet> relax_set;
+    std::unordered_map<const VarNode*, IntSet> relax_set;
     std::unordered_map<IterVar, IntSet> up_state;
     bool found_attach = false;
     CHECK(ctx.op2stage_.count(op.get()));
@@ -188,7 +188,7 @@ void InferRootBound(const Stage& stage,
     // Get the domain of the consumer
     PassUpDomain(op_stage, *rmap, &up_state);
     // Relax if needed.
-    std::unordered_map<const Variable*, IntSet> dom_map;
+    std::unordered_map<const VarNode*, IntSet> dom_map;
     arith::Analyzer analyzer;
     for (auto iv : op->root_iter_vars()) {
       Range r;
index a5ed436..82ee8ff 100644 (file)
@@ -37,7 +37,7 @@ struct TensorDimKey {
   int value_index;
   int dim;
   TensorDimKey() {}
-  TensorDimKey(const ir::Call* op, int dim)
+  TensorDimKey(const ir::CallNode* op, int dim)
       : f(op->func), value_index(op->value_index), dim(dim) {
   }
   TensorDimKey(const Tensor& t, int dim)
@@ -263,13 +263,13 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
         reach[TensorDimKey(t, i)] = {};
       }
       auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) {
-        const ir::Call *call = n.as<ir::Call>();
+        const ir::CallNode *call = n.as<ir::CallNode>();
         if (call != nullptr && call->func.defined()) {
           if (!bset.count(call->func.get())) return;
           for (size_t i = 0; i < call->args.size(); ++i) {
             TensorDimKey dkey(call, static_cast<int>(i));
             auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) {
-              const Variable *v = node.as<Variable>();
+              const VarNode *v = node.as<VarNode>();
               auto it = vmap.find(v);
               if (it != vmap.end()) {
                 reach[it->second].push_back(dkey);
@@ -353,7 +353,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
       }
       auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
           const ObjectRef& n) {
-        const ir::Call *call = n.as<ir::Call>();
+        const ir::CallNode *call = n.as<ir::CallNode>();
         if (call != nullptr && call->func.defined()) {
           for (size_t i = 0; i < call->args.size(); ++i) {
             auto it = vmap.find(call->args[i].get());
index f917e7f..d08b4be 100644 (file)
@@ -501,7 +501,7 @@ std::vector<Expr> MakeBoundCheck(
   PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
 
   std::vector<Expr> preds;
-  std::unordered_map<const Variable*, IntSet> iset_dmap;
+  std::unordered_map<const VarNode*, IntSet> iset_dmap;
 
   // setup domain map for set analysis
   for (const auto& kv : dom_map) {
index 9aef563..a6500ca 100644 (file)
@@ -45,9 +45,9 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) {
 class VarReplacer : public ir::StmtExprMutator {
  public:
   explicit VarReplacer(
-      const std::unordered_map<const Variable*, Expr>& vsub)
+      const std::unordered_map<const VarNode*, Expr>& vsub)
       : vsub_(vsub) {}
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = vsub_.find(op);
     if (it != vsub_.end()) return it->second;
     return GetRef<Expr>(op);
@@ -71,14 +71,14 @@ class VarReplacer : public ir::StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const ir::Reduce* op) final {
+  Expr VisitExpr_(const ir::ReduceNode* op) final {
     Expr new_e = StmtExprMutator::VisitExpr_(op);
-    const ir::Reduce* new_reduce = new_e.as<ir::Reduce>();
+    const ir::ReduceNode* new_reduce = new_e.as<ir::ReduceNode>();
     ir::CommReducer new_combiner = MutateCommReducer(op->combiner);
     if (op->combiner.same_as(new_combiner)) {
       return new_e;
     } else {
-      return ir::Reduce::make(
+      return ir::ReduceNode::make(
         new_combiner,
         new_reduce->source,
         new_reduce->axis,
@@ -88,21 +88,21 @@ class VarReplacer : public ir::StmtExprMutator {
   }
 
  private:
-  const std::unordered_map<const Variable*, Expr>& vsub_;
+  const std::unordered_map<const VarNode*, Expr>& vsub_;
 };
 
 Expr InjectPredicate(const Array<Expr>& predicates,
                      Expr body) {
-  using ir::Reduce;
-  using ir::Select;
+  using ir::ReduceNode;
+  using ir::SelectNode;
   if (predicates.size() == 0) return body;
-  const Reduce* reduce = body.as<Reduce>();
+  const ReduceNode* reduce = body.as<ReduceNode>();
   if (reduce) {
-    auto n = make_object<Reduce>(*reduce);
-    n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates, Expr());
+    auto n = make_object<ReduceNode>(*reduce);
+    n->condition = n->condition && arith::ComputeReduce<ir::AndNode>(predicates, Expr());
     return Expr(n);
   }
-  return Select::make(arith::ComputeReduce<ir::And>(predicates, Expr()),
+  return SelectNode::make(arith::ComputeReduce<ir::AndNode>(predicates, Expr()),
                       body,
                       make_zero(body.dtype()));
 }
@@ -130,7 +130,7 @@ void ReplaceDataFlow(const Array<Stage>& stages,
   }
 }
 
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+inline bool ReduceEqual(const ir::ReduceNode* a, const ir::ReduceNode* b) {
   return (a->combiner.same_as(b->combiner)) &&
          (a->source.same_as(b->source)) &&
          (a->axis.same_as(b->axis)) &&
@@ -193,8 +193,8 @@ void PrepareAxisMapping(Stage orig_stage,
                         std::unordered_set<IterVar>* p_red_axis,
                         Array<IterVar>* p_new_axis,
                         std::unordered_map<IterVar, Range>* p_dom_map,
-                        std::unordered_map<const Variable*, Expr>* p_vsub,
-                        std::unordered_map<const Variable*, Expr>* p_vsub2newvar,
+                        std::unordered_map<const VarNode*, Expr>* p_vsub,
+                        std::unordered_map<const VarNode*, Expr>* p_vsub2newvar,
                         std::vector<Expr>* p_predicates) {
   auto& red_axis = *p_red_axis;
   auto& new_axis = *p_new_axis;
@@ -305,8 +305,8 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
   Array<IterVar> new_axis;
   std::unordered_map<IterVar, Range> dom_map;
 
-  std::unordered_map<const Variable*, Expr> vsub;
-  std::unordered_map<const Variable*, Expr> vsub2newvar;
+  std::unordered_map<const VarNode*, Expr> vsub;
+  std::unordered_map<const VarNode*, Expr> vsub2newvar;
   std::vector<Expr> predicates;
 
   PrepareAxisMapping(orig_stage, compute,
@@ -314,18 +314,18 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch,
 
   Expr body;
   Array<Expr> body_list;
-  const ir::Reduce* first_reduce = nullptr;
+  const ir::ReduceNode* first_reduce = nullptr;
   for (auto cbody : compute->body) {
     body = VarReplacer(vsub)(cbody);
     body = InjectPredicate(predicates, body);
     body = VarReplacer(vsub2newvar)(body);
     // Reduce nodes in ONE computeOp must be the same except value_index
     // This is right only if the original body ensures Reduce nodes are the same
-    if (body->IsInstance<ir::Reduce>()) {
-      const ir::Reduce* reduce_body = body.as<ir::Reduce>();
+    if (body->IsInstance<ir::ReduceNode>()) {
+      const ir::ReduceNode* reduce_body = body.as<ir::ReduceNode>();
       if (first_reduce != nullptr) {
         CHECK(ReduceEqual(reduce_body, first_reduce));
-        body = ir::Reduce::make(first_reduce->combiner,
+        body = ir::ReduceNode::make(first_reduce->combiner,
                                 first_reduce->source,
                                 first_reduce->axis,
                                 first_reduce->condition,
@@ -386,8 +386,8 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch,
   Array<IterVar> new_axis;
   std::unordered_map<IterVar, Range> dom_map;
 
-  std::unordered_map<const Variable*, Expr> vsub;
-  std::unordered_map<const Variable*, Expr> vsub2newvar;
+  std::unordered_map<const VarNode*, Expr> vsub;
+  std::unordered_map<const VarNode*, Expr> vsub2newvar;
   std::vector<Expr> predicates;
 
   PrepareAxisMapping(orig_stage, tensor_op,
@@ -573,25 +573,25 @@ void InjectInline(ScheduleNode* sch) {
           if (!new_body[j].size()) {
             new_body[j] = compute->body;
           }
-          if (new_body[j][0]->IsInstance<ir::Reduce>()) {
+          if (new_body[j][0]->IsInstance<ir::ReduceNode>()) {
             // specially handle reduction inline for multiplre reductions.
-            const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>();
+            const ir::ReduceNode* reduce = new_body[j][0].as<ir::ReduceNode>();
             for (size_t k = 1; k < new_body[j].size(); ++k) {
-              const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>();
+              const ir::ReduceNode* reduce_ = new_body[j][k].as<ir::ReduceNode>();
               CHECK(reduce_);
               CHECK(ReduceEqual(reduce_, reduce))
                   << "The Reduce inputs of ComputeOp should "
                   << "have the same attribute except value_index";
             }
-            Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]),
-                                        stage->op, args, body).as<ir::Evaluate>()->value;
+            Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][0]),
+                                        stage->op, args, body).as<ir::EvaluateNode>()->value;
             if (!new_value.same_as(new_body[j][0])) {
               changed[j] = true;
-              const ir::Reduce* r = new_value.as<ir::Reduce>();
+              const ir::ReduceNode* r = new_value.as<ir::ReduceNode>();
               CHECK_EQ(new_body[j].size(), r->source.size());
               CHECK(r != nullptr);
               for (size_t k = 0; k < new_body[j].size(); ++k) {
-                auto n = make_object<ir::Reduce>(*r);
+                auto n = make_object<ir::ReduceNode>(*r);
                 n->value_index = static_cast<int>(k);
                 n->dtype = r->source[k].dtype();
                 new_body[j].Set(k, Expr(n));
@@ -599,8 +599,8 @@ void InjectInline(ScheduleNode* sch) {
             }
           } else {
             for (size_t k = 0; k < new_body[j].size(); ++k) {
-              Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]),
-                                          stage->op, args, body).as<ir::Evaluate>()->value;
+              Expr new_value = ir::Inline(ir::EvaluateNode::make(new_body[j][k]),
+                                          stage->op, args, body).as<ir::EvaluateNode>()->value;
               if (!new_value.same_as(new_body[j][k])) {
                 new_body[j].Set(k, new_value);
                 changed[j] = true;
@@ -677,7 +677,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
                                 const IterVar& axis,
                                 int factor_axis) {
   (*this)->InvalidateCache();
-  using ir::Reduce;
+  using ir::ReduceNode;
   CHECK_EQ(axis->iter_type, kCommReduce)
       << "Can only factor reduction axis";
   Stage reduce_stage = operator[](tensor->op);
@@ -758,12 +758,12 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   }
   // predicate generation, copy not touched axis.
   int idx = tensor->value_index;
-  const Reduce* reduce = compute_op->body[idx].as<Reduce>();
+  const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   predicates.push_back(reduce->condition);
-  Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
+  Expr predicate = likely(arith::ComputeReduce<ir::AndNode>(predicates, Expr()));
 
-  std::unordered_map<const Variable*, Expr> vsub;
+  std::unordered_map<const VarNode*, Expr> vsub;
 
   for (IterVar iv : compute_op->reduce_axis) {
     if (!touch_map.count(iv)) {
@@ -792,7 +792,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
 
   std::vector<Expr> body;
   for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
-    body.emplace_back(Reduce::make(reduce->combiner,
+    body.emplace_back(ReduceNode::make(reduce->combiner,
                                    new_source,
                                    n->reduce_axis,
                                    new_pred,
@@ -861,7 +861,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
       Array<IterVar> axis = {repl_red_axis};
       Expr cond = const_true();
       for (int idx = 0; idx < size; ++idx) {
-        reductions.push_back(Reduce::make(reduce->combiner,
+        reductions.push_back(ReduceNode::make(reduce->combiner,
           factor_exprs, axis, cond, idx));
       }
       return reductions;
index be42513..a53c1ae 100644 (file)
@@ -408,7 +408,7 @@ Stage& Stage::pragma(IterVar var,
   } else {
     UpdateIterVarAttr(
         operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) {
-          n->pragma_keys.push_back(ir::StringImm::make(pragma_type));
+          n->pragma_keys.push_back(ir::StringImmNode::make(pragma_type));
           n->pragma_values.push_back(pragma_value);
         });
   }
index 2d49452..38174df 100644 (file)
@@ -43,28 +43,28 @@ Stmt MakePipeline(const Stage& s,
                   bool debug_keep_trivial_loop) {
   Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
   if (producer.defined()) {
-    producer = ProducerConsumer::make(s->op, true, producer);
+    producer = ProducerConsumerNode::make(s->op, true, producer);
   }
   if (s->double_buffer) {
-    producer = AttrStmt::make(
+    producer = AttrStmtNode::make(
         s->op, ir::attr::double_buffer_scope, 1, producer);
   }
   Stmt pipeline = producer;
 
   if (consumer.defined() && !is_no_op(consumer)) {
-    consumer = ProducerConsumer::make(s->op, false, consumer);
+    consumer = ProducerConsumerNode::make(s->op, false, consumer);
     pipeline = SeqStmt({producer, consumer});
   }
   pipeline = s->op->BuildRealize(s, dom_map, pipeline);
   // use attribute to mark scope of the operation.
-  pipeline = AttrStmt::make(
+  pipeline = AttrStmtNode::make(
       s->op, ir::attr::realize_scope,
-      StringImm::make(s->scope),
+      StringImmNode::make(s->scope),
       pipeline);
 
   if (s->is_opengl) {
-    pipeline = AttrStmt::make(
-        s->op, ir::attr::opengl_stage_scope, StringImm::make(""), pipeline);
+    pipeline = AttrStmtNode::make(
+        s->op, ir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline);
   }
   return pipeline;
 }
@@ -82,7 +82,7 @@ class InjectAttach : public StmtMutator {
   Stmt VisitStmt(const Stmt& input_stmt) final {
     CHECK(input_stmt.defined());
     auto stmt = StmtMutator::VisitStmt(input_stmt);
-    const AttrStmt* op = stmt.as<AttrStmt>();
+    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
     if (op != nullptr &&
         op->attr_key == attr::loop_scope) {
       if (attach_spec_->attach_type == kScope &&
@@ -91,7 +91,7 @@ class InjectAttach : public StmtMutator {
             << "Find IterVar" << attach_spec_->attach_ivar
             << " in multiple places in the IR";
         found_attach = true;
-        stmt = AttrStmt::make(
+        stmt = AttrStmtNode::make(
             op->node, op->attr_key, op->value,
             MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
       }
@@ -128,13 +128,13 @@ class InjectScanStep : public StmtMutator {
     CHECK(input_stmt.defined());
     auto stmt = StmtMutator::VisitStmt(input_stmt);
     // update
-    const AttrStmt* op = stmt.as<AttrStmt>();
+    const AttrStmtNode* op = stmt.as<AttrStmtNode>();
     if (op != nullptr &&
         ((op->attr_key == attr::scan_update_scope && !is_init_) ||
          (op->attr_key == attr::scan_init_scope && is_init_))) {
       if (op->node.same_as(scan_op_)) {
         found_attach = true;
-        stmt = AttrStmt::make(
+        stmt = AttrStmtNode::make(
             op->node, op->attr_key, op->value,
             MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
       }
@@ -162,12 +162,12 @@ class InjectScanStep : public StmtMutator {
 // Replace the init and update's expression by scan's buffer.
 class SchedulePostProc : public StmtExprMutator {
  public:
-  Stmt VisitStmt_(const ProducerConsumer* op) final {
+  Stmt VisitStmt_(const ProducerConsumerNode* op) final {
     auto it = replace_op_.find(op->func.get());
     if (it != replace_op_.end()) {
       Stmt body = this->VisitStmt(op->body);
       if (it->second.defined()) {
-        return ProducerConsumer::make(
+        return ProducerConsumerNode::make(
             it->second, op->is_producer, body);
       } else {
         return body;
@@ -176,7 +176,7 @@ class SchedulePostProc : public StmtExprMutator {
       return StmtExprMutator::VisitStmt_(op);
     }
   }
-  Stmt VisitStmt_(const LetStmt* op) final {
+  Stmt VisitStmt_(const LetStmtNode* op) final {
     if (!HasSideEffect(op->value)) {
       var_value_[op->var.get()] = this->VisitExpr(op->value);
       return this->VisitStmt(op->body);
@@ -185,7 +185,7 @@ class SchedulePostProc : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const AttrStmt* op) final {
+  Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::loop_scope ||
         op->attr_key == attr::scan_init_scope) {
       return this->VisitStmt(op->body);
@@ -211,7 +211,7 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(op->node.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          Stmt ret = AttrStmt::make(
+          Stmt ret = AttrStmtNode::make(
               it->second, op->attr_key, op->value, op->body);
           return this->VisitStmt(ret);
         } else {
@@ -224,7 +224,7 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(tensor->op.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          return AttrStmt::make(
+          return AttrStmtNode::make(
               Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
               op->attr_key, op->value, this->VisitStmt(op->body));
         } else {
@@ -236,7 +236,7 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(tensor->op.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          return AttrStmt::make(
+          return AttrStmtNode::make(
               it->second.output(tensor->value_index),
               op->attr_key, op->value, this->VisitStmt(op->body));
         } else {
@@ -247,12 +247,12 @@ class SchedulePostProc : public StmtExprMutator {
     return StmtExprMutator::VisitStmt_(op);
   }
 
-  Stmt VisitStmt_(const Realize* op) final {
+  Stmt VisitStmt_(const RealizeNode* op) final {
     TensorKey key{op->func, op->value_index};
     auto it = replace_realize_.find(key);
     if (it != replace_realize_.end()) {
       if (it->second.defined()) {
-        Stmt ret = Realize::make(
+        Stmt ret = RealizeNode::make(
             it->second->op, it->second->value_index,
             op->dtype, op->bounds, op->condition, op->body);
         return this->VisitStmt(ret);
@@ -264,12 +264,12 @@ class SchedulePostProc : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const Provide* op) final {
+  Stmt VisitStmt_(const ProvideNode* op) final {
     TensorKey key{op->func, op->value_index};
     auto it = replace_buffer_.find(key);
     if (it != replace_buffer_.end()) {
       const Tensor& dst = it->second;
-      Stmt ret = Provide::make(
+      Stmt ret = ProvideNode::make(
           dst->op, dst->value_index, op->value, op->args);
       return this->VisitStmt(ret);
     } else {
@@ -277,13 +277,13 @@ class SchedulePostProc : public StmtExprMutator {
     }
   }
 
-  Expr VisitExpr_(const Call* op) final {
-    if (op->call_type == Call::Halide) {
+  Expr VisitExpr_(const CallNode* op) final {
+    if (op->call_type == CallNode::Halide) {
       TensorKey key{op->func, op->value_index};
       auto it = replace_buffer_.find(key);
       if (it != replace_buffer_.end()) {
         const Tensor& dst = it->second;
-        Expr ret = Call::make(
+        Expr ret = CallNode::make(
             op->dtype, dst->op->name, op->args,
             op->call_type, dst->op, dst->value_index);
         return this->VisitExpr(ret);
@@ -292,7 +292,7 @@ class SchedulePostProc : public StmtExprMutator {
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Expr VisitExpr_(const Variable* op) final {
+  Expr VisitExpr_(const VarNode* op) final {
     auto it = var_value_.find(op);
     if (it != var_value_.end()) {
       return it->second;
@@ -345,7 +345,7 @@ class SchedulePostProc : public StmtExprMutator {
   // The thread extent scope.
   std::unordered_map<const Object*, Expr> thread_extent_scope_;
   // The scan value
-  std::unordered_map<const Variable*, Expr> var_value_;
+  std::unordered_map<const VarNode*, Expr> var_value_;
   // buffer replacement
   std::unordered_map<TensorKey, Tensor> replace_buffer_;
   // buffere realization to be replaced
index cb0233f..a6010c3 100644 (file)
@@ -79,7 +79,7 @@ TEST(Attrs, Basic) {
   n->InitBySeq("name", "xxx", "expr", 128);
   CHECK_EQ(n->name, "xxx");
   CHECK_EQ(n->axis, 10);
-  CHECK_EQ(n->expr.as<tvm::ir::IntImm>()->value, 128);
+  CHECK_EQ(n->expr.as<tvm::ir::IntImmNode>()->value, 128);
   // Check docstring
   std::ostringstream os;
   n->PrintDocString(os);
index 7aab3ed..3d7c355 100644 (file)
@@ -169,7 +169,7 @@ TEST(Array, Iterator) {
   using namespace tvm;
   Array<Expr> array{1, 2, 3};
   std::vector<Expr> vector(array.begin(), array.end());
-  CHECK(vector[1].as<IntImm>()->value == 2);
+  CHECK(vector[1].as<IntImmNode>()->value == 2);
 }
 
 TEST(Map, Expr) {
@@ -222,7 +222,7 @@ TEST(Map, Iterator) {
   Map<Expr, Expr> map1{{a, b}};
   std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>
       map2(map1.begin(), map1.end());
-  CHECK(map2[a].as<IntImm>()->value == 2);
+  CHECK(map2[a].as<IntImmNode>()->value == 2);
 }
 
 int main(int argc, char** argv) {
index debfb36..4b6915f 100644 (file)
@@ -38,7 +38,7 @@ TEST(ExprNodeRef, Basic) {
   using namespace tvm;
   Var x("x");
   Expr z = max(x + 1 + 2, 100);
-  const ir::Max* op = z.as<ir::Max>();
+  const ir::MaxNode* op = z.as<ir::MaxNode>();
   CHECK(GetRef<ObjectRef>(op).same_as(z));
 }
 
index a37f6f9..23a81b9 100644 (file)
@@ -31,10 +31,10 @@ TEST(IRF, Basic) {
   auto z = x + 1;
 
   NodeFunctor<int(const ObjectRef& n, int b)> f;
-  f.set_dispatch<Variable>([](const ObjectRef& n, int b) {
+  f.set_dispatch<VarNode>([](const ObjectRef& n, int b) {
       return b;
     });
-  f.set_dispatch<Add>([](const ObjectRef& n, int b) {
+  f.set_dispatch<AddNode>([](const ObjectRef& n, int b) {
       return b + 2;
     });
   CHECK_EQ(f(x, 2),  2);
@@ -48,7 +48,7 @@ TEST(IRF, CountVar) {
 
   auto z = x + 1 + y + y;
   ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) {
-      if (n.as<Variable>()) ++n_var;
+      if (n.as<VarNode>()) ++n_var;
     });
   CHECK_EQ(n_var, 2);
 }
@@ -63,13 +63,13 @@ TEST(IRF, ExprTransform) {
   class MyExprFunctor
       : public ir::ExprFunctor<int(const Expr&, int)> {
    public:
-    int VisitExpr_(const Variable* op, int b) final {
+    int VisitExpr_(const VarNode* op, int b) final {
       return b;
     }
-    int VisitExpr_(const IntImm* op, int b) final {
+    int VisitExpr_(const IntImmNode* op, int b) final {
       return op->value;
     }
-    int VisitExpr_(const Add* op, int b) final {
+    int VisitExpr_(const AddNode* op, int b) final {
       return VisitExpr(op->a, b) + VisitExpr(op->b, b);
     }
   };
@@ -95,21 +95,21 @@ TEST(IRF, ExprVisit) {
    public:
     int count = 0;
     // implementation
-    void VisitExpr_(const Variable* op) final {
+    void VisitExpr_(const VarNode* op) final {
       ++count;
     }
-    void VisitExpr_(const IntImm* op) final {
+    void VisitExpr_(const IntImmNode* op) final {
     }
-    void VisitExpr_(const Add* op) final {
+    void VisitExpr_(const AddNode* op) final {
       VisitExpr(op->a);
       VisitExpr(op->b);
     }
-    void VisitStmt_(const Evaluate* op) final {
+    void VisitStmt_(const EvaluateNode* op) final {
       VisitExpr(op->value);
     }
   };
   MyVisitor v;
-  v.VisitStmt(Evaluate::make(z));
+  v.VisitStmt(EvaluateNode::make(z));
   CHECK_EQ(v.count, 1);
 }
 
@@ -123,16 +123,16 @@ TEST(IRF, StmtVisitor) {
    public:
     int count = 0;
     // implementation
-    void VisitExpr_(const Variable* op) final {
+    void VisitExpr_(const VarNode* op) final {
       ++count;
     }
   };
   MyVisitor v;
   auto fmaketest = [&]() {
     auto z = x + 1;
-    Stmt body = Evaluate::make(z);
+    Stmt body = EvaluateNode::make(z);
     Var buffer("b", DataType::Handle());
-    return Allocate::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
+    return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
   };
   v(fmaketest());
   CHECK_EQ(v.count, 3);
@@ -152,7 +152,7 @@ TEST(IRF, StmtMutator) {
 
    protected:
     // implementation
-    Expr VisitExpr_(const Add* op) final {
+    Expr VisitExpr_(const AddNode* op) final {
       return op->a;
     }
     Stmt VisitStmt_(const SeqStmtNode* op) final {
@@ -164,34 +164,34 @@ TEST(IRF, StmtMutator) {
   };
   auto fmakealloc = [&]() {
     auto z = x + 1;
-    Stmt body = Evaluate::make(z);
+    Stmt body = EvaluateNode::make(z);
     Var buffer("b", DataType::Handle());
-    return Allocate::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
+    return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
   };
 
   auto fmakeif = [&]() {
     auto z = x + 1;
-    Stmt body = Evaluate::make(z);
-    return IfThenElse::make(x < 0, Evaluate::make(0), body);
+    Stmt body = EvaluateNode::make(z);
+    return IfThenElseNode::make(x, EvaluateNode::make(0), body);
   };
 
   MyVisitor v;
   {
     auto body = fmakealloc();
-    Stmt body2 = Evaluate::make(1);
-    Stmt bref = body.as<Allocate>()->body;
-    auto* extentptr = body.as<Allocate>()->extents.get();
+    Stmt body2 = EvaluateNode::make(1);
+    Stmt bref = body.as<AllocateNode>()->body;
+    auto* extentptr = body.as<AllocateNode>()->extents.get();
     Array<Stmt> arr{std::move(body), body2, body2};
     auto* arrptr = arr.get();
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
     CHECK(arr.get() == arrptr);
     // inplace update body
-    CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
-    CHECK(arr[0].as<Allocate>()->extents.get() == extentptr);
+    CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
+    CHECK(arr[0].as<AllocateNode>()->extents.get() == extentptr);
     // copy because there is additional refs
-    CHECK(!arr[0].as<Allocate>()->body.same_as(bref));
-    CHECK(arr[0].as<Allocate>()->body.as<Evaluate>()->value.same_as(x));
-    CHECK(bref.as<Evaluate>()->value.as<Add>());
+    CHECK(!arr[0].as<AllocateNode>()->body.same_as(bref));
+    CHECK(arr[0].as<AllocateNode>()->body.as<EvaluateNode>()->value.same_as(x));
+    CHECK(bref.as<EvaluateNode>()->value.as<AddNode>());
   }
   {
     Array<Stmt> arr{fmakealloc()};
@@ -200,8 +200,8 @@ TEST(IRF, StmtMutator) {
     auto* arrptr = arr.get();
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
     CHECK(arr.get() != arrptr);
-    CHECK(arr[0].as<Allocate>()->extents[1].same_as(x));
-    CHECK(!arr2[0].as<Allocate>()->extents[1].same_as(x));
+    CHECK(arr[0].as<AllocateNode>()->extents[1].same_as(x));
+    CHECK(!arr2[0].as<AllocateNode>()->extents[1].same_as(x));
     // mutate but no content change.
     arr2 = arr;
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
@@ -210,7 +210,7 @@ TEST(IRF, StmtMutator) {
   {
     Array<Stmt> arr{fmakeif()};
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
-    CHECK(arr[0].as<IfThenElse>()->else_case.as<Evaluate>()->value.same_as(x));
+    CHECK(arr[0].as<IfThenElseNode>()->else_case.as<EvaluateNode>()->value.same_as(x));
     // mutate but no content change.
     auto arr2 = arr;
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
@@ -218,15 +218,15 @@ TEST(IRF, StmtMutator) {
   }
 
   {
-    auto body = Evaluate::make(Call::make(DataType::Int(32), "xyz", {x + 1}, Call::Extern));
+    auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
     auto res = v(std::move(body));
-    CHECK(res.as<Evaluate>()->value.as<Call>()->args[0].same_as(x));
+    CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
   }
   {
     auto body = fmakealloc();
-    Stmt body2 = Evaluate::make(1);
+    Stmt body2 = EvaluateNode::make(1);
     auto* ref2 = body2.get();
-    auto* extentptr = body.as<Allocate>()->extents.get();
+    auto* extentptr = body.as<AllocateNode>()->extents.get();
     // construct a recursive SeqStmt.
     body = SeqStmt({body});
     body = SeqStmt({body, body2});
@@ -234,22 +234,22 @@ TEST(IRF, StmtMutator) {
     body = v(std::move(body));
     // the seq get flattened
     CHECK(body.as<SeqStmtNode>()->size() == 3);
-    CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() == extentptr);
+    CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() == extentptr);
     CHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
   }
 
   {
     // Cannot cow because of bref
     auto body = fmakealloc();
-    Stmt body2 = Evaluate::make(1);
-    auto* extentptr = body.as<Allocate>()->extents.get();
+    Stmt body2 = EvaluateNode::make(1);
+    auto* extentptr = body.as<AllocateNode>()->extents.get();
     // construct a recursive SeqStmt.
     body = SeqStmt({body});
     auto bref = body;
     body = SeqStmt({body, body2});
     body = v(std::move(body));
     // the seq get flattened
-    CHECK(body.as<SeqStmtNode>()->seq[0].as<Allocate>()->extents.get() != extentptr);
+    CHECK(body.as<SeqStmtNode>()->seq[0].as<AllocateNode>()->extents.get() != extentptr);
   }
 }
 
index 57d7d50..6b694ef 100644 (file)
@@ -46,7 +46,7 @@ TEST(IRSIMPLIFY, Mod) {
   // Mod::make is used instead of % to avoid constant folding during
   // calling operator%(x,y). Mod::make doesn't try constant folding,
   // and therefore, the constant folding will be attempted in CanonicalSimplify
-  auto mod = tvm::ir::CanonicalSimplify(tvm::ir::Mod::make(x, y));
+  auto mod = tvm::ir::CanonicalSimplify(tvm::ir::ModNode::make(x, y));
   auto es = tvm::ir::CanonicalSimplify(mod - x);
   CHECK(is_zero(es));
 }
index dd9ef3b..47cd000 100644 (file)
@@ -26,9 +26,9 @@ TEST(IRSSA, Convert) {
   using namespace tvm;
   using namespace tvm::ir;
   Var x("x"), y;
-  Expr let = Let::make(x, 1, x + 1);
+  Expr let = LetNode::make(x, 1, x + 1);
 
-  auto z = Evaluate::make(let + let);
+  auto z = EvaluateNode::make(let + let);
   CHECK(!ir::VerifySSA(z));
   auto z_ssa = ir::ConvertSSA(z);
   CHECK(ir::VerifySSA(z_ssa));
@@ -38,7 +38,7 @@ TEST(IRSSA, Basic) {
   using namespace tvm::ir;
   using namespace tvm;
   Var x("x"), y;
-  auto z = Evaluate::make(x + y);
+  auto z = EvaluateNode::make(x + y);
   CHECK(ir::VerifySSA(z));
 }
 
index fba8e10..24ed6d8 100644 (file)
@@ -131,7 +131,7 @@ TEST(PackedFunc, Expr) {
   // automatic conversion of int to expr
   PackedFunc addone([](TVMArgs args, TVMRetValue* rv) {
       Expr x = args[0];
-      *rv = x.as<tvm::ir::IntImm>()->value + 1;
+      *rv = x.as<tvm::ir::IntImmNode>()->value + 1;
   });
   int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
       PackedFunc f = args[0];
index 9710428..2b03454 100644 (file)
@@ -62,7 +62,7 @@ TEST(Pattern, Basic) {
   CHECK((!(px > py || px != py)).Match(!(x > y || x != y)));
   {
     CHECK(select(px >= pz, py, py + pz).Match(
-        ir::Select::make((x + 1) >= 1, y, y + 1)));
+        ir::SelectNode::make((x + 1) >= 1, y, y + 1)));
     CHECK(ir::Equal(px.Eval(), x + 1));
   }
   // bit intrinsics
@@ -79,16 +79,16 @@ TEST(Pattern, Basic) {
   // select
   {
     CHECK(select(px > pz, py, py + pz).Match(
-      ir::Select::make(x > 1, y, y + 1)));
+      ir::SelectNode::make(x > 1, y, y + 1)));
     CHECK(is_const_int(pz.Eval(), 1));
   }
   CHECK(!select(px > pz, py, py + pz).Match(
-      ir::Select::make(x > 2, y, y + 1)));
+      ir::SelectNode::make(x > 2, y, y + 1)));
   CHECK(!select(px > pz, py, py).Match(
-      ir::Select::make(x > 2, y, y + 1)));
+      ir::SelectNode::make(x > 2, y, y + 1)));
   {
     CHECK(select(px, py, pz).Match(
-        ir::Select::make(x > 2, y, y + 1)));
+        ir::SelectNode::make(x > 2, y, y + 1)));
     CHECK(ir::Equal(pz.Eval(), y + 1));
   }
   // if_then_else
@@ -100,30 +100,30 @@ TEST(Pattern, Basic) {
   // cast pattern
   {
     CHECK(!cast(PConst<DataType>(
-        DataType::Int(32)), px).Match(ir::Cast::make(DataType::Float(64), x)));
-    CHECK(cast(pt, px).Match(ir::Cast::make(DataType::Float(64), x)));
+        DataType::Int(32)), px).Match(ir::CastNode::make(DataType::Float(64), x)));
+    CHECK(cast(pt, px).Match(ir::CastNode::make(DataType::Float(64), x)));
     CHECK(pt.Eval() == DataType::Float(64));
     auto zz = cast(pt, px).Eval();
     CHECK((cast(pt, px) - cast(pt, py)).Match(
-        ir::Cast::make(DataType::Float(64), x) - ir::Cast::make(DataType::Int(64), x)));
-    auto expr = ir::Cast::make(DataType::Int(32), ir::Cast::make(DataType::Float(64), x));
+        ir::CastNode::make(DataType::Float(64), x) - ir::CastNode::make(DataType::Int(64), x)));
+    auto expr = ir::CastNode::make(DataType::Int(32), ir::CastNode::make(DataType::Float(64), x));
     CHECK(!(cast(pt, cast(pt, px))).Match(expr));
   }
   // ramp pattern
   {
     CHECK(ramp(px, PConst<Expr>(1), planes).Match(
-        ir::Ramp::make(x, 1, 10)));
+        ir::RampNode::make(x, 1, 10)));
     CHECK(planes.Eval() == 10);
     CHECK(!ramp(px, PConst<Expr>(1), planes).Match(
-        ir::Ramp::make(x, 2, 10)));
+        ir::RampNode::make(x, 2, 10)));
   }
   // broadcast pattern
   {
     CHECK(broadcast(px, planes).Match(
-        ir::Broadcast::make(x, 10)));
+        ir::BroadcastNode::make(x, 10)));
     CHECK(planes.Eval() == 10);
     CHECK(broadcast(px * py , planes).Match(
-        ir::Broadcast::make(x * 10, 10)));
+        ir::BroadcastNode::make(x * 10, 10)));
   }
 }
 
index 4fdd186..8c5068a 100644 (file)
@@ -52,8 +52,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
   int i;
   for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
     // TODO(@icemelon9): Need to revisit this part
-    const Variable* var1 = shape1[s1_size - i].as<Variable>();
-    const Variable* var2 = shape2[s2_size - i].as<Variable>();
+    const VarNode* var1 = shape1[s1_size - i].as<VarNode>();
+    const VarNode* var2 = shape2[s2_size - i].as<VarNode>();
     bh.all_vars.push_front(tvm::Var());
     if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
       bh.common_shape.push_front(shape1[s1_size - i]);
index 6d93f9d..00db1fc 100644 (file)
@@ -43,8 +43,8 @@ using namespace tvm;
  */
 inline bool IsConstInt(Expr expr) {
   return
-    expr->IsInstance<tvm::ir::IntImm>() ||
-    expr->IsInstance<tvm::ir::UIntImm>();
+    expr->IsInstance<tvm::ir::IntImmNode>() ||
+    expr->IsInstance<tvm::ir::UIntImmNode>();
 }
 
 /*!
@@ -56,11 +56,11 @@ inline bool IsConstInt(Expr expr) {
  * \return The integer value.
  */
 inline int64_t GetConstInt(Expr expr) {
-  if (expr->IsInstance<tvm::ir::IntImm>()) {
-    return expr.as<tvm::ir::IntImm>()->value;
+  if (expr->IsInstance<tvm::ir::IntImmNode>()) {
+    return expr.as<tvm::ir::IntImmNode>()->value;
   }
-  if (expr->IsInstance<tvm::ir::UIntImm>()) {
-    return expr.as<tvm::ir::UIntImm>()->value;
+  if (expr->IsInstance<tvm::ir::UIntImmNode>()) {
+    return expr.as<tvm::ir::UIntImmNode>()->value;
   }
   LOG(ERROR) << "expr must be a constant integer";
   return -1;
index fa184bf..643b44b 100644 (file)
@@ -95,7 +95,7 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
   }
 
   auto body = fextern(input_placeholders, output_placeholders);
-  auto body_stmt = tvm::ir::Evaluate::make(body);
+  auto body_stmt = tvm::ir::EvaluateNode::make(body);
 
   auto op = ExternOpNode::make(
       name, tag, attrs, inputs,
@@ -118,12 +118,12 @@ inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
  */
 inline Expr pack_buffer(Buffer buf) {
   CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
-  auto shape = tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
-                                   buf->shape, tvm::ir::Call::CallType::Intrinsic);
+  auto shape = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
+                                   buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
   Expr strides;
   if (buf->strides.size() > 0) {
-    strides = tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
-                                  buf->shape, tvm::ir::Call::CallType::Intrinsic);
+    strides = tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
+                                  buf->shape, tvm::ir::CallNode::CallType::Intrinsic);
   } else {
     strides = 0;
   }
@@ -135,8 +135,8 @@ inline Expr pack_buffer(Buffer buf) {
     make_const(buf->dtype, 0),
     buf->elem_offset
   };
-  return tvm::ir::Call::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
-                             pack_args, tvm::ir::Call::CallType::Intrinsic);
+  return tvm::ir::CallNode::make(DataType::Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
+                             pack_args, tvm::ir::CallNode::CallType::Intrinsic);
 }
 
 /*!
@@ -149,8 +149,8 @@ inline Expr pack_buffer(Buffer buf) {
  * \return An expression representing the invocation
  */
 inline Expr call_packed(Array<Expr> args) {
-  return tvm::ir::Call::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
-                             args, tvm::ir::Call::CallType::Intrinsic);
+  return tvm::ir::CallNode::make(DataType::Int(32), tvm::ir::intrinsic::tvm_call_packed,
+                             args, tvm::ir::CallNode::CallType::Intrinsic);
 }
 
 }  // namespace detail
index ede3f88..fe23836 100644 (file)
@@ -39,7 +39,7 @@ using namespace tvm;
 inline bool is_empty_shape(const Array<Expr>& x) {
   bool is_empty = false;
   for (const auto& dim : x) {
-    if (auto int_dim = dim.as<IntImm>()) {
+    if (auto int_dim = dim.as<IntImmNode>()) {
       if (int_dim->value == 0) {
         is_empty = true;
         break;
index 15b9454..dec94f3 100644 (file)
@@ -194,8 +194,8 @@ inline Tensor sign(const Tensor& x,
     Expr zero = make_zero(x->dtype);
     Expr one = make_const(x->dtype, 1);
     Expr minus_one = make_const(x->dtype, -1);
-    auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
-    auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
+    auto s1 = tvm::ir::SelectNode::make((x(i) < zero), minus_one, zero);
+    auto s2 = tvm::ir::SelectNode::make((x(i) > zero), one, s1);
     return s2;
   }, name, tag);
 }
@@ -264,7 +264,7 @@ inline Tensor cast(const Tensor& x,
       if (expr.dtype().lanes() == type.lanes()) {
         return expr;
       } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
-        return tvm::ir::Broadcast::make(expr, type.lanes());
+        return tvm::ir::BroadcastNode::make(expr, type.lanes());
       }
     }
 
@@ -286,8 +286,8 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te
                           std::string tag = kElementWise) {
   return compute(x->shape,
                  [&](const Array<Var>& i) {
-                   return tvm::ir::Call::make(type, "reinterpret", {x(i)},
-                                              tvm::ir::Call::PureIntrinsic);
+                   return tvm::ir::CallNode::make(type, "reinterpret", {x(i)},
+                                              tvm::ir::CallNode::PureIntrinsic);
                  },
                  name, tag);
 }
index 2235fba..5920c0b 100644 (file)
@@ -94,7 +94,7 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
     [&](const tvm::Array<tvm::Var>& i) {
       auto value = t(i);
       auto calpha = tvm::make_const(value.dtype(), alpha);
-      return tvm::ir::Select::make(value > 0, value, value * calpha);
+      return tvm::ir::SelectNode::make(value > 0, value, value * calpha);
     },
     name,
     tag);
@@ -125,7 +125,7 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
   return tvm::compute(x->shape,
                      [&](const tvm::Array<tvm::Var> &indices) {
                         auto xval = x(indices);
-                        return tvm::ir::Select::make(
+                        return tvm::ir::SelectNode::make(
                             xval > 0,
                             xval,
                             xval * slope(indices[axis]));
@@ -243,10 +243,10 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
     if (sel.size() != 0) {
       if (pad_mode == "constant") {
         return tvm::if_then_else(
-            detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
+            detail::Map(sel, tvm::ir::AndNode::make), t(indices), pad_value);
       } else if (pad_mode == "edge" || pad_mode == "reflect") {
         return tvm::if_then_else(
-            detail::Map(sel, tvm::ir::And::make), t(indices), t(pad_idx));
+            detail::Map(sel, tvm::ir::AndNode::make), t(indices), t(pad_idx));
       }
     }
     return t(indices);
index c4cda6a..a6a5252 100644 (file)
@@ -155,11 +155,11 @@ inline Tensor pool_impl(const Tensor& x,
       } else {
         Expr h_start = output[height_axis] * stride_height - pad_top;
         Expr w_start = output[width_axis] * stride_width - pad_left;
-        Expr h_end = ir::Min::make(h_start + kernel_height, height);
-        Expr w_end = ir::Min::make(w_start + kernel_width, width);
-        h_start = ir::Max::make(h_start, make_const(DataType::DataType::Int(32), 0));
-        w_start = ir::Max::make(w_start, make_const(DataType::DataType::Int(32), 0));
-        Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start),
+        Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
+        Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+        h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0));
+        w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0));
+        Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
                                            make_const(DataType::DataType::Int(32), 1));
         return div(pool_sum(indices), divide_factor);
       }
@@ -265,16 +265,16 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
           out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
           out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
 
-          Expr out_idx_lower_h = ir::Select::make(
+          Expr out_idx_lower_h = ir::SelectNode::make(
               pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
               (pad_inds[height_axis] - kernel_height) / stride_height + 1);
-          Expr out_idx_lower_w = ir::Select::make(
+          Expr out_idx_lower_w = ir::SelectNode::make(
               pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
               (pad_inds[width_axis] - kernel_width) / stride_width + 1);
 
           return tvm::sum(
-              tvm::if_then_else(ir::And::make(
-                  ir::And::make(out_idx[height_axis] >= out_idx_lower_h,
+              tvm::if_then_else(ir::AndNode::make(
+                  ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
                                 out_idx[width_axis] >= out_idx_lower_w),
                   mp_inds(out_idx) == idx),
                   out_grad(out_idx), make_const(x->dtype, 0)),
@@ -295,10 +295,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
           out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
           out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
 
-          Expr out_idx_lower_h = ir::Select::make(
+          Expr out_idx_lower_h = ir::SelectNode::make(
               pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
               (pad_h_idx - kernel_height) / stride_height + 1);
-          Expr out_idx_lower_w = ir::Select::make(
+          Expr out_idx_lower_w = ir::SelectNode::make(
               pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
               (pad_w_idx - kernel_width) / stride_width + 1);
 
@@ -308,19 +308,19 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
           } else {
             Expr h_start = out_idx[height_axis] * stride_height - pad_top;
             Expr w_start = out_idx[width_axis] * stride_width - pad_left;
-            Expr h_end = ir::Min::make(h_start + kernel_height, height);
-            Expr w_end = ir::Min::make(w_start + kernel_width, width);
-            h_start = ir::Max::make(h_start, make_const(DataType::Int(32), 0));
-            w_start = ir::Max::make(w_start, make_const(DataType::Int(32), 0));
+            Expr h_end = ir::MinNode::make(h_start + kernel_height, height);
+            Expr w_end = ir::MinNode::make(w_start + kernel_width, width);
+            h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0));
+            w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0));
             divide_factor =
-                ir::Max::make((h_end - h_start) * (w_end - w_start),
+                ir::MaxNode::make((h_end - h_start) * (w_end - w_start),
                               make_const(DataType::Int(32), 1));
           }
           return tvm::sum(tvm::if_then_else(
-              ir::And::make(
-                ir::And::make(out_idx[height_axis] >= out_idx_lower_h,
+              ir::AndNode::make(
+                ir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h,
                               out_idx[height_axis] < out_height),
-                ir::And::make(out_idx[width_axis] >= out_idx_lower_w,
+                ir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w,
                               out_idx[width_axis] < out_width)),
               out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
               {windowh, windoww});
@@ -467,7 +467,7 @@ inline Expr end_index(const Var& out_index,
                       const Expr& odim,
                       const Expr& idim) {
   Expr tmp = indexdiv((out_index + 1) * idim, odim);
-  return tvm::ir::Select::make(indexmod((out_index + 1) * idim, odim) == 0,
+  return tvm::ir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0,
                                tmp, tmp + 1);
 }
 
@@ -729,12 +729,12 @@ inline Tensor pool_impl_nd(const Tensor& x,
         for (int i = 0; i < k_size; i++) {
           int ii = axis[i];
           start[i] = output[ii] * stride[i] - pad_head[i];
-          end[i] = ir::Min::make(start[i] + kernel[i], x->shape[ii]);
-          start[i] = ir::Max::make(start[i], make_const(DataType::Int(32), 0));
+          end[i] = ir::MinNode::make(start[i] + kernel[i], x->shape[ii]);
+          start[i] = ir::MaxNode::make(start[i], make_const(DataType::Int(32), 0));
           kernel_size *= (end[i] - start[i]);
         }
 
-        Expr divide_factor = ir::Max::make(kernel_size, make_const(DataType::Int(32), 1));
+        Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1));
         return div(pool_sum(indices), divide_factor);
       }
     }, "tensor", kElementWise);
index 0ffc3e0..2d3d7d3 100644 (file)
@@ -299,7 +299,8 @@ inline FCommReduce MakeCommReducer(FCombine fcombine,
     auto combiner = tvm::ir::CommReducerNode::make(lhs, rhs, result, id_elem);
     Array<Expr> outputs;
     for (size_t i = 0; i < exprs.size(); ++i) {
-      outputs.push_back(tvm::ir::Reduce::make(combiner, exprs, axis, cond, static_cast<int>(i)));
+      outputs.push_back(
+        tvm::ir::ReduceNode::make(combiner, exprs, axis, cond, static_cast<int>(i)));
     }
     return outputs;
   };
@@ -472,8 +473,8 @@ inline Tensor argmin(const Tensor& data,
                      bool atleast1d = false) {
   auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
     Array<Expr> result;
-    result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[0], rhs[0]));  // idx
-    result.push_back(tvm::ir::Select::make(lhs[1] <= rhs[1], lhs[1], rhs[1]));  // val
+    result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0]));  // idx
+    result.push_back(tvm::ir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1]));  // val
     return result;
   };
   auto fidentity = [](std::vector<DataType> types) {
@@ -489,8 +490,8 @@ inline Tensor argmin(const Tensor& data,
 inline FCommReduce MakeArgmaxReducer() {
   auto fcombine = [](Array<Var> lhs, Array<Var> rhs) {
     Array<Expr> result;
-    result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[0], rhs[0]));  // idx
-    result.push_back(tvm::ir::Select::make(lhs[1] >= rhs[1], lhs[1], rhs[1]));  // val
+    result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0]));  // idx
+    result.push_back(tvm::ir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1]));  // val
     return result;
   };
   auto fidentity = [](std::vector<DataType> types) {
index 9a66280..00106c1 100644 (file)
@@ -211,7 +211,7 @@ inline Tensor reshape(const Tensor& x,
   Array<Expr> target_shape;
 
   for (const auto &ele : newshape) {
-    if (ele.as<IntImm>()) {
+    if (ele.as<IntImmNode>()) {
       target_shape.push_back(cast(DataType::Int(32), ele));
     } else {
       target_shape.push_back(ele);
@@ -840,7 +840,7 @@ inline Tensor where(const Tensor& condition,
       << condition->shape.size() << " vs " << x->shape.size();
     out = compute(
       oshape, [&](const Array<Var>& indices) {
-        return tvm::ir::Select::make(condition(indices) != 0, x(indices), y(indices));
+        return tvm::ir::SelectNode::make(condition(indices) != 0, x(indices), y(indices));
       }, name, tag);
   } else {
     CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
@@ -849,7 +849,7 @@ inline Tensor where(const Tensor& condition,
     out = compute(
       oshape, [&](const Array<Var>& indices) {
         Array<Expr> condition_idx{indices[0]};
-        return tvm::ir::Select::make(condition(condition_idx) != 0,
+        return tvm::ir::SelectNode::make(condition(condition_idx) != 0,
                                      x(indices), y(indices));
       }, name, tag);
   }
@@ -1316,7 +1316,7 @@ inline Tensor one_hot(const Tensor& indices,
     }
 
     auto idx = iter_vars[true_axis];
-    return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
+    return ir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
   }, name, tag);
 }