hash_reduce(body);
}
- TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
-
static constexpr const char* _type_key = "LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};
/*!
+ * \brief Managed reference to LetStmtNode.
+ * \sa LetStmtNode
+ */
+class LetStmt : public Stmt {
+ public:
+ TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
+};
+
+/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This provide auxiliary information for IR passes that transforms body.
*
hash_reduce(body);
}
- TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
-
static constexpr const char* _type_key = "AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};
/*!
+ * \brief Managed reference to AttrStmtNode.
+ * \sa AttrStmtNode
+ */
+class AttrStmt : public Stmt {
+ public:
+ TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
+};
+
+/*!
* \brief Assert condition, if an error occurs, return the error message.
*/
class AssertStmtNode : public StmtNode {
hash_reduce(body);
}
- TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
-
static constexpr const char* _type_key = "AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};
/*!
+ * \brief Managed reference to AssertStmtNode.
+ * \sa AssertStmtNode
+ */
+class AssertStmt : public Stmt {
+ public:
+ TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
+};
+
+/*!
* \brief Store value to the buffer.
*
* Equivalent to ((DType*)buffer_var)[index] = value.
hash_reduce(predicate);
}
- TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);
-
static constexpr const char* _type_key = "Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};
/*!
+ * \brief Managed reference to StoreNode.
+ * \sa StoreNode
+ */
+class Store : public Stmt {
+ public:
+ TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
+};
+
+/*!
* \brief Store value to the high dimension buffer.
*
* \code
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);
+
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};
hash_reduce(indices);
}
- TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);
-
static constexpr const char* _type_key = "ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};
/*!
+ * \brief Managed reference to ProducerStoreNode.
+ * \sa ProducerStoreNode
+ */
+class ProducerStore : public Stmt {
+ public:
+ TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
+};
+
+/*!
* \brief Annotate the bounds where the data produced by the producer
* need to be written and read in body.
* We will need to allocate space for the corresponding regions.
v->Visit("body", &body);
}
- TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);
-
bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
};
/*!
+ * \brief Managed reference to ProducerRealizeNode.
+ * \sa ProducerRealizeNode
+ */
+class ProducerRealize : public Stmt {
+ public:
+ TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
+};
+
+/*!
* \brief Allocate a buffer that can be used in body.
*/
class AllocateNode : public StmtNode {
hash_reduce(body);
}
- TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
- PrimExpr condition, Stmt body);
-
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};
+/*!
+ * \brief Managed reference to AllocateNode.
+ * \sa AllocateNode
+ */
+class Allocate : public Stmt {
+ public:
+ TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
+ Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
+};
+
/*! \brief Free the resources in the buffer before the scope ends. */
class FreeNode : public StmtNode {
public:
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
- TVM_DLL static Stmt make(Var buffer_var);
-
static constexpr const char* _type_key = "Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};
/*!
+ * \brief Managed reference to FreeNode.
+ * \sa FreeNode
+ */
+class Free : public Stmt {
+ public:
+ TVM_DLL Free(Var buffer_var);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
+};
+
+/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
*/
hash_reduce(else_case);
}
- TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
-
static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};
/*!
+ * \brief Managed reference to IfThenElseNode.
+ * \sa IfThenElseNode
+ */
+class IfThenElse : public Stmt {
+ public:
+ TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
+
+ TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
+};
+
+/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
*
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
- TVM_DLL static Stmt make(PrimExpr v);
-
static constexpr const char* _type_key = "Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};
+/*!
+ * \brief Managed reference to EvaluateNode.
+ * \sa EvaluateNode
+ */
+class Evaluate : public Stmt {
+ public:
+ TVM_DLL explicit Evaluate(PrimExpr value);
+
+ explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {}
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
+};
+
/*! \brief Additional annotation of for loop. */
enum class ForType : int {
/*! \brief serial execution. */
/*! \brief The body of the for loop. */
Stmt body;
- TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
- DeviceAPI device_api, Stmt body);
-
void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
};
/*!
+ * \brief Managed reference to ForNode.
+ * \sa ForNode
+ */
+class For : public Stmt {
+ public:
+ TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
+ Stmt body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
+};
+
+/*!
* \brief A prefetch hint for abuffer
*/
class PrefetchNode : public StmtNode {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};
-
/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template <typename T>
-inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
+inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
if (else_case.defined()) {
return else_case;
}
- return EvaluateNode::make(0);
+ return Evaluate(0);
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
- ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body),
- 0);
+ For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i - 1);
- realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize);
+ realize = tir::ProducerRealize(t, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
if (attr->dim_align_factor != 0) {
Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
attr->dim_align_offset};
- realize = tir::AttrStmtNode::make(
+ realize = tir::AttrStmt(
t, tir::attr::buffer_dim_align,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
realize);
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
- inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args));
- provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args));
+ inits.emplace_back(ProducerStore(t, init_value[i], args));
+ provides.emplace_back(ProducerStore(t, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
- *provide = IfThenElseNode::make(reduce->condition, *provide);
+ *provide = IfThenElse(reduce->condition, *provide);
}
}
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
- return ProducerStoreNode::make(t, op->body[t->value_index], args);
+ return ProducerStore(t, op->body[t->value_index], args);
}
Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
}
auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
- return IfThenElseNode::make(cond, update, body);
+ return IfThenElse(cond, update, body);
}
} // namespace te
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_init.emplace_back(
- StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
+ Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes())));
normal_update.emplace_back(
- StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
+ Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes())));
}
}
// Apply the existing input predicate if any.
output_preds.push_back(input_pred);
- Stmt reduce_body = EvaluateNode::make(Call(
- DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic));
- reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope,
- make_zero(DataType::Handle()), reduce_body);
+ Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce,
+ freduce_args, CallNode::Intrinsic));
+ reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope,
+ make_zero(DataType::Handle()), reduce_body);
if (!normal_red.empty()) {
Stmt init_body = SeqStmt::Flatten(normal_init);
std::vector<Stmt> assigns(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType t = reduces[idx]->dtype;
- assigns[idx] = ProducerStoreNode::make(
- stage->op.output(idx), Load(t, res_handles[idx], 0, const_true(t.lanes())), args);
+ assigns[idx] = ProducerStore(stage->op.output(idx),
+ Load(t, res_handles[idx], 0, const_true(t.lanes())), args);
}
Stmt assign_body = SeqStmt::Flatten(assigns);
assign_body = MergeNest(MakeIfNest(output_preds), assign_body);
Stmt body = SeqStmt::Flatten(reduce_body, assign_body);
for (size_t idx = size; idx != 0; --idx) {
- body =
- AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
- body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"),
- body);
+ body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+ body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body);
if (!normal_red.empty()) {
- body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1},
- const_true(), body);
- body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope,
- StringImm("local"), body);
+ body =
+ Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
+ body =
+ AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body);
}
}
body = Substitute(body, value_map);
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body);
+ realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
}
return realize_body;
}
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret =
- AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec;
Array<PrimExpr> tuple;
tuple.push_back(make_const(buffer->shape[k].dtype(), 0));
tuple.push_back(buffer->shape[k]);
}
- ret = AttrStmtNode::make(
- bind_spec, tir::attr::buffer_bind_scope,
- Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
+ ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope,
+ Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1));
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
- realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body);
+ realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
}
return realize_body;
}
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt ret =
- AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
+ Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
rmap[op->loop_var.get()] = inner + outer * factor;
Stmt ret = tir::Substitute(op->body, rmap);
PrimExpr cond = likely(outer * factor < (op->extent - inner));
- ret = IfThenElseNode::make(cond, ret);
- ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent,
- IterVarTypeToForType(inner->iter_type), op->device_api, ret);
- ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent,
- IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+ ret = IfThenElse(cond, ret);
+ ret = For(inner->var, PrimExpr(0), inner->dom->extent,
+ IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+ ret = For(outer->var, PrimExpr(0), outer->dom->extent,
+ IterVarTypeToForType(outer->iter_type), op->device_api, ret);
splitted = true;
return ret;
}
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
- return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type,
- op->device_api, body);
+ return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api,
+ body);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const VarNode*, PrimExpr> rmap;
std::unordered_map<const VarNode*, PrimExpr> rmap;
rmap[op->loop_var.get()] = iter_var;
Stmt body = tir::Substitute(op->body, rmap);
- return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body);
+ return AttrStmt(iter_var, "thread_extent", op->extent, body);
} else {
- return ForNode::make(op->loop_var, op->min, op->extent,
- IterVarTypeToForType(attr->iter_type), op->device_api, op->body);
+ return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type),
+ op->device_api, op->body);
}
}
return StmtMutator::VisitStmt_(op);
for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
}
const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
- return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
+ return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
}
};
Tensor t = Downcast<Tensor>(op->producer);
auto it = vmap_.find(t);
if (it != vmap_.end()) {
- Stmt ret = tir::ProducerStoreNode::make(it->second, op->value, op->indices);
+ Stmt ret = tir::ProducerStore(it->second, op->value, op->indices);
found = true;
return this->VisitStmt(ret);
}
std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
- Stmt no_op = EvaluateNode::make(0);
+ Stmt no_op = Evaluate(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
nest.resize(leaf_iter_vars.size() + 1);
pvalue = make_const(DataType::Int(32), 1);
}
nest[i + 1].emplace_back(
- AttrStmtNode::make(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
+ AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
- nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op));
+ nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
- nest[i + 1].emplace_back(
- ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
+ nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
value_map[iv] = var;
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
- nest[i + 1].emplace_back(
- ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
+ nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
- nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op));
+ nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
}
if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1";
CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size());
for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
- nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j],
- tir::attr::prefetch_scope,
- it_attr->prefetch_offset[j], no_op));
+ nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope,
+ it_attr->prefetch_offset[j], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") {
CHECK(is_zero(dom->min));
CHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
+ nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = var;
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
CHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
+ AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmtNode::make(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
+ nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
}
// annotate the extent of the IterVar
if (!new_loop_var) {
- nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
+ nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
}
std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
- Stmt no_op = EvaluateNode::make(0);
+ Stmt no_op = Evaluate(0);
std::vector<Stmt> nest;
for (const PrimExpr& cond : predicates) {
- nest.emplace_back(IfThenElseNode::make(cond, no_op));
+ nest.emplace_back(IfThenElse(cond, no_op));
}
return nest;
}
IterVar sp_ax = this->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
- ret = tir::ProducerRealizeNode::make(t, bounds, const_true(), ret);
+ ret = tir::ProducerRealize(t, bounds, const_true(), ret);
}
return ret;
}
Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
- Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
- EvaluateNode::make(0));
- Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0));
+ Stmt provide =
+ AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0));
+ Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0));
size_t begin_scan = 0;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) {
CHECK_EQ(stage->op.operator->(), this);
// Start bind data.
- Stmt nop = EvaluateNode::make(0);
+ Stmt nop = Evaluate(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = this->InputTensors();
tuple.push_back(region[i]->min);
tuple.push_back(region[i]->extent);
}
- input_bind_nest.emplace_back(AttrStmtNode::make(
+ input_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
}
}
}
- output_bind_nest.emplace_back(AttrStmtNode::make(
+ output_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
}
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
// Start bind data.
- Stmt nop = EvaluateNode::make(0);
+ Stmt nop = Evaluate(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch ";
tuple.push_back(r->min);
tuple.push_back(r->extent);
}
- input_bind_nest.emplace_back(AttrStmtNode::make(
+ input_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
}
Tensor tensor = stage->op.output(i - intrin->inputs.size());
Buffer buffer = intrin->buffers[i];
Array<ObjectRef> bind_spec{buffer, tensor};
- output_bind_nest.emplace_back(AttrStmtNode::make(
+ output_bind_nest.emplace_back(AttrStmt(
bind_spec, tir::attr::buffer_bind_scope,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop));
}
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->indices[i]);
}
- expr = Substitute(EvaluateNode::make(expr), vmap).as<EvaluateNode>()->value;
+ expr = Substitute(Evaluate(expr), vmap).as<EvaluateNode>()->value;
}
return expr;
} else {
CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
- PrimExpr new_value =
- Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body)
- .as<tir::EvaluateNode>()
- ->value;
+ PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
+ .as<tir::EvaluateNode>()
+ ->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
- PrimExpr new_value =
- Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body)
- .as<tir::EvaluateNode>()
- ->value;
+ PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
+ .as<tir::EvaluateNode>()
+ ->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
changed[j] = true;
bool debug_keep_trivial_loop) {
Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop);
if (s->double_buffer) {
- producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer);
+ producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer);
}
Stmt pipeline = producer;
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
// use attribute to mark scope of the operation.
- pipeline = AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);
+ pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);
return pipeline;
}
CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar
<< " in multiple places in the IR";
found_attach = true;
- stmt =
- AttrStmtNode::make(op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ stmt = AttrStmt(op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
(op->attr_key == tir::attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
- stmt =
- AttrStmtNode::make(op->node, op->attr_key, op->value,
- MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
+ stmt = AttrStmt(op->node, op->attr_key, op->value,
+ MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_));
}
}
return stmt;
auto it = replace_op_.find(op->node.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body);
+ Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body);
return this->VisitStmt(ret);
} else {
return this->VisitStmt(op->body);
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- return AttrStmtNode::make(
- Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)}, op->attr_key,
- op->value, this->VisitStmt(op->body));
+ return AttrStmt(Array<ObjectRef>{tuple[0], it->second.output(tensor->value_index)},
+ op->attr_key, op->value, this->VisitStmt(op->body));
} else {
return this->VisitStmt(op->body);
}
auto it = replace_op_.find(tensor->op.get());
if (it != replace_op_.end()) {
if (it->second.defined()) {
- return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value,
- this->VisitStmt(op->body));
+ return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value,
+ this->VisitStmt(op->body));
} else {
return this->VisitStmt(op->body);
}
auto it = replace_realize_.find(key);
if (it != replace_realize_.end()) {
if (it->second.defined()) {
- Stmt ret = ProducerRealizeNode::make(it->second, op->bounds, op->condition, op->body);
+ Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body);
return this->VisitStmt(ret);
} else {
return this->VisitStmt(op->body);
auto it = replace_buffer_.find(key);
if (it != replace_buffer_.end()) {
const Tensor& dst = it->second;
- Stmt ret = ProducerStoreNode::make(dst, op->value, op->indices);
+ Stmt ret = ProducerStore(dst, op->value, op->indices);
return this->VisitStmt(ret);
} else {
return StmtExprMutator::VisitStmt_(op);
new_bounds.push_back(
Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
- return ProducerRealizeNode::make(op->producer, new_bounds, op->condition, op->body);
+ return ProducerRealize(op->producer, new_bounds, op->condition, op->body);
}
return stmt;
}
CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name;
auto matrix_abc = tvm::tir::StringImm("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
- return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body);
+ return AttrStmt(op->node, op->attr_key, matrix_abc, body);
}
}
return stmt;
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
- return EvaluateNode::make(
+ return Evaluate(
Call(DataType::Handle(), intrinsic::tvm_bmma_sync,
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
} else {
- return EvaluateNode::make(
+ return Evaluate(
Call(DataType::Handle(), intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset},
auto pload = dst.as<ProducerLoadNode>();
auto fill_fragment_call = [this, &op](const Buffer& buffer) {
- return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_fill_fragment,
- {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, op->value},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment,
+ {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+ buffer->elem_offset, op->value},
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
}
auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) {
- return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
- {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, src, stride, matrix_major},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync,
+ {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+ buffer->elem_offset, src, stride, matrix_major},
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
auto pload = op->value.as<ProducerLoadNode>();
auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) {
- return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
- {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
- buffer->elem_offset, dst, stride, StringImm("col_major")},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync,
+ {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k,
+ buffer->elem_offset, dst, stride, StringImm("col_major")},
+ CallNode::Intrinsic));
};
ObjectPtr<BufferNode> buffer_node = make_object<BufferNode>();
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
- stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api,
- op->body);
+ stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body);
}
}
return stmt;
}
auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic);
Array<ObjectRef> node = {buffer, tensor};
- return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer));
+ return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer));
}
std::unordered_map<std::string, std::string> matrix_abc_;
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
- body = AttrStmtNode::make(buffer, op->attr_key, op->value, body);
+ body = AttrStmt(buffer, op->attr_key, op->value, body);
}
return body;
} else if (op->attr_key == tir::attr::buffer_bind_scope) {
Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
Tensor tensor = Downcast<Tensor>(tuple[1]);
- return AttrStmtNode::make(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key,
- op->value, op->body);
+ return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value,
+ op->body);
} else if (op->attr_key == tir::attr::buffer_dim_align ||
op->attr_key == tir::attr::prefetch_scope) {
Tensor tensor = Downcast<Tensor>(op->node);
Buffer buffer = GetOrAllocBuffer(tensor);
- return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body);
+ return AttrStmt(buffer, op->attr_key, op->value, op->body);
} else {
return ret;
}
CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0)
<< "Cannot load " << dtype << " from buffer of " << n->dtype;
if (value.dtype() == DataType::Bool()) {
- return tir::StoreNode::make(n->data, tir::Cast(DataType::Int(8), value),
- BufferOffset(n, begin, DataType::Int(8)), const_true());
+ return tir::Store(n->data, tir::Cast(DataType::Int(8), value),
+ BufferOffset(n, begin, DataType::Int(8)), const_true());
} else {
- return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype),
- const_true(dtype.lanes()));
+ return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes()));
}
}
return Let(var, value, body);
});
+TVM_REGISTER_NODE_TYPE(LetNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get());
TVM_REGISTER_NODE_TYPE(ShuffleNode);
+template <typename T>
+void PrintList(const Array<T>& exprs, ReprPrinter* p) {
+ for (size_t i = 0; i < exprs.size(); ++i) {
+ p->Print(exprs[i]);
+ if (i < exprs.size() - 1) {
+ p->stream << ", ";
+ }
+ }
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ShuffleNode*>(node.get());
+ p->stream << "shuffle(";
+ PrintList(op->vectors, p);
+ p->stream << ", ";
+ PrintList(op->indices, p);
+ p->stream << ")";
+ });
+
// CommReducer
CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element) {
namespace tvm {
namespace tir {
-Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) {
+// LetStmt
+LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
- return Stmt(node);
+ data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) {
+ return LetStmt(var, value, body);
+});
+
+TVM_REGISTER_NODE_TYPE(LetStmtNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const LetStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "let " << op->var << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
-Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
+// AttrStmt
+AttrStmt::AttrStmt(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
n->value = std::move(value);
n->body = std::move(body);
- return Stmt(n);
+ data_ = std::move(n);
}
-TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make);
+TVM_REGISTER_GLOBAL("tir.AttrStmt")
+ .set_body_typed([](ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) {
+ return AttrStmt(node, attr_key, value, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(AttrStmtNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AttrStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "// attr [";
+ p->Print(op->node);
+ p->stream << "] " << op->attr_key << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ p->Print(op->body);
+ });
-Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
+// AssertStmt
+AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>())
<< "TypeError: AssertStmt message must be an int or string:" << message << "\n";
node->condition = std::move(condition);
node->message = std::move(message);
node->body = std::move(body);
- return Stmt(node);
+ data_ = std::move(node);
}
+TVM_REGISTER_NODE_TYPE(AssertStmtNode);
+
TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
if (const auto* str = message.as<StringObj>()) {
auto msg = StringImm(str->data);
- return AssertStmtNode::make(condition, msg, body);
+ return AssertStmt(condition, msg, body);
} else {
- return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+ return AssertStmt(condition, Downcast<PrimExpr>(message), body);
}
});
-Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
- DeviceAPI device_api, Stmt body) {
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const AssertStmtNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "assert(";
+ p->Print(op->condition);
+ p->stream << ", ";
+ p->Print(op->message);
+ p->stream << ")\n";
+ p->Print(op->body);
+ });
+
+// For
+For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
+ Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.dtype().is_scalar());
node->for_type = for_type;
node->device_api = device_api;
node->body = std::move(body);
- return Stmt(node);
+ data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
- return ForNode::make(loop_var, min, extent, static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api), body);
-});
-
-Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
- CHECK(value.defined());
- CHECK(index.defined());
- CHECK(predicate.defined());
- CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
- CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
-
- ObjectPtr<StoreNode> node = make_object<StoreNode>();
- node->buffer_var = std::move(buffer_var);
- node->value = std::move(value);
- node->index = std::move(index);
- node->predicate = std::move(predicate);
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
- PrimExpr value = args[1];
- if (args.size() == 3) {
- *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes()));
- } else {
- *ret = StoreNode::make(args[0], value, args[2], args[3]);
- }
+ return For(loop_var, min, extent, static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api), body);
});
-Stmt ProducerStoreNode::make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
- ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
- node->producer = std::move(producer);
- node->value = std::move(value);
- node->indices = std::move(indices);
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.ProducerStore").set_body_typed(ProducerStoreNode::make);
-
-Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
- Stmt body) {
- for (size_t i = 0; i < extents.size(); ++i) {
- CHECK(extents[i].defined());
- CHECK(extents[i].dtype().is_scalar());
- }
- CHECK(body.defined());
- CHECK(condition.defined());
- CHECK(condition.dtype().is_bool());
-
- ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
- node->buffer_var = std::move(buffer_var);
- node->dtype = dtype;
- node->extents = std::move(extents);
- node->condition = std::move(condition);
- node->body = std::move(body);
- return Stmt(node);
-}
-
-Stmt ProducerRealizeNode::make(DataProducer producer, Region bounds, PrimExpr condition,
- Stmt body) {
- for (size_t i = 0; i < bounds.size(); ++i) {
- CHECK(bounds[i]->min.defined());
- CHECK(bounds[i]->extent.defined());
- CHECK(bounds[i]->min.dtype().is_scalar());
- CHECK(bounds[i]->extent.dtype().is_scalar());
- }
- CHECK(body.defined());
- CHECK(condition.defined());
- CHECK(condition.dtype().is_bool());
-
- ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
- node->producer = std::move(producer);
- node->bounds = std::move(bounds);
- node->condition = std::move(condition);
- node->body = std::move(body);
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.ProducerRealize").set_body_typed(ProducerRealizeNode::make);
-
-// overloaded, needs special handling
-// has default args
-TVM_REGISTER_GLOBAL("tir.Allocate")
- .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
- Stmt body) {
- return AllocateNode::make(buffer_var, type, extents, condition, body);
- });
-
-int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
- int64_t result = 1;
- for (size_t i = 0; i < extents.size(); ++i) {
- if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
- result *= int_size->value;
- if (result > std::numeric_limits<int32_t>::max()) {
- return 0;
- }
- } else {
- return 0;
- }
- }
- return static_cast<int32_t>(result);
-}
-
-Stmt FreeNode::make(Var buffer_var) {
- ObjectPtr<FreeNode> node = make_object<FreeNode>();
- node->buffer_var = buffer_var;
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make);
-
-Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
- data_ = make_object<PrefetchNode>(buffer, bounds);
-}
-
-TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) {
- return Prefetch(buffer, bounds);
-});
-
-SeqStmt::SeqStmt(Array<Stmt> seq) {
- auto node = make_object<SeqStmtNode>();
- node->seq = std::move(seq);
- data_ = std::move(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) {
- return SeqStmt(std::move(seq));
-});
-
-Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) {
- CHECK(condition.defined());
- CHECK(then_case.defined());
- // else_case may be null.
-
- ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
- node->condition = std::move(condition);
- node->then_case = std::move(then_case);
- node->else_case = std::move(else_case);
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make);
-
-Stmt EvaluateNode::make(PrimExpr value) {
- CHECK(value.defined());
-
- ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
- node->value = std::move(value);
- return Stmt(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make);
-
-BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
- ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
- node->buffer = std::move(buffer);
- node->value = std::move(value);
- node->indices = std::move(indices);
- data_ = std::move(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.BufferStore")
- .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
- return BufferStore(buffer, value, indices);
- });
-
-TVM_REGISTER_NODE_TYPE(BufferStoreNode);
-
-BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
- data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body);
-}
-
-TVM_REGISTER_GLOBAL("tir.BufferRealize")
- .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
- return BufferRealize(buffer, bounds, condition, body);
- });
-
-TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
-
-// Printers
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LetStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "let " << op->var << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AttrStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "// attr [";
- p->Print(op->node);
- p->stream << "] " << op->attr_key << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AssertStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "assert(";
- p->Print(op->condition);
- p->stream << ", ";
- p->Print(op->message);
- p->stream << ")\n";
- p->Print(op->body);
- });
+TVM_REGISTER_NODE_TYPE(ForNode);
std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*)
switch (type) {
p->stream << "}\n";
});
+// Store
+Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) {
+ CHECK(value.defined());
+ CHECK(index.defined());
+ CHECK(predicate.defined());
+ CHECK_EQ(value.dtype().lanes(), index.dtype().lanes());
+ CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes());
+
+ ObjectPtr<StoreNode> node = make_object<StoreNode>();
+ node->buffer_var = std::move(buffer_var);
+ node->value = std::move(value);
+ node->index = std::move(index);
+ node->predicate = std::move(predicate);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
+ PrimExpr value = args[1];
+ if (args.size() == 3) {
+ *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()));
+ } else {
+ *ret = Store(args[0], value, args[2], args[3]);
+ }
+});
+
+TVM_REGISTER_NODE_TYPE(StoreNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StoreNode*>(node.get());
p->stream << '\n';
});
+// ProducerStore
+ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
+ ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
+ node->producer = std::move(producer);
+ node->value = std::move(value);
+ node->indices = std::move(indices);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.ProducerStore")
+ .set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices) {
+ return ProducerStore(producer, value, indices);
+ });
+
+TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerStoreNode*>(node.get());
p->stream << '\n';
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferStoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) p->stream << ", ";
+// Allocate
+Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
+ Stmt body) {
+ for (size_t i = 0; i < extents.size(); ++i) {
+ CHECK(extents[i].defined());
+ CHECK(extents[i].dtype().is_scalar());
+ }
+ CHECK(body.defined());
+ CHECK(condition.defined());
+ CHECK(condition.dtype().is_bool());
+
+ ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
+ node->buffer_var = std::move(buffer_var);
+ node->dtype = dtype;
+ node->extents = std::move(extents);
+ node->condition = std::move(condition);
+ node->body = std::move(body);
+ data_ = std::move(node);
+}
+
+int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
+ int64_t result = 1;
+ for (size_t i = 0; i < extents.size(); ++i) {
+ if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
+ result *= int_size->value;
+ if (result > std::numeric_limits<int32_t>::max()) {
+ return 0;
}
- p->stream << "]";
- p->stream << " = ";
- p->Print(op->value);
- p->stream << '\n';
- });
+ } else {
+ return 0;
+ }
+ }
+ return static_cast<int32_t>(result);
+}
+
+TVM_REGISTER_GLOBAL("tir.Allocate")
+ .set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
+ Stmt body) { return Allocate(buffer_var, type, extents, condition, body); });
+
+TVM_REGISTER_NODE_TYPE(AllocateNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
p->Print(op->body);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FreeNode*>(node.get());
- p->PrintIndent();
- p->stream << "free " << op->buffer_var;
- p->stream << '\n';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferRealizeNode*>(node.get());
- p->PrintIndent();
- p->stream << "buffer_realize " << op->buffer->name << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << " {\n";
+// ProducerRealize
+ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
+ Stmt body) {
+ for (size_t i = 0; i < bounds.size(); ++i) {
+ CHECK(bounds[i]->min.defined());
+ CHECK(bounds[i]->extent.defined());
+ CHECK(bounds[i]->min.dtype().is_scalar());
+ CHECK(bounds[i]->extent.dtype().is_scalar());
+ }
+ CHECK(body.defined());
+ CHECK(condition.defined());
+ CHECK(condition.dtype().is_bool());
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
+ ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
+ node->producer = std::move(producer);
+ node->bounds = std::move(bounds);
+ node->condition = std::move(condition);
+ node->body = std::move(body);
+ data_ = std::move(node);
+}
- p->PrintIndent();
- p->stream << "}\n";
+TVM_REGISTER_GLOBAL("tir.ProducerRealize")
+ .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) {
+ return ProducerRealize(producer, bounds, condition, body);
});
+TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerRealizeNode*>(node.get());
p->stream << "}\n";
});
+// Free
+Free::Free(Var buffer_var) {
+ ObjectPtr<FreeNode> node = make_object<FreeNode>();
+ node->buffer_var = buffer_var;
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); });
+
+TVM_REGISTER_NODE_TYPE(FreeNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const FreeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "free " << op->buffer_var;
+ p->stream << '\n';
+ });
+
+// Prefetch
+Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
+ data_ = make_object<PrefetchNode>(buffer, bounds);
+}
+
+TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array<Range> bounds) {
+ return Prefetch(buffer, bounds);
+});
+
+TVM_REGISTER_NODE_TYPE(PrefetchNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get());
p->stream << ")";
});
+// SeqStmt
+SeqStmt::SeqStmt(Array<Stmt> seq) {
+ auto node = make_object<SeqStmtNode>();
+ node->seq = std::move(seq);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq) {
+ return SeqStmt(std::move(seq));
+});
+
+TVM_REGISTER_NODE_TYPE(SeqStmtNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SeqStmtNode*>(node.get());
}
});
+// IfThenElse
+IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) {
+ CHECK(condition.defined());
+ CHECK(then_case.defined());
+ // else_case may be null.
+ ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
+ node->condition = std::move(condition);
+ node->then_case = std::move(then_case);
+ node->else_case = std::move(else_case);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_NODE_TYPE(IfThenElseNode);
+
+TVM_REGISTER_GLOBAL("tir.IfThenElse")
+ .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) {
+ return IfThenElse(condition, then_case, else_case);
+ });
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IfThenElseNode*>(node.get());
p->stream << "}\n";
});
+// Evaluate
+Evaluate::Evaluate(PrimExpr value) {
+ CHECK(value.defined());
+
+ ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
+ node->value = std::move(value);
+ data_ = std::move(node);
+}
+
+TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); });
+
+TVM_REGISTER_NODE_TYPE(EvaluateNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EvaluateNode*>(node.get());
p->stream << "\n";
});
-template <typename T>
-void PrintList(const Array<T>& exprs, ReprPrinter* p) {
- for (size_t i = 0; i < exprs.size(); ++i) {
- p->Print(exprs[i]);
- if (i < exprs.size() - 1) {
- p->stream << ", ";
- }
- }
+// BufferStore
+BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
+ node->buffer = std::move(buffer);
+ node->value = std::move(value);
+ node->indices = std::move(indices);
+ data_ = std::move(node);
}
+TVM_REGISTER_GLOBAL("tir.BufferStore")
+ .set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
+ return BufferStore(buffer, value, indices);
+ });
+
+TVM_REGISTER_NODE_TYPE(BufferStoreNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ShuffleNode*>(node.get());
- p->stream << "shuffle(";
- PrintList(op->vectors, p);
- p->stream << ", ";
- PrintList(op->indices, p);
- p->stream << ")";
+ .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferStoreNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ p->Print(op->indices[i]);
+ if (i < op->indices.size() - 1) p->stream << ", ";
+ }
+ p->stream << "]";
+ p->stream << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
});
-TVM_REGISTER_NODE_TYPE(AttrStmtNode);
-TVM_REGISTER_NODE_TYPE(PrefetchNode);
-TVM_REGISTER_NODE_TYPE(CallNode);
-TVM_REGISTER_NODE_TYPE(LetNode);
-TVM_REGISTER_NODE_TYPE(LetStmtNode);
-TVM_REGISTER_NODE_TYPE(AssertStmtNode);
-TVM_REGISTER_NODE_TYPE(ForNode);
-TVM_REGISTER_NODE_TYPE(StoreNode);
-TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
-TVM_REGISTER_NODE_TYPE(AllocateNode);
-TVM_REGISTER_NODE_TYPE(FreeNode);
-TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
-TVM_REGISTER_NODE_TYPE(SeqStmtNode);
-TVM_REGISTER_NODE_TYPE(IfThenElseNode);
-TVM_REGISTER_NODE_TYPE(EvaluateNode);
+// BufferRealize
+BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+ data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferRealize")
+ .set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body) {
+ return BufferRealize(buffer, bounds, condition, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferRealizeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "buffer_realize " << op->buffer->name << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << " {\n";
+
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
+
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
} // namespace tir
} // namespace tvm
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
if (auto mapped_var = vmap_(op->buffer_var)) {
- return StoreNode::make(Downcast<Var>(mapped_var.value()), op->value, op->index,
- op->predicate);
+ return Store(Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
} else {
return ret;
}
const IfThenElseNode* new_if_node = new_if.as<IfThenElseNode>();
CHECK(new_if_node);
- new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for);
+ new_if = IfThenElse(new_if_node->condition, then_for, else_for);
if (i < if2for_map_[if_stmt.get()].size() - 1) {
const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx);
Stmt new_for = Stmt();
for (size_t i = new_if_list.size() - 1; i > 0; --i) {
CHECK(current_if_node);
- const Stmt current_if_stmt = IfThenElseNode::make(
+ const Stmt current_if_stmt = IfThenElse(
current_if_node->condition, current_if_node->then_case, current_if_node->else_case);
next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
CHECK(next_if_node);
- new_for =
- IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case);
+ new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case);
current_if_node = new_for.as<IfThenElseNode>();
}
if (!new_for.get()) {
const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
CHECK(first_if_node);
- new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case,
- first_if_node->else_case);
+ new_for = IfThenElse(first_if_node->condition, first_if_node->then_case,
+ first_if_node->else_case);
}
*ret = new_for;
}
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
- asserts->emplace_back(
- AssertStmtNode::make(scond, tvm::tir::StringImm(os.str()), EvaluateNode::make(0)));
+ asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0)));
}
}
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
- init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0)));
+ init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
} else {
(*def_map_)[v] = value;
}
const std::string& arg_name) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
- const Stmt nop = EvaluateNode::make(0);
+ const Stmt nop = Evaluate(0);
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size();
auto msg = tvm::tir::StringImm(ndim_err_msg.str());
- asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
+ asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
// type checks
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
IntImm(DataType::UInt(16), dtype.lanes()));
if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
- asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
- asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
+ asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
+ asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
- init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment,
- IntImm(DataType::Int(32), buffer->data_alignment),
- nop));
+ init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment,
+ IntImm(DataType::Int(32), buffer->data_alignment), nop));
}
Var v_shape(arg_name + ".shape", DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
- init_nest_.emplace_back(LetStmtNode::make(
- v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
+ init_nest_.emplace_back(
+ LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) {
break;
// strides field
Var v_strides(arg_name + ".strides", DataType::Handle());
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
- init_nest_.emplace_back(LetStmtNode::make(
- v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
+ init_nest_.emplace_back(
+ LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop));
PrimExpr is_null =
Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic);
if (buffer->strides.size() == 0) {
if (conds.size() != 0) {
auto stride_msg = tvm::tir::StringImm(stride_err_msg.str());
auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
- Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg,
- EvaluateNode::make(0));
- check = IfThenElseNode::make(Not(is_null), check, Stmt());
- asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
+ Stmt check = AssertStmt(foldl(fand, const_true(1), conds), stride_msg, Evaluate(0));
+ check = IfThenElse(Not(is_null), check, Stmt());
+ asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
DataType stype = buffer->DefaultIndexType();
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
asserts_.emplace_back(
- AssertStmtNode::make(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
+ AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
if (store_scope_bound_collector_.size()) {
PrimExpr condition = MakeCondition();
if (!condition.as<StringImmNode>()) {
- Stmt nop = EvaluateNode::make(1);
- Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate);
- Stmt else_case = AssertStmtNode::make(condition, StringImm(error_message_), nop);
- Stmt body = IfThenElseNode::make(condition, then_case, else_case);
+ Stmt nop = Evaluate(1);
+ Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate);
+ Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
+ Stmt body = IfThenElse(condition, then_case, else_case);
return body;
}
}
static Stmt BuildContext(
const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap, Stmt body) {
for (const auto& kv : cmap) {
- body = LetStmtNode::make(kv.second, kv.first, body);
+ body = LetStmt(kv.second, kv.first, body);
}
return body;
}
}
std::vector<Stmt> GetSync(std::string sync_name) {
- return {EvaluateNode::make(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
+ return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))};
}
const std::unordered_set<const VarNode*>& touched_;
CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
PrimExpr min = r->min;
PrimExpr extent = r->extent;
- return EvaluateNode::make(Call(DataType::Int(32), func,
- {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), func,
+ {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent},
+ CallNode::Intrinsic));
}
// Write barrier name
bool read_barrier_{false};
}
Stmt MakePush(int from, int to) {
- return EvaluateNode::make(
- Call(DataType::Int(32), sync_push_name_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), sync_push_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+ CallNode::Intrinsic));
}
Stmt MakePop(int from, int to) {
- return EvaluateNode::make(
- Call(DataType::Int(32), sync_pop_name_,
- {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), sync_pop_name_,
+ {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)},
+ CallNode::Intrinsic));
}
// sync states.
SyncState first_state_, last_state_, curr_state_;
namespace tir {
Stmt DecorateDeviceScope(Stmt&& stmt) {
- Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt);
+ Stmt body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt);
return body;
}
}
CHECK(it->second.loop != nullptr);
auto& alloc_nest = loop_allocs_[it->second.loop];
- alloc_nest.emplace_back(AttrStmtNode::make(
- op->buffer_var, attr::storage_scope, StringImm(it->second.scope), EvaluateNode::make(0)));
- alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents,
- op->condition, EvaluateNode::make(0)));
+ alloc_nest.emplace_back(
+ AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0)));
+ alloc_nest.emplace_back(
+ Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0)));
return op->body;
} else {
return StmtExprMutator::VisitStmt_(op);
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
- Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type,
- old_loop->device_api, SeqStmt::Flatten(loop_seq));
+ Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
+ SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
for (int32_t i = 0; i < split_loop_; ++i) {
PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
vmap[old_loop->loop_var.get()] = idx;
- tail_seq.emplace_back(
- IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap)));
+ tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap)));
}
stmt = SeqStmt::Flatten(loop, tail_seq);
}
const StorageEntry& e = it->second;
CHECK(in_double_buffer_scope_);
CHECK(e.stride.defined());
- return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
- op->predicate);
+ return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index,
+ op->predicate);
} else {
return stmt;
}
vmap[e.loop->loop_var.get()] = loop_shift;
vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
body = Substitute(body, vmap);
- body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body);
- body = IfThenElseNode::make(loop_shift < e.loop->extent, body);
+ body = AttrStmt(buffer, attr::double_buffer_write, 1, body);
+ body = IfThenElse(loop_shift < e.loop->extent, body);
return body;
}
// Storage entry for those who need double buffering.
trigger_base_inject_ = !allow_share_;
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
- return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second),
- op->predicate);
+ return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate);
} else {
return stmt;
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return AttrStmtNode::make(op->node, op->attr_key, value, body);
+ return AttrStmt(op->node, op->attr_key, value, body);
}
}
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return LetStmtNode::make(op->var, value, body);
+ return LetStmt(op->var, value, body);
}
}
// For
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
+ return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElseNode::make(condition, then_case, else_case);
+ return IfThenElse(condition, then_case, else_case);
}
}
if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) {
return GetRef<Stmt>(op);
} else {
- return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
+ return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
}
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
- ForType::Serial, DeviceAPI::None, stmt);
+ return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
+ ForType::Serial, DeviceAPI::None, stmt);
}
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
if (scope_.count(op->buffer_var.get())) {
- return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index,
- op->predicate);
+ return Store(scope_[op->buffer_var.get()].back(), op->value, op->index, op->predicate);
} else {
return stmt;
}
scope_[v.get()].push_back(new_var);
Stmt body = this->VisitStmt(op->body);
scope_[v.get()].pop_back();
- return LetStmtNode::make(new_var, value, body);
+ return LetStmt(new_var, value, body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<ForNode>();
- return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
+ return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<AllocateNode>();
- return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body);
+ return Allocate(new_var, op->dtype, op->extents, op->condition, op->body);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
if (new_alloc.same_as(op->body)) return GetRef<Stmt>(op);
alloc = new_alloc.as<AllocateNode>();
CHECK(alloc);
- return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc);
+ return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc);
}
}
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
if (scope_.count(v) && scope_[v].size() != 0) {
- return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body);
+ return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body);
} else {
return stmt;
}
PrimExpr value) {
Array<PrimExpr> args = {handle, make_const(DataType::Int(32), index),
make_const(DataType::Int(32), static_cast<int>(kind)), value};
- return EvaluateNode::make(
- Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic));
}
/*!
Stmt Lift(Stmt stmt) {
stmt = operator()(std::move(stmt));
if (attr_node_.defined()) {
- stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt);
+ stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt);
}
return stmt;
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (attr_node_.defined()) {
- Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body);
+ Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
- return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body);
+ return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body);
} else {
return stmt;
}
}
Stmt stmt = SeqStmt::Flatten(seq);
if (attr_node[begin].defined()) {
- stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt);
+ stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt);
}
reorg.push_back(stmt);
begin = end;
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElseNode::make(op->condition, then_case, else_case);
+ return IfThenElse(op->condition, then_case, else_case);
}
} else {
if (first_node.defined()) {
- then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case);
+ then_case = AttrStmt(first_node, attr_key_, first_value, then_case);
}
if (attr_node_.defined()) {
- else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case);
+ else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case);
// undefine them
attr_node_ = ObjectRef();
attr_value_ = PrimExpr();
if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElseNode::make(op->condition, then_case, else_case);
+ return IfThenElse(op->condition, then_case, else_case);
}
}
}
// add branch code inside the innermost thread scope
if (innermost_thread_scope_) {
Stmt simplified_body = ConditionEliminator(ps_)(op->body);
- Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body);
+ Stmt body = IfThenElse(cond_, simplified_body, op->body);
PrimExpr value = this->VisitExpr(op->value);
- stmt = AttrStmtNode::make(op->node, op->attr_key, value, body);
+ stmt = AttrStmt(op->node, op->attr_key, value, body);
}
innermost_thread_scope_ = false;
return stmt;
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
- return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
- for_node->for_type, for_node->device_api, body);
+ return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type,
+ for_node->device_api, body);
}
}
if (toBeLowered) {
auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes());
- return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents,
- allocate->condition, allocate->body);
+ return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents,
+ allocate->condition, allocate->body);
}
return stmt;
}
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
- return LetStmtNode::make(op->buffer_var, info->head_address, op->body);
+ return LetStmt(op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
if (it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
- stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition,
- op->body);
- stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt);
+ stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
+ stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt);
} else {
// use volatile access to shared buffer.
- stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body);
- stmt =
- AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
- stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt);
+ stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body);
+ stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
+ stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt);
}
return stmt;
} else {
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred));
+ seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred));
// Uses a local variable to store the shuffled data.
// Later on, this allocation will be properly attached to this statement.
Var var("t" + std::to_string(idx), types[idx]);
- Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0));
+ Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0));
local_vars.push_back(s);
}
PrimExpr pred = const_true(1);
PrimExpr mask =
Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
- seq.emplace_back(StoreNode::make(mask_var, mask, index, pred));
+ seq.emplace_back(Store(mask_var, mask, index, pred));
// Push allocation with an empty body. Later this will be fixed
// when the entire body is ready.
- auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred,
- EvaluateNode::make(0));
+ auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0));
local_vars.push_back(stmt);
}
const char* shfl_func = intrinsic::tvm_warp_shuffle_down;
PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset);
const AllocateNode* repl = local_vars[i].as<AllocateNode>();
- Stmt s = StoreNode::make(repl->buffer_var, other, index, pred);
+ Stmt s = Store(repl->buffer_var, other, index, pred);
seq.push_back(s);
PrimExpr load = Load(types[i], repl->buffer_var, index, pred);
for (size_t i = 0; i < size; ++i) {
Var var = shared_bufs[i];
PrimExpr pred = const_true(types[i].lanes());
- stores[i] = StoreNode::make(var, ret[i], index, pred);
+ stores[i] = Store(var, ret[i], index, pred);
}
seq.push_back(SeqStmt::Flatten(stores));
}
const char* shfl_func = intrinsic::tvm_warp_shuffle;
PrimExpr val = Load(types[i], var, index, pred);
PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0);
- seq.push_back(StoreNode::make(var, splat, index, pred));
+ seq.push_back(Store(var, splat, index, pred));
}
// Update existing allocations.
Var var = shared_bufs[i];
load_remap_[buffers[i]] = Load(types[i], var, index, pred);
Array<PrimExpr> extents{PrimExpr(1)};
- auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0));
+ auto node = Allocate(var, types[i], extents, pred, Evaluate(0));
alloc_remap_[buffers[i]] = node;
warp_allocs_.insert(node.get());
}
for (size_t i = 0; i < size; ++i) {
PrimExpr pred = const_true(types[i].lanes());
Var buffer_var = Downcast<Var>(call->args[2 + size + i]);
- stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
+ stores[i] = Store(buffer_var, values[i], 0, pred);
}
return SeqStmt::Flatten(stores);
}
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle());
PrimExpr pred = const_true(types[idx].lanes());
- seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx],
- BufIndex(reduce_index, group_index, reduce_extent), pred));
+ seq.emplace_back(Store(shared_bufs[idx], values[idx],
+ BufIndex(reduce_index, group_index, reduce_extent), pred));
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
load_remap_[buffers[idx]] =
Load(types[idx], shared_bufs[idx],
BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
- alloc_remap_[buffers[idx]] = AllocateNode::make(
- shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred,
- EvaluateNode::make(0));
+ alloc_remap_[buffers[idx]] =
+ Allocate(shared_bufs[idx], types[idx],
+ {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
}
}
for (auto var : local_vars) {
const AllocateNode* repl = var.as<AllocateNode>();
if (repl) {
- body =
- AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
- body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), body);
+ body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
+ body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body);
}
}
Array<PrimExpr> ret = (*combiner)(a, b);
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
- stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true());
+ stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true());
}
return SeqStmt::Flatten(stores);
};
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
- seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
+ seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
CHECK(threadx_extent >= 1 && warp_size_ >= 1);
while (reduce_align > threadx_extent || reduce_align > warp_size_) {
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
- seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align)));
+ seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread("shared"));
}
// in warp synchronization.
}
if (in_warp_seq.size() != 0) {
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
- seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body));
+ seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
seq.emplace_back(SyncThread("shared"));
}
return SeqStmt::Flatten(seq);
}
// sync thread op.
static Stmt SyncThread(const std::string& sync) {
- return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm(sync)}, CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)},
+ CallNode::Intrinsic));
}
// Emit warp shuffle intrinsic calls.
stack_tcode_ = Var("stack_tcode", DataType::Handle());
stmt = this->VisitStmt(stmt);
if (max_shape_stack_ != 0) {
- stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
+ stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
- stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt);
+ stmt = LetStmt(stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
- stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
- stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
+ stmt = LetStmt(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
+ stmt = LetStmt(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
}
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error = EvaluateNode::make(
- Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
+ Stmt throw_last_error =
+ Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic));
- Stmt body = SeqStmt({IfThenElseNode::make(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null,
- {op->buffer_var}, CallNode::PureIntrinsic),
- throw_last_error),
+ Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null,
+ {op->buffer_var}, CallNode::PureIntrinsic),
+ throw_last_error),
op->body});
- Stmt alloca = LetStmtNode::make(
+ Stmt alloca = LetStmt(
op->buffer_var,
Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace",
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
{cast(DataType::Int(32), device_type_),
cast(DataType::Int(32), device_id_), op->buffer_var},
CallNode::Extern);
- Stmt free_stmt =
- IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error);
+ Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
- body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment,
- make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
+ body = AttrStmt(op->buffer_var, attr::storage_alignment,
+ make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
return body;
}
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
for (size_t i = 0; i < op->args.size(); ++i) {
- prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]),
- ConstInt32(stack_begin + i), const_true(1)));
+ prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]),
+ ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(stack_shape_, DataType::Int(64), stack_begin);
}
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back(
- StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
+ Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
int arg_tcode = api_type.code();
CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
- StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
+ Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
warp_group_ = (alloc_size + (factor - 1)) / factor;
alloc_size = warp_group_ * factor;
- return AllocateNode::make(op->buffer_var, op->dtype,
- {make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
- this->VisitStmt(op->body));
+ return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
+ op->condition, this->VisitStmt(op->body));
}
protected:
if (op->buffer_var.get() == buffer_) {
PrimExpr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
- return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate);
+ return Store(op->buffer_var, op->value, local_index, op->predicate);
} else {
return StmtExprMutator::VisitStmt_(op);
}
warp_buffer_.insert(buf);
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
- return AttrStmtNode::make(op->node, op->attr_key, StringImm("local"), op->body);
+ return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body);
}
}
return StmtMutator::VisitStmt_(op);
namespace tir {
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
- return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImm(msg), EvaluateNode::make(0));
+ return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
std::string name_hint = global_symbol.value();
auto* func_ptr = func.CopyOnWrite();
- const Stmt nop = EvaluateNode::make(0);
+ const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
CHECK_LE(num_unpacked_args, num_args);
}
if (i < num_packed_args) {
// Value loads
- seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop));
+ seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
- seq_init.emplace_back(LetStmtNode::make(tcode,
- Load(DataType::Int(32), v_packed_arg_type_ids,
- IntImm(DataType::Int(32), i), const_true(1)),
- nop));
+ seq_init.emplace_back(LetStmt(tcode,
+ Load(DataType::Int(32), v_packed_arg_type_ids,
+ IntImm(DataType::Int(32), i), const_true(1)),
+ nop));
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
- seq_check.emplace_back(
- AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
- tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
- tvm::tir::StringImm(msg.str()), nop));
+ seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
+ tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
+ tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
- seq_check.emplace_back(
- AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
+ seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
- seq_check.emplace_back(
- AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
+ seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
} else {
args.push_back(v_arg);
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
- auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope,
- StringImm(name_hint + "_compute_"), func_ptr->body);
+ Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
+ StringImm(name_hint + "_compute_"), func_ptr->body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImm("default");
- seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop));
- seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop));
+ seq_check.push_back(AttrStmt(node, attr::device_context_id, device_id, nop));
+ seq_check.push_back(AttrStmt(node, attr::device_context_type, device_type, nop));
if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
- Stmt set_device = EvaluateNode::make(
- Call(DataType::Int(32), intrinsic::tvm_call_packed,
- {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
- CallNode::Intrinsic));
+ Stmt set_device =
+ Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed,
+ {StringImm(runtime::symbol::tvm_set_device), device_type, device_id},
+ CallNode::Intrinsic));
body = SeqStmt({set_device, body});
}
}
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
- Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate);
+ Stmt s = Store(op->buffer_var, op->value, index, op->predicate);
return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
}
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
- return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
- op->for_type, op->device_api, op->body);
+ return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type,
+ op->device_api, op->body);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag);
}
- return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
+ return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
CHECK(vmap_[v].same_as(new_iv->var));
}
Stmt body = this->VisitStmt(op->body);
- return AttrStmtNode::make(new_iv, op->attr_key, op->value, body);
+ return AttrStmt(new_iv, op->attr_key, op->value, body);
}
}
return StmtExprMutator::VisitStmt_(op);
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
- return IfThenElseNode::make(op->condition, op->then_case);
+ return IfThenElse(op->condition, op->then_case);
}
} else {
return stmt;
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (is_zero(op->extent)) {
- return EvaluateNode::make(0);
+ return Evaluate(0);
}
return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
}
}
Stmt VisitStmt_(const EvaluateNode* op) final {
if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
- return EvaluateNode::make(0);
+ return Evaluate(0);
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
private:
Stmt MakeEvaluate(PrimExpr value) {
if (HasSideEffect(value)) {
- return EvaluateNode::make(value);
+ return Evaluate(value);
} else {
- return EvaluateNode::make(0);
+ return Evaluate(0);
}
}
Stmt MakeEvaluate(const Array<PrimExpr>& values) {
for (PrimExpr e : values) {
if (HasSideEffect(e)) {
if (stmt.defined()) {
- stmt = SeqStmt({stmt, EvaluateNode::make(e)});
+ stmt = SeqStmt({stmt, Evaluate(e)});
} else {
- stmt = EvaluateNode::make(e);
+ stmt = Evaluate(e);
}
}
}
- return stmt.defined() ? stmt : EvaluateNode::make(0);
+ return stmt.defined() ? stmt : Evaluate(0);
}
};
if (const LoadNode* load = op->value.as<LoadNode>()) {
if (load->buffer_var.same_as(op->buffer_var) &&
tir::ExprDeepEqual()(load->index, op->index)) {
- return EvaluateNode::make(0);
+ return Evaluate(0);
}
}
return GetRef<Stmt>(op);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
- return AttrStmtNode::make(op->node, op->attr_key, value, body);
+ return AttrStmt(op->node, op->attr_key, value, body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
} else {
- return LetStmtNode::make(op->var, value, body);
+ return LetStmt(op->var, value, body);
}
}
}
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
- return EvaluateNode::make(
+ return Evaluate(
Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic));
}
if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) {
CHECK(it->second.as<VarNode>());
Var buf_var = Downcast<Var>(it->second);
- return StoreNode::make(buf_var, op->value, op->index, op->predicate);
+ return Store(buf_var, op->value, op->index, op->predicate);
} else {
return stmt;
}
Stmt body = this->VisitStmt(op->body);
auto it = buf_map_.find(buffer);
CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer;
- body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body));
+ body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body));
return body;
} else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
// To create bound attribute collector should has at least one item.
if (create_bound_attributes_ && shape_collector_.size()) {
for (size_t i = 0; i < shape_collector_.size(); ++i) {
- body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound,
- MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
+ body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound,
+ MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
}
}
return body;
}
if (strides.size() != 0) {
int first_dim = 0;
- ret = AllocateNode::make(e.buffer->data, storage_type,
- {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
- make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+ ret = Allocate(e.buffer->data, storage_type,
+ {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]},
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
} else {
shape = e.buffer->shape;
if (shape.size() == 0) {
shape.push_back(make_const(DataType::Int(32), 1));
}
- ret = AllocateNode::make(e.buffer->data, storage_type, shape,
- make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
+ ret = Allocate(e.buffer->data, storage_type, shape,
+ make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
- ret =
- AttrStmtNode::make(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
+ ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
- ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound,
- MakeBound(e.buffer->dtype, e.buffer->shape), ret);
+ ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
+ MakeBound(e.buffer->dtype, e.buffer->shape), ret);
}
return ret;
}
}
for (int i = starts; i >= 0; --i) {
if (i < starts) {
- stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None,
- stmt);
+ stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt);
} else {
PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
PrimExpr address =
Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic);
PrimExpr prefetch =
Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic);
- stmt = EvaluateNode::make(prefetch);
+ stmt = Evaluate(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
- stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
+ stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
}
}
return stmt;
for (StorageEntry* e : attach_map_.at(nullptr)) {
// CHECK_EQ(e->scope.rank, 0);
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
- StringImm(e->scope.to_string()),
- EvaluateNode::make(0)));
+ nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope,
+ StringImm(e->scope.to_string()), Evaluate(0)));
nest.push_back(e->new_alloc);
}
}
op = stmt.as<StoreNode>();
auto it = alloc_map_.find(op->buffer_var.get());
if (it == alloc_map_.end()) return stmt;
- return StoreNode::make(it->second->alloc_var, op->value,
- RemapIndex(op->value.dtype(), op->index, it->second), op->predicate);
+ return Store(it->second->alloc_var, op->value,
+ RemapIndex(op->value.dtype(), op->index, it->second), op->predicate);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
- return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
+ return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
op = stmt.as<AttrStmtNode>();
auto it = alloc_map_.find(op->node.as<VarNode>());
if (it == alloc_map_.end()) return stmt;
- return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body);
+ return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
- return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
- MakeAttach(svec, op->body));
+ return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
+ MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
if (e->new_alloc.defined()) {
- nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope,
- StringImm(e->scope.to_string()),
- EvaluateNode::make(0)));
+ nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope,
+ StringImm(e->scope.to_string()), Evaluate(0)));
nest.push_back(e->new_alloc);
}
}
if (e->allocs.size() == 1) {
// simply use the original allocation.
PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents);
- e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition,
- EvaluateNode::make(0));
+ e->new_alloc =
+ Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
- e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(),
- EvaluateNode::make(0));
+ e->new_alloc =
+ Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0));
if (e->scope.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size =
make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits);
- e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(),
- EvaluateNode::make(0));
+ e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0));
if (info.defined()) {
CHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
if (me->base % factor == 0 && me->coeff % factor == 0) {
extents.Set(extents.size() - 1,
extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
- return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body);
+ return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body);
}
}
return stmt;
std::string shape =
std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k);
PrimExpr shape_expr = StringImm(shape);
- Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
+ Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
if (info.layout != "") {
// Add shape attribute to matrix_a and matrix_b
- Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout,
- StringImm(info.layout), shape_attr);
+ Stmt layout_attr =
+ AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr);
return layout_attr;
} else {
return shape_attr;
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else {
- barrier =
- EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
+ barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
Stmt InitGlobalBarrier(const AttrStmtNode* op) {
CHECK(op != nullptr);
Array<PrimExpr> pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)};
- Stmt prep = EvaluateNode::make(
- Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
+ Stmt prep =
+ Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic));
Stmt body = op->body;
for (const auto& kv : rw_stats_) {
const auto& e = kv.second;
if (e.read_count != 0 && e.write_count != 0) {
- body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body);
+ body = AttrStmt(kv.first, attr::volatile_scope, 1, body);
}
}
rw_stats_.clear();
- Stmt kinit = EvaluateNode::make(
+ Stmt kinit = Evaluate(
Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic));
body = SeqStmt({kinit, body});
- body = AttrStmtNode::make(op->node, op->attr_key, op->value, body);
+ body = AttrStmt(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
}
Stmt MakeGlobalBarrier() {
} else {
CHECK_EQ(num_work_dim_, thread_extents_.size());
}
- return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
- {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
- CallNode::Intrinsic));
+ return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync,
+ {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_},
+ CallNode::Intrinsic));
}
// data structure.
StorageScope sync_scope_;
} else {
if (auto_unroll) {
if (op->for_type != ForType::Unrolled) {
- return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api,
- op->body);
+ return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api,
+ op->body);
}
}
return stmt;
int value = GetExtent(op);
// For loop must have a constant integer extent
CHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
- if (value == 0) return EvaluateNode::make(0);
+ if (value == 0) return Evaluate(0);
Stmt body = op->body;
Map<Var, PrimExpr> vmap;
Array<Stmt> unrolled;
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<StoreNode>();
if (op->buffer_var.get() == buf_) {
- return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_,
- op->predicate);
+ return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate);
} else {
return stmt;
}
} else {
int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
lanes = std::max(lanes, pred.dtype().lanes());
- return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
- BroadcastTo(pred, lanes));
+ return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
+ BroadcastTo(pred, lanes));
}
}
// For
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
+ return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
}
}
// IfThenElse
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
- return IfThenElseNode::make(condition, then_case, else_case);
+ return IfThenElse(condition, then_case, else_case);
}
}
// LetStmt
// rewrite access to buffer internally.
Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
- return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body);
+ return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
+ return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
private:
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
- return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api,
- op->body);
+ return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body);
} else {
return stmt;
}
void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); }
};
MyVisitor v;
- v.VisitStmt(EvaluateNode::make(z));
+ v.VisitStmt(Evaluate(z));
CHECK_EQ(v.count, 1);
}
MyVisitor v;
auto fmaketest = [&]() {
auto z = x + 1;
- Stmt body = EvaluateNode::make(z);
+ Stmt body = Evaluate(z);
Var buffer("b", DataType::Handle());
- return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body);
+ return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body);
};
v(fmaketest());
CHECK_EQ(v.count, 3);
};
auto fmakealloc = [&]() {
auto z = x + 1;
- Stmt body = EvaluateNode::make(z);
+ Stmt body = Evaluate(z);
Var buffer("b", DataType::Handle());
- return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body);
+ return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body);
};
auto fmakeif = [&]() {
auto z = x + 1;
- Stmt body = EvaluateNode::make(z);
- return IfThenElseNode::make(x, EvaluateNode::make(0), body);
+ Stmt body = Evaluate(z);
+ return IfThenElse(x, Evaluate(0), body);
};
MyVisitor v;
{
auto body = fmakealloc();
- Stmt body2 = EvaluateNode::make(1);
+ Stmt body2 = Evaluate(1);
Stmt bref = body.as<AllocateNode>()->body;
auto* extentptr = body.as<AllocateNode>()->extents.get();
Array<Stmt> arr{std::move(body), body2, body2};
}
{
- auto body = EvaluateNode::make(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
+ auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern));
auto res = v(std::move(body));
CHECK(res.as<EvaluateNode>()->value.as<CallNode>()->args[0].same_as(x));
}
{
- auto body = fmakealloc();
- Stmt body2 = EvaluateNode::make(1);
+ Stmt body = fmakealloc();
+ Stmt body2 = Evaluate(1);
auto* ref2 = body2.get();
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
{
// Cannot cow because of bref
- auto body = fmakealloc();
- Stmt body2 = EvaluateNode::make(1);
+ Stmt body = fmakealloc();
+ Stmt body2 = Evaluate(1);
auto* extentptr = body.as<AllocateNode>()->extents.get();
// construct a recursive SeqStmt.
body = SeqStmt({body});
}
auto body = fextern(input_placeholders, output_placeholders);
- auto body_stmt = tvm::tir::EvaluateNode::make(body);
+ auto body_stmt = tvm::tir::Evaluate(body);
auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders,
body_stmt);