[TIR][REFACTOR][API-Change] Migrate tir/stmt.h to use constructor. (#5778)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 11 Jun 2020 23:35:43 +0000 (16:35 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Jun 2020 23:35:43 +0000 (16:35 -0700)
This PR migrate tvm/tir/stmt.h to the new constructor style that is
consistent with the rest of the codebase and changes the affected files accordingly.

52 files changed:
include/tvm/tir/stmt.h
include/tvm/tir/stmt_functor.h
src/arith/ir_mutator_with_analyzer.cc
src/target/llvm/codegen_cpu.cc
src/te/operation/compute_op.cc
src/te/operation/cross_thread_reduction.cc
src/te/operation/extern_op.cc
src/te/operation/hybrid_op.cc
src/te/operation/op_util.cc
src/te/operation/scan_op.cc
src/te/operation/tensor_compute_op.cc
src/te/operation/tensorize.cc
src/te/schedule/operation_inline.cc
src/te/schedule/schedule_dataflow_rewrite.cc
src/te/schedule/schedule_ops.cc
src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
src/te/schedule/schedule_postproc_to_primfunc.cc
src/tir/ir/buffer.cc
src/tir/ir/expr.cc
src/tir/ir/stmt.cc
src/tir/ir/stmt_functor.cc
src/tir/pass/hoist_if_then_else.cc
src/tir/transforms/arg_binder.cc
src/tir/transforms/bound_checker.cc
src/tir/transforms/combine_context_call.cc
src/tir/transforms/coproc_sync.cc
src/tir/transforms/decorate_device_scope.cc
src/tir/transforms/inject_double_buffer.cc
src/tir/transforms/inject_virtual_thread.cc
src/tir/transforms/ir_util.cc
src/tir/transforms/ir_util.h
src/tir/transforms/lift_attr_scope.cc
src/tir/transforms/loop_partition.cc
src/tir/transforms/lower_custom_datatypes.cc
src/tir/transforms/lower_device_storage_access_info.cc
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/lower_tvm_builtin.cc
src/tir/transforms/lower_warp_memory.cc
src/tir/transforms/make_packed_api.cc
src/tir/transforms/narrow_datatype.cc
src/tir/transforms/remap_thread_axis.cc
src/tir/transforms/remove_no_op.cc
src/tir/transforms/simplify.cc
src/tir/transforms/split_host_device.cc
src/tir/transforms/storage_flatten.cc
src/tir/transforms/storage_rewrite.cc
src/tir/transforms/tensorcore_infer_fragment.cc
src/tir/transforms/thread_storage_sync.cc
src/tir/transforms/unroll_loop.cc
src/tir/transforms/vectorize_loop.cc
tests/cpp/ir_functor_test.cc
topi/include/topi/detail/extern.h

index d4c813e..2aaf795 100644 (file)
@@ -79,13 +79,22 @@ class LetStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
-
   static constexpr const char* _type_key = "LetStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to LetStmtNode.
+ * \sa LetStmtNode
+ */
+class LetStmt : public Stmt {
+ public:
+  TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
+};
+
+/*!
  * \brief Define certain auxiliary attribute for the body to be a symbolic value.
  *  This provide auxiliary information for IR passes that transforms body.
  *
@@ -125,13 +134,22 @@ class AttrStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
-
   static constexpr const char* _type_key = "AttrStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to AttrStmtNode.
+ * \sa AttrStmtNode
+ */
+class AttrStmt : public Stmt {
+ public:
+  TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
+};
+
+/*!
  * \brief Assert condition, if an error occurs, return the error message.
  */
 class AssertStmtNode : public StmtNode {
@@ -163,13 +181,22 @@ class AssertStmtNode : public StmtNode {
     hash_reduce(body);
   }
 
-  TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
-
   static constexpr const char* _type_key = "AssertStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to AssertStmtNode.
+ * \sa AssertStmtNode
+ */
+class AssertStmt : public Stmt {
+ public:
+  TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
+};
+
+/*!
  * \brief Store value to the buffer.
  *
  *  Equivalent to ((DType*)buffer_var)[index] = value.
@@ -217,13 +244,22 @@ class StoreNode : public StmtNode {
     hash_reduce(predicate);
   }
 
-  TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);
-
   static constexpr const char* _type_key = "Store";
   TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to StoreNode.
+ * \sa StoreNode
+ */
+class Store : public Stmt {
+ public:
+  TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
+};
+
+/*!
  * \brief Store value to the high dimension buffer.
  *
  * \code
@@ -270,6 +306,7 @@ class BufferStoreNode : public StmtNode {
 class BufferStore : public Stmt {
  public:
   TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+
   TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
 };
 
@@ -369,13 +406,22 @@ class ProducerStoreNode : public StmtNode {
     hash_reduce(indices);
   }
 
-  TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);
-
   static constexpr const char* _type_key = "ProducerStore";
   TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to ProducerStoreNode.
+ * \sa ProducerStoreNode
+ */
+class ProducerStore : public Stmt {
+ public:
+  TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
+};
+
+/*!
  * \brief Annotate the bounds where the data produced by the producer
  *  need to be written and read in body.
  *  We will need to allocate space for the corresponding regions.
@@ -404,8 +450,6 @@ class ProducerRealizeNode : public StmtNode {
     v->Visit("body", &body);
   }
 
-  TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);
-
   bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
     return equal(producer, other->producer) && equal(bounds, other->bounds) &&
            equal(condition, other->condition) && equal(body, other->body);
@@ -423,6 +467,17 @@ class ProducerRealizeNode : public StmtNode {
 };
 
 /*!
+ * \brief Managed reference to ProducerRealizeNode.
+ * \sa ProducerRealizeNode
+ */
+class ProducerRealize : public Stmt {
+ public:
+  TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
+};
+
+/*!
  * \brief Allocate a buffer that can be used in body.
  */
 class AllocateNode : public StmtNode {
@@ -460,9 +515,6 @@ class AllocateNode : public StmtNode {
     hash_reduce(body);
   }
 
-  TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
-                           PrimExpr condition, Stmt body);
-
   /*!
    * \brief If the buffer size is constant, return the size.
    *        Otherwise return 0.
@@ -481,6 +533,18 @@ class AllocateNode : public StmtNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
 };
 
+/*!
+ * \brief Managed reference to AllocateNode.
+ * \sa AllocateNode
+ */
+class Allocate : public Stmt {
+ public:
+  TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
+                   Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
+};
+
 /*! \brief Free the resources in the buffer before the scope ends. */
 class FreeNode : public StmtNode {
  public:
@@ -495,13 +559,22 @@ class FreeNode : public StmtNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
 
-  TVM_DLL static Stmt make(Var buffer_var);
-
   static constexpr const char* _type_key = "Free";
   TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to FreeNode.
+ * \sa FreeNode
+ */
+class Free : public Stmt {
+ public:
+  TVM_DLL Free(Var buffer_var);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
+};
+
+/*!
  * \brief The container of seq statement.
  *        Represent a sequence of statements.
  */
@@ -624,13 +697,22 @@ class IfThenElseNode : public StmtNode {
     hash_reduce(else_case);
   }
 
-  TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
-
   static constexpr const char* _type_key = "IfThenElse";
   TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
 };
 
 /*!
+ * \brief Managed reference to IfThenElseNode.
+ * \sa IfThenElseNode
+ */
+class IfThenElse : public Stmt {
+ public:
+  TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
+};
+
+/*!
  * \brief Evaluates an expression.
  *  This is mostly used for putting a Call node into Stmt.
  *
@@ -649,12 +731,23 @@ class EvaluateNode : public StmtNode {
 
   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
 
-  TVM_DLL static Stmt make(PrimExpr v);
-
   static constexpr const char* _type_key = "Evaluate";
   TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
 };
 
+/*!
+ * \brief Managed reference to EvaluateNode.
+ * \sa EvaluateNode
+ */
+class Evaluate : public Stmt {
+ public:
+  TVM_DLL explicit Evaluate(PrimExpr value);
+
+  explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {}
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
+};
+
 /*! \brief Additional annotation of for loop. */
 enum class ForType : int {
   /*! \brief serial execution. */
@@ -700,9 +793,6 @@ class ForNode : public StmtNode {
   /*! \brief The body of the for loop. */
   Stmt body;
 
-  TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
-                           DeviceAPI device_api, Stmt body);
-
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("loop_var", &loop_var);
     v->Visit("min", &min);
@@ -732,6 +822,18 @@ class ForNode : public StmtNode {
 };
 
 /*!
+ * \brief Managed reference to ForNode.
+ * \sa ForNode
+ */
+class For : public Stmt {
+ public:
+  TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
+              Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
+};
+
+/*!
  * \brief A prefetch hint for abuffer
  */
 class PrefetchNode : public StmtNode {
@@ -773,7 +875,6 @@ class Prefetch : public Stmt {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
 };
 
-
 /*! \brief namespace of possible attribute sin AttrStmt.attr_key */
 namespace attr {
 // The above attr does not pass to ir stage.
index 9a85b38..f037de7 100644 (file)
@@ -352,7 +352,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(cons
  * \tparam T the input type, can be PrimExpr or Stmt.
  */
 template <typename T>
-inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
+inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
   auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
     auto it = value_map.find(var);
     if (it != value_map.end()) return (*it).second;
index f4bb9c2..84e2093 100644 (file)
@@ -76,7 +76,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
     if (else_case.defined()) {
       return else_case;
     }
-    return EvaluateNode::make(0);
+    return Evaluate(0);
   }
 
   if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
index 05c2ef2..9113c98 100644 (file)
@@ -901,8 +901,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
   } else if (op->for_type == ForType::Parallel) {
     if (parallel_env_.penv == nullptr) {
       CreateParallelLaunch(
-          ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body),
-          0);
+          For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0);
     } else {
       // already in parallel env.
       CHECK(parallel_env_.task_id.defined());
index 66f0820..7f957b5 100644 (file)
@@ -264,7 +264,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
   Stmt realize = body;
   for (int i = this->num_outputs(); i > 0; --i) {
     Tensor t = stage->op.output(i - 1);
-    realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize);
+    realize = tir::ProducerRealize(t, bounds, const_true(), realize);
     // alignment requirement, only useful for compute
     for (size_t i = 0; i < num_schedulable_dims(); ++i) {
       auto it = stage->iter_var_attrs.find(this->axis[i]);
@@ -273,7 +273,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
         if (attr->dim_align_factor != 0) {
           Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
                                    attr->dim_align_offset};
-          realize = tir::AttrStmtNode::make(
+          realize = tir::AttrStmt(
               t, tir::attr::buffer_dim_align,
               Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
               realize);
@@ -308,13 +308,13 @@ void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt*
   Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
   for (size_t i = 0; i < size; ++i) {
     Tensor t = tensors[i];
-    inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args));
-    provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args));
+    inits.emplace_back(ProducerStore(t, init_value[i], args));
+    provides.emplace_back(ProducerStore(t, update_value[i], args));
   }
   *init = SeqStmt::Flatten(inits);
   *provide = SeqStmt::Flatten(provides);
   if (!is_one(reduce->condition)) {
-    *provide = IfThenElseNode::make(reduce->condition, *provide);
+    *provide = IfThenElse(reduce->condition, *provide);
   }
 }
 
@@ -324,7 +324,7 @@ Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
   for (IterVar iv : op->axis) {
     args.push_back(iv->var);
   }
-  return ProducerStoreNode::make(t, op->body[t->value_index], args);
+  return ProducerStore(t, op->body[t->value_index], args);
 }
 
 Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
@@ -587,7 +587,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
   }
 
   auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
-  return IfThenElseNode::make(cond, update, body);
+  return IfThenElse(cond, update, body);
 }
 
 }  // namespace te
index cd76910..e834ff2 100644 (file)
@@ -149,9 +149,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
     for (size_t i = 0; i < size; ++i) {
       DataType t = reduces[i]->dtype;
       normal_init.emplace_back(
-          StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
+          Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
       normal_update.emplace_back(
-          StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+          Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
     }
   }
 
@@ -194,10 +194,10 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   // Apply the existing input predicate if any.
   output_preds.push_back(input_pred);
 
-  Stmt reduce_body = EvaluateNode::make(Call(
-      DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic));
-  reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope,
-                                   make_zero(DataType::Handle()), reduce_body);
+  Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce,
+                                   freduce_args, CallNode::Intrinsic));
+  reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
+                         make_zero(DataType::Handle()), reduce_body);
 
   if (!normal_red.empty()) {
     Stmt init_body = SeqStmt::Flatten(normal_init);
@@ -210,22 +210,20 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   std::vector<Stmt> assigns(size);
   for (size_t idx = 0; idx < size; ++idx) {
     DataType t = reduces[idx]->dtype;
-    assigns[idx] = ProducerStoreNode::make(
-        stage->op.output(idx), Load(t, res_handles[idx], 0, const_true(t.lanes())), args);
+    assigns[idx] = ProducerStore(stage->op.output(idx),
+                                 Load(t, res_handles[idx], 0, const_true(t.lanes())), args);
   }
   Stmt assign_body = SeqStmt::Flatten(assigns);
   assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
   Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
   for (size_t idx = size; idx != 0; --idx) {
-    body =
-        AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
-    body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"),
-                              body);
+    body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+    body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body);
     if (!normal_red.empty()) {
-      body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1},
-                                const_true(), body);
-      body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope,
-                                StringImm("local"), body);
+      body =
+          Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+      body =
+          AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body);
     }
   }
   body = Substitute(body, value_map);
index 75181b8..0933e30 100644 (file)
@@ -128,7 +128,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage,
     for (size_t i = 0; i < t->shape.size(); ++i) {
       bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
-    realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body);
+    realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
   }
   return realize_body;
 }
@@ -137,8 +137,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage,
                                 const std::unordered_map<IterVar, Range>& dom_map,
                                 bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret =
-      AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+  Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
   auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
     Array<ObjectRef> bind_spec;
     Array<PrimExpr> tuple;
@@ -148,9 +147,8 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage,
       tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
       tuple.push_back(buffer->shape[k]);
     }
-    ret = AttrStmtNode::make(
-        bind_spec, tir::attr::buffer_bind_scope,
-        Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
+    ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+                   Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
   };
   for (size_t i = output_placeholders.size(); i != 0; --i) {
     f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
index c927f80..9b3a79f 100644 (file)
@@ -152,7 +152,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage,
     for (size_t i = 0; i < t->shape.size(); ++i) {
       bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
-    realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body);
+    realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
   }
   return realize_body;
 }
@@ -161,8 +161,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage,
                                 const std::unordered_map<IterVar, Range>& dom_map,
                                 bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt ret =
-      AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+  Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
   std::unordered_map<Tensor, Tensor> rmap;
   for (int i = 0; i < this->num_outputs(); ++i) {
     rmap[outputs[i]] = stage->op.output(i);
@@ -231,11 +230,11 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
         rmap[op->loop_var.get()] = inner + outer * factor;
         Stmt ret = tir::Substitute(op->body, rmap);
         PrimExpr cond = likely(outer * factor < (op->extent - inner));
-        ret = IfThenElseNode::make(cond, ret);
-        ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
-                            IterVarTypeToForType(inner->iter_type), op->device_api, ret);
-        ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
-                            IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+        ret = IfThenElse(cond, ret);
+        ret = For(inner->var, PrimExpr(0), inner->dom->extent,
+                  IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+        ret = For(outer->var, PrimExpr(0), outer->dom->extent,
+                  IterVarTypeToForType(outer->iter_type), op->device_api, ret);
         splitted = true;
         return ret;
       }
@@ -276,8 +275,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
         rmap[op->loop_var.get()] = indexdiv(parent, extent);
         body = tir::Substitute(body, rmap);
         under_outer = false;
-        return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type,
-                             op->device_api, body);
+        return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api,
+                   body);
       } else if (under_outer) {
         Stmt body = this->VisitStmt(op->body);
         std::unordered_map<const VarNode*, PrimExpr> rmap;
@@ -328,10 +327,10 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar,
           std::unordered_map<const VarNode*, PrimExpr> rmap;
           rmap[op->loop_var.get()] = iter_var;
           Stmt body = tir::Substitute(op->body, rmap);
-          return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
+          return AttrStmt(iter_var, "thread_extent", op->extent, body);
         } else {
-          return ForNode::make(op->loop_var, op->min, op->extent,
-                               IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
+          return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type),
+                     op->device_api, op->body);
         }
       }
       return StmtMutator::VisitStmt_(op);
@@ -413,7 +412,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>
         for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
       }
       const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
-      return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
+      return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
     }
   };
 
@@ -463,7 +462,7 @@ class ProviderReplacer : public tir::StmtMutator {
     Tensor t = Downcast<Tensor>(op->producer);
     auto it = vmap_.find(t);
     if (it != vmap_.end()) {
-      Stmt ret = tir::ProducerStoreNode::make(it->second, op->value, op->indices);
+      Stmt ret = tir::ProducerStore(it->second, op->value, op->indices);
       found = true;
       return this->VisitStmt(ret);
     }
index 936781d..f1b0527 100644 (file)
@@ -45,7 +45,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
                                              std::unordered_map<IterVar, PrimExpr>* p_value_map,
                                              bool debug_keep_trivial_loop) {
   auto leaf_iter_vars = stage->leaf_iter_vars;
-  Stmt no_op = EvaluateNode::make(0);
+  Stmt no_op = Evaluate(0);
   // create the loop nest
   std::vector<std::vector<Stmt> > nest;
   nest.resize(leaf_iter_vars.size() + 1);
@@ -108,31 +108,28 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
             pvalue = make_const(DataType::Int(32), 1);
           }
           nest[i + 1].emplace_back(
-              AttrStmtNode::make(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+              AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
         }
       }
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
-        nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op));
+        nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op));
         value_map[iv] = dom->min;
       } else if (is_zero(dom->min)) {
-        nest[i + 1].emplace_back(
-            ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
+        nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
         value_map[iv] = var;
       } else {
         Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
-        nest[i + 1].emplace_back(
-            ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
+        nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
         PrimExpr new_value = dom->min + idx;
         value_map[iv] = new_value;
-        nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op));
+        nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
       }
       if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
         CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1";
         CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size());
         for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
-          nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j],
-                                                      tir::attr::prefetch_scope,
-                                                      it_attr->prefetch_offset[j], no_op));
+          nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope,
+                                            it_attr->prefetch_offset[j], no_op));
         }
       }
     } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") {
@@ -141,8 +138,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
       CHECK(is_zero(dom->min));
       CHECK(is_positive_const(dom->extent));
       // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
+      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
       value_map[iv] = var;
     } else if (bind_iv->thread_tag == "pipeline") {
       // pipeline marker.
@@ -150,14 +146,13 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
       CHECK(is_one(dom->extent));
       // annotate the extent of the IterVar
       nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
+          AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
       value_map[iv] = dom->min;
     } else {
       // Always restrict threaded IterVar to starts from 0.
       CHECK(is_zero(dom->min));
       // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmtNode::make(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
+      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         value_map[iv] = dom->min;
       } else {
@@ -184,7 +179,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
     }
     // annotate the extent of the IterVar
     if (!new_loop_var) {
-      nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
+      nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op));
     }
   }
   // message passing to get offset of root iter vars.
@@ -193,10 +188,10 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
 }
 
 std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
-  Stmt no_op = EvaluateNode::make(0);
+  Stmt no_op = Evaluate(0);
   std::vector<Stmt> nest;
   for (const PrimExpr& cond : predicates) {
-    nest.emplace_back(IfThenElseNode::make(cond, no_op));
+    nest.emplace_back(IfThenElse(cond, no_op));
   }
   return nest;
 }
index 675954a..45e86e2 100644 (file)
@@ -246,7 +246,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterV
       IterVar sp_ax = this->spatial_axis_[sp_idx];
       bounds.push_back(dom_map.at(sp_ax));
     }
-    ret = tir::ProducerRealizeNode::make(t, bounds, const_true(), ret);
+    ret = tir::ProducerRealize(t, bounds, const_true(), ret);
   }
   return ret;
 }
@@ -254,9 +254,9 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterV
 Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
                               bool debug_keep_trivial_loop) const {
   CHECK_EQ(stage->op.operator->(), this);
-  Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
-                                    EvaluateNode::make(0));
-  Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0));
+  Stmt provide =
+      AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0));
+  Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0));
   size_t begin_scan = 0;
   for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
     if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
index f9e0c8d..c8dfce8 100644 (file)
@@ -127,7 +127,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
   CHECK_EQ(stage->op.operator->(), this);
 
   // Start bind data.
-  Stmt nop = EvaluateNode::make(0);
+  Stmt nop = Evaluate(0);
   std::vector<Stmt> input_bind_nest, output_bind_nest;
   Array<Tensor> inputs = this->InputTensors();
 
@@ -144,7 +144,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
       tuple.push_back(region[i]->min);
       tuple.push_back(region[i]->extent);
     }
-    input_bind_nest.emplace_back(AttrStmtNode::make(
+    input_bind_nest.emplace_back(AttrStmt(
         bind_spec, tir::attr::buffer_bind_scope,
         Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
   }
@@ -168,7 +168,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
       }
     }
 
-    output_bind_nest.emplace_back(AttrStmtNode::make(
+    output_bind_nest.emplace_back(AttrStmt(
         bind_spec, tir::attr::buffer_bind_scope,
         Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
   }
index 224907d..af4b08e 100644 (file)
@@ -351,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
   VerifyTensorizeLoopNest(self, stage, n, tloc);
   VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
   // Start bind data.
-  Stmt nop = EvaluateNode::make(0);
+  Stmt nop = Evaluate(0);
   std::vector<Stmt> input_bind_nest, output_bind_nest;
   Array<Tensor> inputs = self->InputTensors();
   CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch ";
@@ -368,7 +368,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
       tuple.push_back(r->min);
       tuple.push_back(r->extent);
     }
-    input_bind_nest.emplace_back(AttrStmtNode::make(
+    input_bind_nest.emplace_back(AttrStmt(
         bind_spec, tir::attr::buffer_bind_scope,
         Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
   }
@@ -388,7 +388,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
     Tensor tensor = stage->op.output(i - intrin->inputs.size());
     Buffer buffer = intrin->buffers[i];
     Array<ObjectRef> bind_spec{buffer, tensor};
-    output_bind_nest.emplace_back(AttrStmtNode::make(
+    output_bind_nest.emplace_back(AttrStmt(
         bind_spec, tir::attr::buffer_bind_scope,
         Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
   }
index 13b601a..fd613f4 100644 (file)
@@ -65,7 +65,7 @@ class OperationInliner final : public StmtExprMutator {
         for (size_t i = 0; i < args_.size(); ++i) {
           vmap.Set(args_[i], op->indices[i]);
         }
-        expr = Substitute(EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
+        expr = Substitute(Evaluate(expr), vmap).as<EvaluateNode>()->value;
       }
       return expr;
     } else {
index 009d74f..c360513 100644 (file)
@@ -533,10 +533,9 @@ void InjectInline(ScheduleNode* sch) {
               CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
                                                   << "have the same attribute except value_index";
             }
-            PrimExpr new_value =
-                Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body)
-                    .as<tir::EvaluateNode>()
-                    ->value;
+            PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
+                                     .as<tir::EvaluateNode>()
+                                     ->value;
             if (!new_value.same_as(new_body[j][0])) {
               changed[j] = true;
               const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
@@ -551,10 +550,9 @@ void InjectInline(ScheduleNode* sch) {
             }
           } else {
             for (size_t k = 0; k < new_body[j].size(); ++k) {
-              PrimExpr new_value =
-                  Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body)
-                      .as<tir::EvaluateNode>()
-                      ->value;
+              PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
+                                       .as<tir::EvaluateNode>()
+                                       ->value;
               if (!new_value.same_as(new_body[j][k])) {
                 new_body[j].Set(k, new_value);
                 changed[j] = true;
index f5ba43c..f2955f3 100644 (file)
@@ -44,7 +44,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_
                   bool debug_keep_trivial_loop) {
   Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
   if (s->double_buffer) {
-    producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer);
+    producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer);
   }
   Stmt pipeline = producer;
 
@@ -53,7 +53,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_
   }
   pipeline = s->op->BuildRealize(s, dom_map, pipeline);
   // use attribute to mark scope of the operation.
-  pipeline = AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);
+  pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);
 
   return pipeline;
 }
@@ -77,9 +77,8 @@ class InjectAttach : public StmtMutator {
         CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar
                              << " in multiple places in the IR";
         found_attach = true;
-        stmt =
-            AttrStmtNode::make(op->node, op->attr_key, op->value,
-                               MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+        stmt = AttrStmt(op->node, op->attr_key, op->value,
+                        MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
       }
     }
     return stmt;
@@ -120,9 +119,8 @@ class InjectScanStep : public StmtMutator {
                           (op->attr_key == tir::attr::scan_init_scope && is_init_))) {
       if (op->node.same_as(scan_op_)) {
         found_attach = true;
-        stmt =
-            AttrStmtNode::make(op->node, op->attr_key, op->value,
-                               MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+        stmt = AttrStmt(op->node, op->attr_key, op->value,
+                        MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
       }
     }
     return stmt;
@@ -182,7 +180,7 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(op->node.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body);
+          Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body);
           return this->VisitStmt(ret);
         } else {
           return this->VisitStmt(op->body);
@@ -194,9 +192,8 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(tensor->op.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          return AttrStmtNode::make(
-              Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)}, op->attr_key,
-              op->value, this->VisitStmt(op->body));
+          return AttrStmt(Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
+                          op->attr_key, op->value, this->VisitStmt(op->body));
         } else {
           return this->VisitStmt(op->body);
         }
@@ -206,8 +203,8 @@ class SchedulePostProc : public StmtExprMutator {
       auto it = replace_op_.find(tensor->op.get());
       if (it != replace_op_.end()) {
         if (it->second.defined()) {
-          return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value,
-                                    this->VisitStmt(op->body));
+          return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value,
+                          this->VisitStmt(op->body));
         } else {
           return this->VisitStmt(op->body);
         }
@@ -221,7 +218,7 @@ class SchedulePostProc : public StmtExprMutator {
     auto it = replace_realize_.find(key);
     if (it != replace_realize_.end()) {
       if (it->second.defined()) {
-        Stmt ret = ProducerRealizeNode::make(it->second, op->bounds, op->condition, op->body);
+        Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body);
         return this->VisitStmt(ret);
       } else {
         return this->VisitStmt(op->body);
@@ -236,7 +233,7 @@ class SchedulePostProc : public StmtExprMutator {
     auto it = replace_buffer_.find(key);
     if (it != replace_buffer_.end()) {
       const Tensor& dst = it->second;
-      Stmt ret = ProducerStoreNode::make(dst, op->value, op->indices);
+      Stmt ret = ProducerStore(dst, op->value, op->indices);
       return this->VisitStmt(ret);
     } else {
       return StmtExprMutator::VisitStmt_(op);
index e81ad2c..1ff569f 100644 (file)
@@ -803,7 +803,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
       new_bounds.push_back(
           Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
 
-      return ProducerRealizeNode::make(op->producer, new_bounds, op->condition, op->body);
+      return ProducerRealize(op->producer, new_bounds, op->condition, op->body);
     }
     return stmt;
   }
@@ -821,7 +821,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
         CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name;
         auto matrix_abc = tvm::tir::StringImm("wmma." + it->second);
         Stmt body = this->VisitStmt(op->body);
-        return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body);
+        return AttrStmt(op->node, op->attr_key, matrix_abc, body);
       }
     }
     return stmt;
@@ -847,13 +847,13 @@ class TensorCoreIRMutator : public StmtExprMutator {
         Buffer buffer_a(buffer_node_a);
         Buffer buffer_b(buffer_node_b);
         if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
-          return EvaluateNode::make(
+          return Evaluate(
               Call(DataType::Handle(), intrinsic::tvm_bmma_sync,
                    {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
                     buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
                    CallNode::Intrinsic));
         } else {
-          return EvaluateNode::make(
+          return Evaluate(
               Call(DataType::Handle(), intrinsic::tvm_mma_sync,
                    {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
                     buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
@@ -879,10 +879,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
         auto pload = dst.as<ProducerLoadNode>();
 
         auto fill_fragment_call = [this, &op](const Buffer& buffer) {
-          return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_fill_fragment,
-                                         {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                                          buffer->elem_offset, op->value},
-                                         CallNode::Intrinsic));
+          return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment,
+                               {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                                buffer->elem_offset, op->value},
+                               CallNode::Intrinsic));
         };
 
         ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -918,10 +918,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
       }
 
       auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
-        return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
-                                       {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                                        buffer->elem_offset, src, stride, matrix_major},
-                                       CallNode::Intrinsic));
+        return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
+                             {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                              buffer->elem_offset, src, stride, matrix_major},
+                             CallNode::Intrinsic));
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -946,10 +946,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
       auto pload = op->value.as<ProducerLoadNode>();
 
       auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
-        return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
-                                       {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
-                                        buffer->elem_offset, dst, stride, StringImm("col_major")},
-                                       CallNode::Intrinsic));
+        return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
+                             {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+                              buffer->elem_offset, dst, stride, StringImm("col_major")},
+                             CallNode::Intrinsic));
       };
 
       ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
@@ -972,8 +972,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
           scaled_extent_value = ori_extent_value / scale_factor;
         }
         PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
-        stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api,
-                             op->body);
+        stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body);
       }
     }
     return stmt;
@@ -1067,7 +1066,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
     }
     auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic);
     Array<ObjectRef> node = {buffer, tensor};
-    return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer));
+    return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
   }
 
   std::unordered_map<std::string, std::string> matrix_abc_;
index 74f4a2c..a86ad76 100644 (file)
@@ -76,19 +76,19 @@ class TensorToBufferMapper : public StmtExprMutator {
       Operation operation = Downcast<Operation>(op->node);
       for (int i = operation->num_outputs(); i != 0; --i) {
         Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
-        body = AttrStmtNode::make(buffer, op->attr_key, op->value, body);
+        body = AttrStmt(buffer, op->attr_key, op->value, body);
       }
       return body;
     } else if (op->attr_key == tir::attr::buffer_bind_scope) {
       Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
       Tensor tensor = Downcast<Tensor>(tuple[1]);
-      return AttrStmtNode::make(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key,
-                                op->value, op->body);
+      return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value,
+                      op->body);
     } else if (op->attr_key == tir::attr::buffer_dim_align ||
                op->attr_key == tir::attr::prefetch_scope) {
       Tensor tensor = Downcast<Tensor>(op->node);
       Buffer buffer = GetOrAllocBuffer(tensor);
-      return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body);
+      return AttrStmt(buffer, op->attr_key, op->value, op->body);
     } else {
       return ret;
     }
index 4c5b30f..46f4160 100644 (file)
@@ -302,11 +302,10 @@ Stmt Buffer::vstore(Array<PrimExpr> begin, PrimExpr value) const {
   CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
       << "Cannot load " << dtype << " from buffer of " << n->dtype;
   if (value.dtype() == DataType::Bool()) {
-    return tir::StoreNode::make(n->data, tir::Cast(DataType::Int(8), value),
-                                BufferOffset(n, begin, DataType::Int(8)), const_true());
+    return tir::Store(n->data, tir::Cast(DataType::Int(8), value),
+                      BufferOffset(n, begin, DataType::Int(8)), const_true());
   } else {
-    return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
-                                const_true(dtype.lanes()));
+    return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes()));
   }
 }
 
index 12df05e..7959eba 100644 (file)
@@ -687,6 +687,8 @@ TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimEx
   return Let(var, value, body);
 });
 
+TVM_REGISTER_NODE_TYPE(LetNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const LetNode*>(node.get());
@@ -816,6 +818,26 @@ TVM_REGISTER_GLOBAL("tir.Shuffle")
 
 TVM_REGISTER_NODE_TYPE(ShuffleNode);
 
+template <typename T>
+void PrintList(const Array<T>& exprs, ReprPrinter* p) {
+  for (size_t i = 0; i < exprs.size(); ++i) {
+    p->Print(exprs[i]);
+    if (i < exprs.size() - 1) {
+      p->stream << ", ";
+    }
+  }
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const ShuffleNode*>(node.get());
+      p->stream << "shuffle(";
+      PrintList(op->vectors, p);
+      p->stream << ", ";
+      PrintList(op->indices, p);
+      p->stream << ")";
+    });
+
 // CommReducer
 CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
                          Array<PrimExpr> identity_element) {
index 46c4b09..9bb1de4 100644 (file)
@@ -27,7 +27,8 @@
 namespace tvm {
 namespace tir {
 
-Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
+// LetStmt
+LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) {
   CHECK(value.defined());
   CHECK(body.defined());
   CHECK_EQ(value.dtype(), var.dtype());
@@ -36,23 +37,56 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
   node->var = std::move(var);
   node->value = std::move(value);
   node->body = std::move(body);
-  return Stmt(node);
+  data_ = std::move(node);
 }
 
-TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) {
+  return LetStmt(var, value, body);
+});
+
+TVM_REGISTER_NODE_TYPE(LetStmtNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const LetStmtNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "let " << op->var << " = ";
+      p->Print(op->value);
+      p->stream << '\n';
+      p->Print(op->body);
+    });
 
-Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
+// AttrStmt
+AttrStmt::AttrStmt(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
   auto n = make_object<AttrStmtNode>();
   n->node = node;
   n->attr_key = std::move(attr_key);
   n->value = std::move(value);
   n->body = std::move(body);
-  return Stmt(n);
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.AttrStmt")
+    .set_body_typed([](ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
+      return AttrStmt(node, attr_key, value, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(AttrStmtNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const AttrStmtNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "// attr [";
+      p->Print(op->node);
+      p->stream << "] " << op->attr_key << " = ";
+      p->Print(op->value);
+      p->stream << '\n';
+      p->Print(op->body);
+    });
 
-Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
+// AssertStmt
+AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>())
       << "TypeError: AssertStmt message must be an int or string:" << message << "\n";
@@ -61,21 +95,36 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   node->condition = std::move(condition);
   node->message = std::move(message);
   node->body = std::move(body);
-  return Stmt(node);
+  data_ = std::move(node);
 }
 
+TVM_REGISTER_NODE_TYPE(AssertStmtNode);
+
 TVM_REGISTER_GLOBAL("tir.AssertStmt")
     .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
       if (const auto* str = message.as<StringObj>()) {
         auto msg = StringImm(str->data);
-        return AssertStmtNode::make(condition, msg, body);
+        return AssertStmt(condition, msg, body);
       } else {
-        return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+        return AssertStmt(condition, Downcast<PrimExpr>(message), body);
       }
     });
 
-Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
-                   DeviceAPI device_api, Stmt body) {
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const AssertStmtNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "assert(";
+      p->Print(op->condition);
+      p->stream << ", ";
+      p->Print(op->message);
+      p->stream << ")\n";
+      p->Print(op->body);
+    });
+
+// For
+For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
+         Stmt body) {
   CHECK(min.defined());
   CHECK(extent.defined());
   CHECK(min.dtype().is_scalar());
@@ -90,224 +139,16 @@ Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type
   node->for_type = for_type;
   node->device_api = device_api;
   node->body = std::move(body);
-  return Stmt(node);
+  data_ = std::move(node);
 }
 
 TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent,
                                                  int for_type, int device_api, Stmt body) {
-  return ForNode::make(loop_var, min, extent, static_cast<ForType>(for_type),
-                       static_cast<DeviceAPI>(device_api), body);
-});
-
-Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
-  CHECK(value.defined());
-  CHECK(index.defined());
-  CHECK(predicate.defined());
-  CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
-  CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
-
-  ObjectPtr<StoreNode> node = make_object<StoreNode>();
-  node->buffer_var = std::move(buffer_var);
-  node->value = std::move(value);
-  node->index = std::move(index);
-  node->predicate = std::move(predicate);
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
-  PrimExpr value = args[1];
-  if (args.size() == 3) {
-    *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
-  } else {
-    *ret = StoreNode::make(args[0], value, args[2], args[3]);
-  }
+  return For(loop_var, min, extent, static_cast<ForType>(for_type),
+             static_cast<DeviceAPI>(device_api), body);
 });
 
-Stmt ProducerStoreNode::make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
-  ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
-  node->producer = std::move(producer);
-  node->value = std::move(value);
-  node->indices = std::move(indices);
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.ProducerStore").set_body_typed(ProducerStoreNode::make);
-
-Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
-                        Stmt body) {
-  for (size_t i = 0; i < extents.size(); ++i) {
-    CHECK(extents[i].defined());
-    CHECK(extents[i].dtype().is_scalar());
-  }
-  CHECK(body.defined());
-  CHECK(condition.defined());
-  CHECK(condition.dtype().is_bool());
-
-  ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
-  node->buffer_var = std::move(buffer_var);
-  node->dtype = dtype;
-  node->extents = std::move(extents);
-  node->condition = std::move(condition);
-  node->body = std::move(body);
-  return Stmt(node);
-}
-
-Stmt ProducerRealizeNode::make(DataProducer producer, Region bounds, PrimExpr condition,
-                               Stmt body) {
-  for (size_t i = 0; i < bounds.size(); ++i) {
-    CHECK(bounds[i]->min.defined());
-    CHECK(bounds[i]->extent.defined());
-    CHECK(bounds[i]->min.dtype().is_scalar());
-    CHECK(bounds[i]->extent.dtype().is_scalar());
-  }
-  CHECK(body.defined());
-  CHECK(condition.defined());
-  CHECK(condition.dtype().is_bool());
-
-  ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
-  node->producer = std::move(producer);
-  node->bounds = std::move(bounds);
-  node->condition = std::move(condition);
-  node->body = std::move(body);
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.ProducerRealize").set_body_typed(ProducerRealizeNode::make);
-
-// overloaded, needs special handling
-// has default args
-TVM_REGISTER_GLOBAL("tir.Allocate")
-    .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
-                       Stmt body) {
-      return AllocateNode::make(buffer_var, type, extents, condition, body);
-    });
-
-int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
-  int64_t result = 1;
-  for (size_t i = 0; i < extents.size(); ++i) {
-    if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
-      result *= int_size->value;
-      if (result > std::numeric_limits<int32_t>::max()) {
-        return 0;
-      }
-    } else {
-      return 0;
-    }
-  }
-  return static_cast<int32_t>(result);
-}
-
-Stmt FreeNode::make(Var buffer_var) {
-  ObjectPtr<FreeNode> node = make_object<FreeNode>();
-  node->buffer_var = buffer_var;
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make);
-
-Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
-  data_ = make_object<PrefetchNode>(buffer, bounds);
-}
-
-TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) {
-  return Prefetch(buffer, bounds);
-});
-
-SeqStmt::SeqStmt(Array<Stmt> seq) {
-  auto node = make_object<SeqStmtNode>();
-  node->seq = std::move(seq);
-  data_ = std::move(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) {
-  return SeqStmt(std::move(seq));
-});
-
-Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
-  CHECK(condition.defined());
-  CHECK(then_case.defined());
-  // else_case may be null.
-
-  ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
-  node->condition = std::move(condition);
-  node->then_case = std::move(then_case);
-  node->else_case = std::move(else_case);
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make);
-
-Stmt EvaluateNode::make(PrimExpr value) {
-  CHECK(value.defined());
-
-  ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
-  node->value = std::move(value);
-  return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make);
-
-BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
-  ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
-  node->buffer = std::move(buffer);
-  node->value = std::move(value);
-  node->indices = std::move(indices);
-  data_ = std::move(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.BufferStore")
-    .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
-      return BufferStore(buffer, value, indices);
-    });
-
-TVM_REGISTER_NODE_TYPE(BufferStoreNode);
-
-BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
-  data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body);
-}
-
-TVM_REGISTER_GLOBAL("tir.BufferRealize")
-    .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
-      return BufferRealize(buffer, bounds, condition, body);
-    });
-
-TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
-
-// Printers
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const LetStmtNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "let " << op->var << " = ";
-      p->Print(op->value);
-      p->stream << '\n';
-      p->Print(op->body);
-    });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const AttrStmtNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "// attr [";
-      p->Print(op->node);
-      p->stream << "] " << op->attr_key << " = ";
-      p->Print(op->value);
-      p->stream << '\n';
-      p->Print(op->body);
-    });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const AssertStmtNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "assert(";
-      p->Print(op->condition);
-      p->stream << ", ";
-      p->Print(op->message);
-      p->stream << ")\n";
-      p->Print(op->body);
-    });
+TVM_REGISTER_NODE_TYPE(ForNode);
 
 std::ostream& operator<<(std::ostream& out, ForType type) {  // NOLINT(*)
   switch (type) {
@@ -345,6 +186,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "}\n";
     });
 
+// Store
+Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
+  CHECK(value.defined());
+  CHECK(index.defined());
+  CHECK(predicate.defined());
+  CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
+  CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
+
+  ObjectPtr<StoreNode> node = make_object<StoreNode>();
+  node->buffer_var = std::move(buffer_var);
+  node->value = std::move(value);
+  node->index = std::move(index);
+  node->predicate = std::move(predicate);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
+  PrimExpr value = args[1];
+  if (args.size() == 3) {
+    *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()));
+  } else {
+    *ret = Store(args[0], value, args[2], args[3]);
+  }
+});
+
+TVM_REGISTER_NODE_TYPE(StoreNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const StoreNode*>(node.get());
@@ -360,6 +228,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << '\n';
     });
 
+// ProducerStore
+ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
+  ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
+  node->producer = std::move(producer);
+  node->value = std::move(value);
+  node->indices = std::move(indices);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.ProducerStore")
+    .set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
+      return ProducerStore(producer, value, indices);
+    });
+
+TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ProducerStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const ProducerStoreNode*>(node.get());
@@ -375,20 +259,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << '\n';
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const BufferStoreNode*>(node.get());
-      p->PrintIndent();
-      p->stream << op->buffer->name << "[";
-      for (size_t i = 0; i < op->indices.size(); ++i) {
-        p->Print(op->indices[i]);
-        if (i < op->indices.size() - 1) p->stream << ", ";
+// Allocate
+Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
+                   Stmt body) {
+  for (size_t i = 0; i < extents.size(); ++i) {
+    CHECK(extents[i].defined());
+    CHECK(extents[i].dtype().is_scalar());
+  }
+  CHECK(body.defined());
+  CHECK(condition.defined());
+  CHECK(condition.dtype().is_bool());
+
+  ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
+  node->buffer_var = std::move(buffer_var);
+  node->dtype = dtype;
+  node->extents = std::move(extents);
+  node->condition = std::move(condition);
+  node->body = std::move(body);
+  data_ = std::move(node);
+}
+
+int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
+  int64_t result = 1;
+  for (size_t i = 0; i < extents.size(); ++i) {
+    if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
+      result *= int_size->value;
+      if (result > std::numeric_limits<int32_t>::max()) {
+        return 0;
       }
-      p->stream << "]";
-      p->stream << " = ";
-      p->Print(op->value);
-      p->stream << '\n';
-    });
+    } else {
+      return 0;
+    }
+  }
+  return static_cast<int32_t>(result);
+}
+
+TVM_REGISTER_GLOBAL("tir.Allocate")
+    .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
+                       Stmt body) { return Allocate(buffer_var, type, extents, condition, body); });
+
+TVM_REGISTER_NODE_TYPE(AllocateNode);
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
@@ -408,42 +318,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->Print(op->body);
     });
 
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const FreeNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "free " << op->buffer_var;
-      p->stream << '\n';
-    });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const BufferRealizeNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "buffer_realize " << op->buffer->name << "(";
-      for (size_t i = 0; i < op->bounds.size(); ++i) {
-        p->stream << "[";
-        p->Print(op->bounds[i]->min);
-        p->stream << ", ";
-        p->Print(op->bounds[i]->extent);
-        p->stream << "]";
-        if (i < op->bounds.size() - 1) p->stream << ", ";
-      }
-      p->stream << ")";
-      if (!is_one(op->condition)) {
-        p->stream << " if ";
-        p->Print(op->condition);
-      }
-      p->stream << " {\n";
+// ProducerRealize
+ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
+                                 Stmt body) {
+  for (size_t i = 0; i < bounds.size(); ++i) {
+    CHECK(bounds[i]->min.defined());
+    CHECK(bounds[i]->extent.defined());
+    CHECK(bounds[i]->min.dtype().is_scalar());
+    CHECK(bounds[i]->extent.dtype().is_scalar());
+  }
+  CHECK(body.defined());
+  CHECK(condition.defined());
+  CHECK(condition.dtype().is_bool());
 
-      p->indent += 2;
-      p->Print(op->body);
-      p->indent -= 2;
+  ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
+  node->producer = std::move(producer);
+  node->bounds = std::move(bounds);
+  node->condition = std::move(condition);
+  node->body = std::move(body);
+  data_ = std::move(node);
+}
 
-      p->PrintIndent();
-      p->stream << "}\n";
+TVM_REGISTER_GLOBAL("tir.ProducerRealize")
+    .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) {
+      return ProducerRealize(producer, bounds, condition, body);
     });
 
+TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ProducerRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const ProducerRealizeNode*>(node.get());
@@ -472,6 +374,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "}\n";
     });
 
+// Free
+Free::Free(Var buffer_var) {
+  ObjectPtr<FreeNode> node = make_object<FreeNode>();
+  node->buffer_var = buffer_var;
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); });
+
+TVM_REGISTER_NODE_TYPE(FreeNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const FreeNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "free " << op->buffer_var;
+      p->stream << '\n';
+    });
+
+// Prefetch
+Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
+  data_ = make_object<PrefetchNode>(buffer, bounds);
+}
+
+TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) {
+  return Prefetch(buffer, bounds);
+});
+
+TVM_REGISTER_NODE_TYPE(PrefetchNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const PrefetchNode*>(node.get());
@@ -488,6 +420,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << ")";
     });
 
+// SeqStmt
+SeqStmt::SeqStmt(Array<Stmt> seq) {
+  auto node = make_object<SeqStmtNode>();
+  node->seq = std::move(seq);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) {
+  return SeqStmt(std::move(seq));
+});
+
+TVM_REGISTER_NODE_TYPE(SeqStmtNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const SeqStmtNode*>(node.get());
@@ -496,6 +441,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       }
     });
 
+// IfThenElse
+IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) {
+  CHECK(condition.defined());
+  CHECK(then_case.defined());
+  // else_case may be null.
+  ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
+  node->condition = std::move(condition);
+  node->then_case = std::move(then_case);
+  node->else_case = std::move(else_case);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IfThenElseNode);
+
+TVM_REGISTER_GLOBAL("tir.IfThenElse")
+    .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) {
+      return IfThenElse(condition, then_case, else_case);
+    });
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const IfThenElseNode*>(node.get());
@@ -527,6 +491,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "}\n";
     });
 
+// Evaluate
+Evaluate::Evaluate(PrimExpr value) {
+  CHECK(value.defined());
+
+  ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
+  node->value = std::move(value);
+  data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); });
+
+TVM_REGISTER_NODE_TYPE(EvaluateNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const EvaluateNode*>(node.get());
@@ -535,41 +512,75 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "\n";
     });
 
-template <typename T>
-void PrintList(const Array<T>& exprs, ReprPrinter* p) {
-  for (size_t i = 0; i < exprs.size(); ++i) {
-    p->Print(exprs[i]);
-    if (i < exprs.size() - 1) {
-      p->stream << ", ";
-    }
-  }
+// BufferStore
+BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+  ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
+  node->buffer = std::move(buffer);
+  node->value = std::move(value);
+  node->indices = std::move(indices);
+  data_ = std::move(node);
 }
 
+TVM_REGISTER_GLOBAL("tir.BufferStore")
+    .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+      return BufferStore(buffer, value, indices);
+    });
+
+TVM_REGISTER_NODE_TYPE(BufferStoreNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const ShuffleNode*>(node.get());
-      p->stream << "shuffle(";
-      PrintList(op->vectors, p);
-      p->stream << ", ";
-      PrintList(op->indices, p);
-      p->stream << ")";
+    .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const BufferStoreNode*>(node.get());
+      p->PrintIndent();
+      p->stream << op->buffer->name << "[";
+      for (size_t i = 0; i < op->indices.size(); ++i) {
+        p->Print(op->indices[i]);
+        if (i < op->indices.size() - 1) p->stream << ", ";
+      }
+      p->stream << "]";
+      p->stream << " = ";
+      p->Print(op->value);
+      p->stream << '\n';
     });
 
-TVM_REGISTER_NODE_TYPE(AttrStmtNode);
-TVM_REGISTER_NODE_TYPE(PrefetchNode);
-TVM_REGISTER_NODE_TYPE(CallNode);
-TVM_REGISTER_NODE_TYPE(LetNode);
-TVM_REGISTER_NODE_TYPE(LetStmtNode);
-TVM_REGISTER_NODE_TYPE(AssertStmtNode);
-TVM_REGISTER_NODE_TYPE(ForNode);
-TVM_REGISTER_NODE_TYPE(StoreNode);
-TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
-TVM_REGISTER_NODE_TYPE(AllocateNode);
-TVM_REGISTER_NODE_TYPE(FreeNode);
-TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
-TVM_REGISTER_NODE_TYPE(SeqStmtNode);
-TVM_REGISTER_NODE_TYPE(IfThenElseNode);
-TVM_REGISTER_NODE_TYPE(EvaluateNode);
+// BufferRealize
+BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+  data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferRealize")
+    .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+      return BufferRealize(buffer, bounds, condition, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
 
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const BufferRealizeNode*>(node.get());
+      p->PrintIndent();
+      p->stream << "buffer_realize " << op->buffer->name << "(";
+      for (size_t i = 0; i < op->bounds.size(); ++i) {
+        p->stream << "[";
+        p->Print(op->bounds[i]->min);
+        p->stream << ", ";
+        p->Print(op->bounds[i]->extent);
+        p->stream << "]";
+        if (i < op->bounds.size() - 1) p->stream << ", ";
+      }
+      p->stream << ")";
+      if (!is_one(op->condition)) {
+        p->stream << " if ";
+        p->Print(op->condition);
+      }
+      p->stream << " {\n";
+
+      p->indent += 2;
+      p->Print(op->body);
+      p->indent -= 2;
+
+      p->PrintIndent();
+      p->stream << "}\n";
+    });
 }  // namespace tir
 }  // namespace tvm
index 06958a2..67329aa 100644 (file)
@@ -499,8 +499,7 @@ class IRSubstitue : public StmtExprMutator {
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<StoreNode>();
     if (auto mapped_var = vmap_(op->buffer_var)) {
-      return StoreNode::make(Downcast<Var>(mapped_var.value()), op->value, op->index,
-                             op->predicate);
+      return Store(Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
     } else {
       return ret;
     }
index 67a88f5..868845f 100644 (file)
@@ -346,7 +346,7 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) {
 
     const IfThenElseNode* new_if_node = new_if.as<IfThenElseNode>();
     CHECK(new_if_node);
-    new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for);
+    new_if = IfThenElse(new_if_node->condition, then_for, else_for);
     if (i < if2for_map_[if_stmt.get()].size() - 1) {
       const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
       const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx);
@@ -376,20 +376,19 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
       Stmt new_for = Stmt();
       for (size_t i = new_if_list.size() - 1; i > 0; --i) {
         CHECK(current_if_node);
-        const Stmt current_if_stmt = IfThenElseNode::make(
+        const Stmt current_if_stmt = IfThenElse(
             current_if_node->condition, current_if_node->then_case, current_if_node->else_case);
         next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
         CHECK(next_if_node);
-        new_for =
-            IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case);
+        new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case);
         current_if_node = new_for.as<IfThenElseNode>();
       }
 
       if (!new_for.get()) {
         const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
         CHECK(first_if_node);
-        new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case,
-                                       first_if_node->else_case);
+        new_for = IfThenElse(first_if_node->condition, first_if_node->then_case,
+                             first_if_node->else_case);
       }
       *ret = new_for;
     }
index 14452a6..ae7065d 100644 (file)
@@ -42,8 +42,7 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg
   if (!is_one(scond)) {
     std::ostringstream os;
     os << "Argument " << arg_name << " has an unsatisfied constraint";
-    asserts->emplace_back(
-        AssertStmtNode::make(scond, tvm::tir::StringImm(os.str()), EvaluateNode::make(0)));
+    asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0)));
   }
 }
 
@@ -57,7 +56,7 @@ bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::str
       defs_.emplace_back(v_arg);
       if (with_lets) {
         (*def_map_)[v] = arg;
-        init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0)));
+        init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
       } else {
         (*def_map_)[v] = value;
       }
@@ -151,14 +150,14 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
                              const std::string& arg_name) {
   const DataType tvm_shape_type = DataType::ShapeIndex();
   const DataType tvm_ndim_type = DataType::Int(32);
-  const Stmt nop = EvaluateNode::make(0);
+  const Stmt nop = Evaluate(0);
   // dimension checks
   PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
   PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
   std::ostringstream ndim_err_msg;
   ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
   auto msg = tvm::tir::StringImm(ndim_err_msg.str());
-  asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
+  asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
   // type checks
   DataType dtype = buffer->dtype;
   std::ostringstream type_err_msg;
@@ -171,8 +170,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
                        IntImm(DataType::UInt(16), dtype.lanes()));
   if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
     auto type_msg = tvm::tir::StringImm(type_err_msg.str());
-    asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
-    asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
+    asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
+    asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
   }
   // data field
   if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
@@ -180,15 +179,14 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
     Var vptr(buffer->data);
     def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
     // mark alignment of external bufs
-    init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment,
-                                               IntImm(DataType::Int(32), buffer->data_alignment),
-                                               nop));
+    init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
+                                     IntImm(DataType::Int(32), buffer->data_alignment), nop));
   }
 
   Var v_shape(arg_name + ".shape", DataType::Handle());
   def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
-  init_nest_.emplace_back(LetStmtNode::make(
-      v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
+  init_nest_.emplace_back(
+      LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
   for (size_t k = 0; k < buffer->shape.size(); ++k) {
     if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) {
       break;
@@ -203,8 +201,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
   // strides field
   Var v_strides(arg_name + ".strides", DataType::Handle());
   def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
-  init_nest_.emplace_back(LetStmtNode::make(
-      v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
+  init_nest_.emplace_back(
+      LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
   PrimExpr is_null =
       Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic);
   if (buffer->strides.size() == 0) {
@@ -225,10 +223,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
     if (conds.size() != 0) {
       auto stride_msg = tvm::tir::StringImm(stride_err_msg.str());
       auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
-      Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg,
-                                        EvaluateNode::make(0));
-      check = IfThenElseNode::make(Not(is_null), check, Stmt());
-      asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
+      Stmt check = AssertStmt(foldl(fand, const_true(1), conds), stride_msg, Evaluate(0));
+      check = IfThenElse(Not(is_null), check, Stmt());
+      asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
     }
   } else if (buffer->buffer_type == kAutoBroadcast) {
     DataType stype = buffer->DefaultIndexType();
@@ -249,7 +246,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
     std::ostringstream stride_null_err_msg;
     stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
     asserts_.emplace_back(
-        AssertStmtNode::make(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
+        AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
 
     for (size_t k = 0; k < buffer->strides.size(); ++k) {
       std::ostringstream field_name;
index 55a8131..94464a0 100644 (file)
@@ -85,10 +85,10 @@ class BoundChecker : public StmtExprMutator {
     if (store_scope_bound_collector_.size()) {
       PrimExpr condition = MakeCondition();
       if (!condition.as<StringImmNode>()) {
-        Stmt nop = EvaluateNode::make(1);
-        Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
-        Stmt else_case = AssertStmtNode::make(condition, StringImm(error_message_), nop);
-        Stmt body = IfThenElseNode::make(condition, then_case, else_case);
+        Stmt nop = Evaluate(1);
+        Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate);
+        Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
+        Stmt body = IfThenElse(condition, then_case, else_case);
         return body;
       }
     }
index 9e5e4ae..73bf4c6 100644 (file)
@@ -95,7 +95,7 @@ class ContextCallCombiner final : public StmtExprMutator {
   static Stmt BuildContext(
       const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap, Stmt body) {
     for (const auto& kv : cmap) {
-      body = LetStmtNode::make(kv.second, kv.first, body);
+      body = LetStmt(kv.second, kv.first, body);
     }
     return body;
   }
index 3072c0d..384dbcb 100644 (file)
@@ -195,7 +195,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor {
   }
 
   std::vector<Stmt> GetSync(std::string sync_name) {
-    return {EvaluateNode::make(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
+    return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
   }
 
   const std::unordered_set<const VarNode*>& touched_;
@@ -331,9 +331,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
     CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
     PrimExpr min = r->min;
     PrimExpr extent = r->extent;
-    return EvaluateNode::make(Call(DataType::Int(32), func,
-                                   {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
-                                   CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), func,
+                         {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
+                         CallNode::Intrinsic));
   }
   // Write barrier name
   bool read_barrier_{false};
@@ -555,16 +555,14 @@ class CoProcInstDepDetector : public StmtVisitor {
   }
 
   Stmt MakePush(int from, int to) {
-    return EvaluateNode::make(
-        Call(DataType::Int(32), sync_push_name_,
-             {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-             CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), sync_push_name_,
+                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+                         CallNode::Intrinsic));
   }
   Stmt MakePop(int from, int to) {
-    return EvaluateNode::make(
-        Call(DataType::Int(32), sync_pop_name_,
-             {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
-             CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), sync_pop_name_,
+                         {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+                         CallNode::Intrinsic));
   }
   // sync states.
   SyncState first_state_, last_state_, curr_state_;
index 0decb94..5034a85 100644 (file)
@@ -29,7 +29,7 @@ namespace tvm {
 namespace tir {
 
 Stmt DecorateDeviceScope(Stmt&& stmt) {
-  Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt);
+  Stmt body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt);
   return body;
 }
 
index 3f53022..9d5ee95 100644 (file)
@@ -125,10 +125,10 @@ class DoubleBufferInjector : public StmtExprMutator {
       }
       CHECK(it->second.loop != nullptr);
       auto& alloc_nest = loop_allocs_[it->second.loop];
-      alloc_nest.emplace_back(AttrStmtNode::make(
-          op->buffer_var, attr::storage_scope, StringImm(it->second.scope), EvaluateNode::make(0)));
-      alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents,
-                                                 op->condition, EvaluateNode::make(0)));
+      alloc_nest.emplace_back(
+          AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0)));
+      alloc_nest.emplace_back(
+          Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0)));
       return op->body;
     } else {
       return StmtExprMutator::VisitStmt_(op);
@@ -158,16 +158,15 @@ class DoubleBufferInjector : public StmtExprMutator {
           vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
           loop_seq.emplace_back(Substitute(old_loop->body, vmap));
         }
-        Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type,
-                                  old_loop->device_api, SeqStmt::Flatten(loop_seq));
+        Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
+                        SeqStmt::Flatten(loop_seq));
         // tail
         std::vector<Stmt> tail_seq;
         Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
         for (int32_t i = 0; i < split_loop_; ++i) {
           PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
           vmap[old_loop->loop_var.get()] = idx;
-          tail_seq.emplace_back(
-              IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap)));
+          tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap)));
         }
         stmt = SeqStmt::Flatten(loop, tail_seq);
       }
@@ -189,8 +188,8 @@ class DoubleBufferInjector : public StmtExprMutator {
       const StorageEntry& e = it->second;
       CHECK(in_double_buffer_scope_);
       CHECK(e.stride.defined());
-      return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
-                             op->predicate);
+      return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
+                   op->predicate);
     } else {
       return stmt;
     }
@@ -243,8 +242,8 @@ class DoubleBufferInjector : public StmtExprMutator {
     vmap[e.loop->loop_var.get()] = loop_shift;
     vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
     body = Substitute(body, vmap);
-    body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body);
-    body = IfThenElseNode::make(loop_shift < e.loop->extent, body);
+    body = AttrStmt(buffer, attr::double_buffer_write, 1, body);
+    body = IfThenElse(loop_shift < e.loop->extent, body);
     return body;
   }
   // Storage entry for those who need double buffering.
index f9088e3..042ddab 100644 (file)
@@ -252,8 +252,7 @@ class VTInjector : public StmtExprMutator {
     trigger_base_inject_ = !allow_share_;
     auto it = alloc_remap_.find(op->buffer_var.get());
     if (it != alloc_remap_.end()) {
-      return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second),
-                             op->predicate);
+      return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate);
     } else {
       return stmt;
     }
@@ -271,7 +270,7 @@ class VTInjector : public StmtExprMutator {
       if (value.same_as(op->value) && body.same_as(op->body)) {
         return GetRef<Stmt>(op);
       } else {
-        return AttrStmtNode::make(op->node, op->attr_key, value, body);
+        return AttrStmt(op->node, op->attr_key, value, body);
       }
     }
   }
@@ -286,7 +285,7 @@ class VTInjector : public StmtExprMutator {
     if (value.same_as(op->value) && body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return LetStmtNode::make(op->var, value, body);
+      return LetStmt(op->var, value, body);
     }
   }
   // For
@@ -304,7 +303,7 @@ class VTInjector : public StmtExprMutator {
     if (extent.same_as(op->extent) && body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
+      return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
     }
   }
   // IfThenElse
@@ -327,7 +326,7 @@ class VTInjector : public StmtExprMutator {
         else_case.same_as(op->else_case)) {
       return GetRef<Stmt>(op);
     } else {
-      return IfThenElseNode::make(condition, then_case, else_case);
+      return IfThenElse(condition, then_case, else_case);
     }
   }
 
@@ -387,7 +386,7 @@ class VTInjector : public StmtExprMutator {
     if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) {
       return GetRef<Stmt>(op);
     } else {
-      return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
+      return Allocate(op->buffer_var, op->dtype, extents, condition, body);
     }
   }
 
@@ -417,8 +416,8 @@ class VTInjector : public StmtExprMutator {
       Var idx(var_->name_hint + ".s", var_->dtype);
       Map<Var, PrimExpr> values{{var_, idx}};
       stmt = Substitute(stmt, values);
-      return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
-                           ForType::Serial, DeviceAPI::None, stmt);
+      return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
+                 ForType::Serial, DeviceAPI::None, stmt);
     }
   }
 
index 28f347e..4f21f0b 100644 (file)
@@ -122,8 +122,7 @@ class IRConvertSSA final : public StmtExprMutator {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     op = stmt.as<StoreNode>();
     if (scope_.count(op->buffer_var.get())) {
-      return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index,
-                             op->predicate);
+      return Store(scope_[op->buffer_var.get()].back(), op->value, op->index, op->predicate);
     } else {
       return stmt;
     }
@@ -136,7 +135,7 @@ class IRConvertSSA final : public StmtExprMutator {
       scope_[v.get()].push_back(new_var);
       Stmt body = this->VisitStmt(op->body);
       scope_[v.get()].pop_back();
-      return LetStmtNode::make(new_var, value, body);
+      return LetStmt(new_var, value, body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
@@ -150,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator {
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
       op = stmt.as<ForNode>();
-      return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
+      return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
@@ -164,7 +163,7 @@ class IRConvertSSA final : public StmtExprMutator {
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       scope_[v.get()].pop_back();
       op = stmt.as<AllocateNode>();
-      return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body);
+      return Allocate(new_var, op->dtype, op->extents, op->condition, op->body);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
@@ -179,13 +178,13 @@ class IRConvertSSA final : public StmtExprMutator {
           if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
           alloc = new_alloc.as<AllocateNode>();
           CHECK(alloc);
-          return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc);
+          return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc);
         }
       }
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AttrStmtNode>();
       if (scope_.count(v) && scope_[v].size() != 0) {
-        return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body);
+        return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body);
       } else {
         return stmt;
       }
index 4fbd2a0..6c0eeea 100644 (file)
@@ -129,8 +129,7 @@ inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind ki
                          PrimExpr value) {
   Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
                           make_const(DataType::Int(32), static_cast<int>(kind)), value};
-  return EvaluateNode::make(
-      Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
+  return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
 }
 
 /*!
index bb4e5f7..ca4b39e 100644 (file)
@@ -41,7 +41,7 @@ class AttrScopeLifter : public StmtMutator {
   Stmt Lift(Stmt stmt) {
     stmt = operator()(std::move(stmt));
     if (attr_node_.defined()) {
-      stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt);
+      stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt);
     }
     return stmt;
   }
@@ -51,11 +51,11 @@ class AttrScopeLifter : public StmtMutator {
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<AllocateNode>();
     if (attr_node_.defined()) {
-      Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body);
+      Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body);
       // undefine them
       attr_node_ = ObjectRef();
       attr_value_ = PrimExpr();
-      return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body);
+      return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body);
     } else {
       return stmt;
     }
@@ -111,7 +111,7 @@ class AttrScopeLifter : public StmtMutator {
       }
       Stmt stmt = SeqStmt::Flatten(seq);
       if (attr_node[begin].defined()) {
-        stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt);
+        stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt);
       }
       reorg.push_back(stmt);
       begin = end;
@@ -137,14 +137,14 @@ class AttrScopeLifter : public StmtMutator {
       if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
         return GetRef<Stmt>(op);
       } else {
-        return IfThenElseNode::make(op->condition, then_case, else_case);
+        return IfThenElse(op->condition, then_case, else_case);
       }
     } else {
       if (first_node.defined()) {
-        then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case);
+        then_case = AttrStmt(first_node, attr_key_, first_value, then_case);
       }
       if (attr_node_.defined()) {
-        else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case);
+        else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case);
         // undefine them
         attr_node_ = ObjectRef();
         attr_value_ = PrimExpr();
@@ -152,7 +152,7 @@ class AttrScopeLifter : public StmtMutator {
       if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
         return GetRef<Stmt>(op);
       } else {
-        return IfThenElseNode::make(op->condition, then_case, else_case);
+        return IfThenElse(op->condition, then_case, else_case);
       }
     }
   }
index b06bb8a..7dbf0fc 100644 (file)
@@ -303,9 +303,9 @@ class ThreadPartitionInserter : public StmtMutator {
       // add branch code inside the innermost thread scope
       if (innermost_thread_scope_) {
         Stmt simplified_body = ConditionEliminator(ps_)(op->body);
-        Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
+        Stmt body = IfThenElse(cond_, simplified_body, op->body);
         PrimExpr value = this->VisitExpr(op->value);
-        stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
+        stmt = AttrStmt(op->node, op->attr_key, value, body);
       }
       innermost_thread_scope_ = false;
       return stmt;
@@ -588,8 +588,8 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
     // If the loop extent is 1, do not create the loop anymore
     return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
   } else {
-    return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
-                         for_node->for_type, for_node->device_api, body);
+    return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type,
+               for_node->device_api, body);
   }
 }
 
index 4a15501..154023c 100644 (file)
@@ -80,8 +80,8 @@ class CustomDatatypesLowerer : public StmtExprMutator {
 
     if (toBeLowered) {
       auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
-      return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
-                                allocate->condition, allocate->body);
+      return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents,
+                      allocate->condition, allocate->body);
     }
     return stmt;
   }
index a842462..0b87757 100644 (file)
@@ -52,7 +52,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
           << "Double allocation of " << it->second.scope.to_string();
 
       if (info->head_address.defined()) {
-        return LetStmtNode::make(op->buffer_var, info->head_address, op->body);
+        return LetStmt(op->buffer_var, info->head_address, op->body);
       } else {
         return op->body;
       }
index f6daabd..8604017 100644 (file)
@@ -84,15 +84,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     if (it != alloc_remap_.end()) {
       const AllocateNode* repl = it->second.as<AllocateNode>();
       if (warp_allocs_.count(repl)) {
-        stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition,
-                                  op->body);
-        stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt);
+        stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
+        stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt);
       } else {
         // use volatile access to shared buffer.
-        stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body);
-        stmt =
-            AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
-        stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt);
+        stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body);
+        stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
+        stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt);
       }
       return stmt;
     } else {
@@ -214,12 +212,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       for (size_t idx = 0; idx < size; ++idx) {
         shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
         PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred));
+        seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred));
 
         // Uses a local variable to store the shuffled data.
         // Later on, this allocation will be properly attached to this statement.
         Var var("t" + std::to_string(idx), types[idx]);
-        Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0));
+        Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0));
         local_vars.push_back(s);
       }
 
@@ -232,11 +230,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         PrimExpr pred = const_true(1);
         PrimExpr mask =
             Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
-        seq.emplace_back(StoreNode::make(mask_var, mask, index, pred));
+        seq.emplace_back(Store(mask_var, mask, index, pred));
         // Push allocation with an empty body. Later this will be fixed
         // when the entire body is ready.
-        auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred,
-                                       EvaluateNode::make(0));
+        auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0));
         local_vars.push_back(stmt);
       }
 
@@ -266,7 +263,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
           const char* shfl_func = intrinsic::tvm_warp_shuffle_down;
           PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset);
           const AllocateNode* repl = local_vars[i].as<AllocateNode>();
-          Stmt s = StoreNode::make(repl->buffer_var, other, index, pred);
+          Stmt s = Store(repl->buffer_var, other, index, pred);
           seq.push_back(s);
 
           PrimExpr load = Load(types[i], repl->buffer_var, index, pred);
@@ -281,7 +278,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         for (size_t i = 0; i < size; ++i) {
           Var var = shared_bufs[i];
           PrimExpr pred = const_true(types[i].lanes());
-          stores[i] = StoreNode::make(var, ret[i], index, pred);
+          stores[i] = Store(var, ret[i], index, pred);
         }
         seq.push_back(SeqStmt::Flatten(stores));
       }
@@ -296,7 +293,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         const char* shfl_func = intrinsic::tvm_warp_shuffle;
         PrimExpr val = Load(types[i], var, index, pred);
         PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0);
-        seq.push_back(StoreNode::make(var, splat, index, pred));
+        seq.push_back(Store(var, splat, index, pred));
       }
 
       // Update existing allocations.
@@ -306,7 +303,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         Var var = shared_bufs[i];
         load_remap_[buffers[i]] = Load(types[i], var, index, pred);
         Array<PrimExpr> extents{PrimExpr(1)};
-        auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0));
+        auto node = Allocate(var, types[i], extents, pred, Evaluate(0));
         alloc_remap_[buffers[i]] = node;
         warp_allocs_.insert(node.get());
       }
@@ -318,7 +315,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         for (size_t i = 0; i < size; ++i) {
           PrimExpr pred = const_true(types[i].lanes());
           Var buffer_var = Downcast<Var>(call->args[2 + size + i]);
-          stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
+          stores[i] = Store(buffer_var, values[i], 0, pred);
         }
         return SeqStmt::Flatten(stores);
       }
@@ -332,8 +329,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       for (size_t idx = 0; idx < size; ++idx) {
         shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
         PrimExpr pred = const_true(types[idx].lanes());
-        seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx],
-                                         BufIndex(reduce_index, group_index, reduce_extent), pred));
+        seq.emplace_back(Store(shared_bufs[idx], values[idx],
+                               BufIndex(reduce_index, group_index, reduce_extent), pred));
       }
       seq.emplace_back(SyncThread("shared"));
       seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
@@ -344,9 +341,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         load_remap_[buffers[idx]] =
             Load(types[idx], shared_bufs[idx],
                  BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
-        alloc_remap_[buffers[idx]] = AllocateNode::make(
-            shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred,
-            EvaluateNode::make(0));
+        alloc_remap_[buffers[idx]] =
+            Allocate(shared_bufs[idx], types[idx],
+                     {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
       }
     }
 
@@ -355,9 +352,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     for (auto var : local_vars) {
       const AllocateNode* repl = var.as<AllocateNode>();
       if (repl) {
-        body =
-            AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
-        body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), body);
+        body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
+        body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body);
       }
     }
 
@@ -390,7 +386,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       Array<PrimExpr> ret = (*combiner)(a, b);
       std::vector<Stmt> stores(size);
       for (size_t i = 0; i < size; ++i) {
-        stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
+        stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true());
       }
       return SeqStmt::Flatten(stores);
     };
@@ -399,7 +395,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       // reduction with the boundary condition
       reduce_align = reduce_align >> 1;
       PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
-      seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
+      seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
     CHECK(threadx_extent >= 1 && warp_size_ >= 1);
@@ -407,7 +403,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     while (reduce_align > threadx_extent || reduce_align > warp_size_) {
       reduce_align = reduce_align >> 1;
       PrimExpr cond = reduce_index < reduce_align;
-      seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
+      seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
       seq.emplace_back(SyncThread("shared"));
     }
     // in warp synchronization.
@@ -420,7 +416,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
     if (in_warp_seq.size() != 0) {
       Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
-      seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body));
+      seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
       seq.emplace_back(SyncThread("shared"));
     }
     return SeqStmt::Flatten(seq);
@@ -456,8 +452,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   }
   // sync thread op.
   static Stmt SyncThread(const std::string& sync) {
-    return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
-                                   {StringImm(sync)}, CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)},
+                         CallNode::Intrinsic));
   }
 
   // Emit warp shuffle intrinsic calls.
index 0e52802..7611e0f 100644 (file)
@@ -54,14 +54,14 @@ class BuiltinLower : public StmtExprMutator {
     stack_tcode_ = Var("stack_tcode", DataType::Handle());
     stmt = this->VisitStmt(stmt);
     if (max_shape_stack_ != 0) {
-      stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
+      stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
     }
     if (max_array_stack_ != 0) {
-      stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt);
+      stmt = LetStmt(stack_array_, StackAlloca("array", max_array_stack_), stmt);
     }
     if (max_arg_stack_ != 0) {
-      stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
-      stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
+      stmt = LetStmt(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
+      stmt = LetStmt(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
     }
     return stmt;
   }
@@ -102,15 +102,15 @@ class BuiltinLower : public StmtExprMutator {
     }
     CHECK(device_type_.defined()) << "Unknown device type in current IR";
     CHECK(device_id_.defined()) << "Unknown device id in current IR";
-    Stmt throw_last_error = EvaluateNode::make(
-        Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
+    Stmt throw_last_error =
+        Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
 
-    Stmt body = SeqStmt({IfThenElseNode::make(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null,
-                                                   {op->buffer_var}, CallNode::PureIntrinsic),
-                                              throw_last_error),
+    Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null,
+                                         {op->buffer_var}, CallNode::PureIntrinsic),
+                                    throw_last_error),
                          op->body});
 
-    Stmt alloca = LetStmtNode::make(
+    Stmt alloca = LetStmt(
         op->buffer_var,
         Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace",
              {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
@@ -123,11 +123,10 @@ class BuiltinLower : public StmtExprMutator {
                             {cast(DataType::Int(32), device_type_),
                              cast(DataType::Int(32), device_id_), op->buffer_var},
                             CallNode::Extern);
-    Stmt free_stmt =
-        IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
+    Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
     body = SeqStmt({alloca, free_stmt});
-    body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment,
-                              make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
+    body = AttrStmt(op->buffer_var, attr::storage_alignment,
+                    make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
     return body;
   }
 
@@ -166,8 +165,8 @@ class BuiltinLower : public StmtExprMutator {
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
     for (size_t i = 0; i < op->args.size(); ++i) {
-      prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
-                                             ConstInt32(stack_begin + i), const_true(1)));
+      prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
+                                   ConstInt32(stack_begin + i), const_true(1)));
     }
     return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
   }
@@ -234,7 +233,7 @@ class BuiltinLower : public StmtExprMutator {
       }
       if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
       prep_seq_.emplace_back(
-          StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
+          Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
     }
     // UPDATE stack value
     max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
@@ -272,7 +271,7 @@ class BuiltinLower : public StmtExprMutator {
       int arg_tcode = api_type.code();
       CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
       prep_seq_.emplace_back(
-          StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
+          Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
     }
     // UPDATE stack value
     max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
index 7294c01..a0ddf26 100644 (file)
@@ -220,9 +220,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
     warp_group_ = (alloc_size + (factor - 1)) / factor;
     alloc_size = warp_group_ * factor;
 
-    return AllocateNode::make(op->buffer_var, op->dtype,
-                              {make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
-                              this->VisitStmt(op->body));
+    return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
+                    op->condition, this->VisitStmt(op->body));
   }
 
  protected:
@@ -235,7 +234,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
     if (op->buffer_var.get() == buffer_) {
       PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
-      return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
+      return Store(op->buffer_var, op->value, local_index, op->predicate);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -373,7 +372,7 @@ class WarpMemoryRewriter : private StmtMutator {
         warp_buffer_.insert(buf);
         Stmt ret = StmtMutator::VisitStmt_(op);
         op = ret.as<AttrStmtNode>();
-        return AttrStmtNode::make(op->node, op->attr_key, StringImm("local"), op->body);
+        return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body);
       }
     }
     return StmtMutator::VisitStmt_(op);
index 0fdfb85..a91e350 100644 (file)
@@ -41,7 +41,7 @@ namespace tvm {
 namespace tir {
 
 inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
-  return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImm(msg), EvaluateNode::make(0));
+  return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
 }
 
 PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
@@ -55,7 +55,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
   std::string name_hint = global_symbol.value();
 
   auto* func_ptr = func.CopyOnWrite();
-  const Stmt nop = EvaluateNode::make(0);
+  const Stmt nop = Evaluate(0);
   int num_args = static_cast<int>(func_ptr->params.size());
   CHECK_LE(num_unpacked_args, num_args);
 
@@ -122,32 +122,29 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
     }
     if (i < num_packed_args) {
       // Value loads
-      seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop));
+      seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
       // type code checks
       Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
-      seq_init.emplace_back(LetStmtNode::make(tcode,
-                                              Load(DataType::Int(32), v_packed_arg_type_ids,
-                                                   IntImm(DataType::Int(32), i), const_true(1)),
-                                              nop));
+      seq_init.emplace_back(LetStmt(tcode,
+                                    Load(DataType::Int(32), v_packed_arg_type_ids,
+                                         IntImm(DataType::Int(32), i), const_true(1)),
+                                    nop));
       DataType t = v_arg.dtype();
       if (t.is_handle()) {
         std::ostringstream msg;
         msg << name_hint << ": Expect arg[" << i << "] to be pointer";
-        seq_check.emplace_back(
-            AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
-                                     tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
-                                 tvm::tir::StringImm(msg.str()), nop));
+        seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
+                                              tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
+                                          tvm::tir::StringImm(msg.str()), nop));
       } else if (t.is_int() || t.is_uint()) {
         std::ostringstream msg;
         msg << name_hint << ": Expect arg[" << i << "] to be int";
-        seq_check.emplace_back(
-            AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
+        seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
       } else {
         CHECK(t.is_float());
         std::ostringstream msg;
         msg << name_hint << ": Expect arg[" << i << "] to be float";
-        seq_check.emplace_back(
-            AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
+        seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
       }
     } else {
       args.push_back(v_arg);
@@ -182,19 +179,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
     func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
   }
 
-  auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope,
-                                 StringImm(name_hint + "_compute_"), func_ptr->body);
+  Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
+                       StringImm(name_hint + "_compute_"), func_ptr->body);
   // Set device context
   if (vmap.count(device_id.get())) {
     PrimExpr node = StringImm("default");
-    seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop));
-    seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop));
+    seq_check.push_back(AttrStmt(node, attr::device_context_id, device_id, nop));
+    seq_check.push_back(AttrStmt(node, attr::device_context_type, device_type, nop));
 
     if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
-      Stmt set_device = EvaluateNode::make(
-          Call(DataType::Int(32), intrinsic::tvm_call_packed,
-               {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
-               CallNode::Intrinsic));
+      Stmt set_device =
+          Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed,
+                        {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
+                        CallNode::Intrinsic));
       body = SeqStmt({set_device, body});
     }
   }
index af2886e..07b0ea2 100644 (file)
@@ -208,7 +208,7 @@ class DataTypeRewriter : public StmtExprMutator {
     is_index_ = true;
     PrimExpr index = this->VisitExpr(op->index);
     is_index_ = false;
-    Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate);
+    Stmt s = Store(op->buffer_var, op->value, index, op->predicate);
     return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
   }
 
@@ -219,8 +219,8 @@ class DataTypeRewriter : public StmtExprMutator {
                          << ", but get " << s->GetTypeKey();
     PrimExpr e = VisitExpr(op->loop_var);
     Var var = Downcast<Var>(e);
-    return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
-                         op->for_type, op->device_api, op->body);
+    return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type,
+               op->device_api, op->body);
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -237,7 +237,7 @@ class DataTypeRewriter : public StmtExprMutator {
       if (ivmap_.find(iv) == ivmap_.end()) {
         ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag);
       }
-      return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
+      return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
index efb9e69..017d1b4 100644 (file)
@@ -52,7 +52,7 @@ class ThreadAxisRewriter : private StmtExprMutator {
           CHECK(vmap_[v].same_as(new_iv->var));
         }
         Stmt body = this->VisitStmt(op->body);
-        return AttrStmtNode::make(new_iv, op->attr_key, op->value, body);
+        return AttrStmt(new_iv, op->attr_key, op->value, body);
       }
     }
     return StmtExprMutator::VisitStmt_(op);
index 0463d44..cd3a4b7 100644 (file)
@@ -57,7 +57,7 @@ class NoOpRemover : public StmtMutator {
         if (is_no_op(op->then_case)) {
           return MakeEvaluate(op->condition);
         } else {
-          return IfThenElseNode::make(op->condition, op->then_case);
+          return IfThenElse(op->condition, op->then_case);
         }
       } else {
         return stmt;
@@ -74,7 +74,7 @@ class NoOpRemover : public StmtMutator {
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<ForNode>();
     if (is_zero(op->extent)) {
-      return EvaluateNode::make(0);
+      return Evaluate(0);
     }
     return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
   }
@@ -91,7 +91,7 @@ class NoOpRemover : public StmtMutator {
   }
   Stmt VisitStmt_(const EvaluateNode* op) final {
     if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
-    return EvaluateNode::make(0);
+    return Evaluate(0);
   }
 
   Stmt VisitStmt_(const SeqStmtNode* op) final {
@@ -128,9 +128,9 @@ class NoOpRemover : public StmtMutator {
  private:
   Stmt MakeEvaluate(PrimExpr value) {
     if (HasSideEffect(value)) {
-      return EvaluateNode::make(value);
+      return Evaluate(value);
     } else {
-      return EvaluateNode::make(0);
+      return Evaluate(0);
     }
   }
   Stmt MakeEvaluate(const Array<PrimExpr>& values) {
@@ -138,13 +138,13 @@ class NoOpRemover : public StmtMutator {
     for (PrimExpr e : values) {
       if (HasSideEffect(e)) {
         if (stmt.defined()) {
-          stmt = SeqStmt({stmt, EvaluateNode::make(e)});
+          stmt = SeqStmt({stmt, Evaluate(e)});
         } else {
-          stmt = EvaluateNode::make(e);
+          stmt = Evaluate(e);
         }
       }
     }
-    return stmt.defined() ? stmt : EvaluateNode::make(0);
+    return stmt.defined() ? stmt : Evaluate(0);
   }
 };
 
index 759b320..3be2329 100644 (file)
@@ -80,7 +80,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
     if (const LoadNode* load = op->value.as<LoadNode>()) {
       if (load->buffer_var.same_as(op->buffer_var) &&
           tir::ExprDeepEqual()(load->index, op->index)) {
-        return EvaluateNode::make(0);
+        return Evaluate(0);
       }
     }
     return GetRef<Stmt>(op);
index 1806265..67336d4 100644 (file)
@@ -59,7 +59,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
       if (value.same_as(op->value) && body.same_as(op->body)) {
         return GetRef<Stmt>(op);
       }
-      return AttrStmtNode::make(op->node, op->attr_key, value, body);
+      return AttrStmt(op->node, op->attr_key, value, body);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -76,7 +76,7 @@ class VarUseDefAnalysis : public StmtExprMutator {
       if (body.same_as(op->body) && value.same_as(op->value)) {
         return GetRef<Stmt>(op);
       } else {
-        return LetStmtNode::make(op->var, value, body);
+        return LetStmt(op->var, value, body);
       }
     }
   }
@@ -237,7 +237,7 @@ class HostDeviceSplitter : public StmtMutator {
     for (PrimExpr ext : m.thread_extent_) {
       call_args.push_back(ext);
     }
-    return EvaluateNode::make(
+    return Evaluate(
         Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic));
   }
 
index 21ddaaf..4c3de58 100644 (file)
@@ -71,7 +71,7 @@ class StorageFlattener : public StmtExprMutator {
     if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
       CHECK(it->second.as<VarNode>());
       Var buf_var = Downcast<Var>(it->second);
-      return StoreNode::make(buf_var, op->value, op->index, op->predicate);
+      return Store(buf_var, op->value, op->index, op->predicate);
     } else {
       return stmt;
     }
@@ -87,7 +87,7 @@ class StorageFlattener : public StmtExprMutator {
       Stmt body = this->VisitStmt(op->body);
       auto it = buf_map_.find(buffer);
       CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer;
-      body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body));
+      body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body));
       return body;
     } else if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
@@ -134,8 +134,8 @@ class StorageFlattener : public StmtExprMutator {
     // To create bound attribute collector should has at least one item.
     if (create_bound_attributes_ && shape_collector_.size()) {
       for (size_t i = 0; i < shape_collector_.size(); ++i) {
-        body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound,
-                                  MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
+        body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound,
+                        MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
       }
     }
     return body;
@@ -217,23 +217,22 @@ class StorageFlattener : public StmtExprMutator {
       }
       if (strides.size() != 0) {
         int first_dim = 0;
-        ret = AllocateNode::make(e.buffer->data, storage_type,
-                                 {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
-                                 make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+        ret = Allocate(e.buffer->data, storage_type,
+                       {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
+                       make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
       } else {
         shape = e.buffer->shape;
         if (shape.size() == 0) {
           shape.push_back(make_const(DataType::Int(32), 1));
         }
-        ret = AllocateNode::make(e.buffer->data, storage_type, shape,
-                                 make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+        ret = Allocate(e.buffer->data, storage_type, shape,
+                       make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
       }
-      ret =
-          AttrStmtNode::make(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
+      ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
 
       if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
-        ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound,
-                                 MakeBound(e.buffer->dtype, e.buffer->shape), ret);
+        ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
+                       MakeBound(e.buffer->dtype, e.buffer->shape), ret);
       }
       return ret;
     }
@@ -319,17 +318,16 @@ class StorageFlattener : public StmtExprMutator {
     }
     for (int i = starts; i >= 0; --i) {
       if (i < starts) {
-        stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None,
-                             stmt);
+        stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
       } else {
         PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
         PrimExpr address =
             Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
         PrimExpr prefetch =
             Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
-        stmt = EvaluateNode::make(prefetch);
+        stmt = Evaluate(prefetch);
         PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
-        stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
+        stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
       }
     }
     return stmt;
index 952d273..2d09e8b 100644 (file)
@@ -350,9 +350,8 @@ class StoragePlanRewriter : public StmtExprMutator {
       for (StorageEntry* e : attach_map_.at(nullptr)) {
         // CHECK_EQ(e->scope.rank, 0);
         if (e->new_alloc.defined()) {
-          nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
-                                               StringImm(e->scope.to_string()),
-                                               EvaluateNode::make(0)));
+          nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope,
+                                     StringImm(e->scope.to_string()), Evaluate(0)));
           nest.push_back(e->new_alloc);
         }
       }
@@ -365,8 +364,8 @@ class StoragePlanRewriter : public StmtExprMutator {
     op = stmt.as<StoreNode>();
     auto it = alloc_map_.find(op->buffer_var.get());
     if (it == alloc_map_.end()) return stmt;
-    return StoreNode::make(it->second->alloc_var, op->value,
-                           RemapIndex(op->value.dtype(), op->index, it->second), op->predicate);
+    return Store(it->second->alloc_var, op->value,
+                 RemapIndex(op->value.dtype(), op->index, it->second), op->predicate);
   }
   PrimExpr VisitExpr_(const LoadNode* op) final {
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
@@ -421,7 +420,7 @@ class StoragePlanRewriter : public StmtExprMutator {
         auto& svec = attach_map_[op];
         Stmt stmt = StmtExprMutator::VisitStmt_(op);
         op = stmt.as<AttrStmtNode>();
-        return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
+        return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
       } else {
         return StmtExprMutator::VisitStmt_(op);
       }
@@ -430,7 +429,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       op = stmt.as<AttrStmtNode>();
       auto it = alloc_map_.find(op->node.as<VarNode>());
       if (it == alloc_map_.end()) return stmt;
-      return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body);
+      return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -442,8 +441,8 @@ class StoragePlanRewriter : public StmtExprMutator {
       auto& svec = attach_map_[op];
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<ForNode>();
-      return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
-                           MakeAttach(svec, op->body));
+      return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
+                 MakeAttach(svec, op->body));
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
@@ -498,9 +497,8 @@ class StoragePlanRewriter : public StmtExprMutator {
     std::vector<Stmt> nest;
     for (StorageEntry* e : svec) {
       if (e->new_alloc.defined()) {
-        nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
-                                             StringImm(e->scope.to_string()),
-                                             EvaluateNode::make(0)));
+        nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope,
+                                   StringImm(e->scope.to_string()), Evaluate(0)));
         nest.push_back(e->new_alloc);
       }
     }
@@ -559,8 +557,8 @@ class StoragePlanRewriter : public StmtExprMutator {
         if (e->allocs.size() == 1) {
           // simply use the original allocation.
           PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents);
-          e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition,
-                                            EvaluateNode::make(0));
+          e->new_alloc =
+              Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0));
           if (e->scope.tag.length() != 0) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -599,8 +597,8 @@ class StoragePlanRewriter : public StmtExprMutator {
             combo_size = combo_size + make_const(DataType::Int(32), 1);
           }
           combo_size = analyzer_.Simplify(combo_size);
-          e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(),
-                                            EvaluateNode::make(0));
+          e->new_alloc =
+              Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0));
           if (e->scope.tag.length() != 0) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -642,8 +640,7 @@ class StoragePlanRewriter : public StmtExprMutator {
     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
     PrimExpr alloc_size =
         make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits);
-    e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(),
-                                      EvaluateNode::make(0));
+    e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0));
     if (info.defined()) {
       CHECK_LE(total_bits, info->max_num_bits)
           << "Allocation exceed bound of memory tag " << e->scope.to_string();
@@ -935,7 +932,7 @@ class VectorAllocRewriter : public StmtExprMutator {
       if (me->base % factor == 0 && me->coeff % factor == 0) {
         extents.Set(extents.size() - 1,
                     extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
-        return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body);
+        return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body);
       }
     }
     return stmt;
index bd66fc0..493aa51 100644 (file)
@@ -188,11 +188,11 @@ class InferFragmenter : public StmtMutator {
       std::string shape =
           std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k);
       PrimExpr shape_expr = StringImm(shape);
-      Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
+      Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
       if (info.layout != "") {
         // Add shape attribute to matrix_a and matrix_b
-        Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout,
-                                              StringImm(info.layout), shape_attr);
+        Stmt layout_attr =
+            AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr);
         return layout_attr;
       } else {
         return shape_attr;
index 266ada0..e5b4bdd 100644 (file)
@@ -209,9 +209,8 @@ class ThreadSyncInserter : public StmtExprMutator {
       if (sync_scope_.rank == StorageRank::kGlobal) {
         barrier = MakeGlobalBarrier();
       } else {
-        barrier =
-            EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
-                                    {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
+        barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+                                {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
       }
       // Mutate after query, to avoid stmt change.
       auto ret = StmtExprMutator::VisitStmt(stmt);
@@ -299,20 +298,20 @@ class ThreadSyncInserter : public StmtExprMutator {
   Stmt InitGlobalBarrier(const AttrStmtNode* op) {
     CHECK(op != nullptr);
     Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
-    Stmt prep = EvaluateNode::make(
-        Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
+    Stmt prep =
+        Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
     Stmt body = op->body;
     for (const auto& kv : rw_stats_) {
       const auto& e = kv.second;
       if (e.read_count != 0 && e.write_count != 0) {
-        body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body);
+        body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
       }
     }
     rw_stats_.clear();
-    Stmt kinit = EvaluateNode::make(
+    Stmt kinit = Evaluate(
         Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
     body = SeqStmt({kinit, body});
-    body = AttrStmtNode::make(op->node, op->attr_key, op->value, body);
+    body = AttrStmt(op->node, op->attr_key, op->value, body);
     return SeqStmt({prep, body});
   }
   Stmt MakeGlobalBarrier() {
@@ -333,9 +332,9 @@ class ThreadSyncInserter : public StmtExprMutator {
     } else {
       CHECK_EQ(num_work_dim_, thread_extents_.size());
     }
-    return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
-                                   {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
-                                   CallNode::Intrinsic));
+    return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+                         {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
+                         CallNode::Intrinsic));
   }
   // data structure.
   StorageScope sync_scope_;
index fd1a92a..a151906 100644 (file)
@@ -125,8 +125,8 @@ class LoopUnroller : public StmtExprMutator {
     } else {
       if (auto_unroll) {
         if (op->for_type != ForType::Unrolled) {
-          return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api,
-                               op->body);
+          return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api,
+                     op->body);
         }
       }
       return stmt;
@@ -164,7 +164,7 @@ class LoopUnroller : public StmtExprMutator {
     int value = GetExtent(op);
     // For loop must have a constant integer extent
     CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
-    if (value == 0) return EvaluateNode::make(0);
+    if (value == 0) return Evaluate(0);
     Stmt body = op->body;
     Map<Var, PrimExpr> vmap;
     Array<Stmt> unrolled;
index 290a3a4..227aea2 100644 (file)
@@ -74,8 +74,7 @@ class VecAllocAccess : public StmtExprMutator {
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     op = stmt.as<StoreNode>();
     if (op->buffer_var.get() == buf_) {
-      return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_,
-                             op->predicate);
+      return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate);
     } else {
       return stmt;
     }
@@ -291,8 +290,8 @@ class Vectorizer : public StmtExprMutator {
     } else {
       int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
       lanes = std::max(lanes, pred.dtype().lanes());
-      return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
-                             BroadcastTo(pred, lanes));
+      return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
+                   BroadcastTo(pred, lanes));
     }
   }
   // For
@@ -310,7 +309,7 @@ class Vectorizer : public StmtExprMutator {
     if (extent.same_as(op->extent) && body.same_as(op->body)) {
       return GetRef<Stmt>(op);
     } else {
-      return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
+      return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
     }
   }
   // IfThenElse
@@ -329,7 +328,7 @@ class Vectorizer : public StmtExprMutator {
         else_case.same_as(op->else_case)) {
       return GetRef<Stmt>(op);
     } else {
-      return IfThenElseNode::make(condition, then_case, else_case);
+      return IfThenElse(condition, then_case, else_case);
     }
   }
   // LetStmt
@@ -358,14 +357,14 @@ class Vectorizer : public StmtExprMutator {
     // rewrite access to buffer internally.
     Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
     body = this->VisitStmt(body);
-    return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
+    return Allocate(op->buffer_var, op->dtype, extents, condition, body);
   }
   // scalarize the statment
   Stmt Scalarize(Stmt stmt) {
     Var idx(var_->name_hint + ".s", var_->dtype);
     Map<Var, PrimExpr> values{{var_, idx}};
     stmt = Substitute(stmt, values);
-    return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
+    return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
   }
 
  private:
@@ -465,8 +464,7 @@ class VectorizeSkipper : public StmtMutator {
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<ForNode>();
     if (op->for_type == ForType::Vectorized) {
-      return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
-                           op->body);
+      return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body);
     } else {
       return stmt;
     }
index b9f5b9c..8dae799 100644 (file)
@@ -95,7 +95,7 @@ TEST(IRF, ExprVisit) {
     void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
   };
   MyVisitor v;
-  v.VisitStmt(EvaluateNode::make(z));
+  v.VisitStmt(Evaluate(z));
   CHECK_EQ(v.count, 1);
 }
 
@@ -112,9 +112,9 @@ TEST(IRF, StmtVisitor) {
   MyVisitor v;
   auto fmaketest = [&]() {
     auto z = x + 1;
-    Stmt body = EvaluateNode::make(z);
+    Stmt body = Evaluate(z);
     Var buffer("b", DataType::Handle());
-    return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
+    return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
   };
   v(fmaketest());
   CHECK_EQ(v.count, 3);
@@ -138,21 +138,21 @@ TEST(IRF, StmtMutator) {
   };
   auto fmakealloc = [&]() {
     auto z = x + 1;
-    Stmt body = EvaluateNode::make(z);
+    Stmt body = Evaluate(z);
     Var buffer("b", DataType::Handle());
-    return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
+    return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
   };
 
   auto fmakeif = [&]() {
     auto z = x + 1;
-    Stmt body = EvaluateNode::make(z);
-    return IfThenElseNode::make(x, EvaluateNode::make(0), body);
+    Stmt body = Evaluate(z);
+    return IfThenElse(x, Evaluate(0), body);
   };
 
   MyVisitor v;
   {
     auto body = fmakealloc();
-    Stmt body2 = EvaluateNode::make(1);
+    Stmt body2 = Evaluate(1);
     Stmt bref = body.as<AllocateNode>()->body;
     auto* extentptr = body.as<AllocateNode>()->extents.get();
     Array<Stmt> arr{std::move(body), body2, body2};
@@ -192,13 +192,13 @@ TEST(IRF, StmtMutator) {
   }
 
   {
-    auto body = EvaluateNode::make(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
+    auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
     auto res = v(std::move(body));
     CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
   }
   {
-    auto body = fmakealloc();
-    Stmt body2 = EvaluateNode::make(1);
+    Stmt body = fmakealloc();
+    Stmt body2 = Evaluate(1);
     auto* ref2 = body2.get();
     auto* extentptr = body.as<AllocateNode>()->extents.get();
     // construct a recursive SeqStmt.
@@ -214,8 +214,8 @@ TEST(IRF, StmtMutator) {
 
   {
     // Cannot cow because of bref
-    auto body = fmakealloc();
-    Stmt body2 = EvaluateNode::make(1);
+    Stmt body = fmakealloc();
+    Stmt body2 = Evaluate(1);
     auto* extentptr = body.as<AllocateNode>()->extents.get();
     // construct a recursive SeqStmt.
     body = SeqStmt({body});
index f53693b..25b3800 100644 (file)
@@ -91,7 +91,7 @@ inline Array<Tensor> make_extern(const Array<Array<PrimExpr> >& out_shapes,
   }
 
   auto body = fextern(input_placeholders, output_placeholders);
-  auto body_stmt = tvm::tir::EvaluateNode::make(body);
+  auto body_stmt = tvm::tir::Evaluate(body);
 
   auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders,
                                body_stmt);