From eafb2aa13d6cd223629f17d5f6aab5a8d4fce7f5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 11 Jun 2020 11:36:01 -0700 Subject: [PATCH] [TIR][REFACTOR][API-Change] Migrate the tvm/tir/expr.h to construct style. (#5773) This PR migrate tvm/tir/expr.h to the new constructor style that is consistent with the rest of the codebase and changes the affected files accordingly. --- docs/dev/relay_add_op.rst | 2 +- docs/dev/relay_add_pass.rst | 2 +- docs/dev/relay_pass_infra.rst | 2 +- include/tvm/relay/type.h | 3 +- include/tvm/runtime/container.h | 4 +- include/tvm/runtime/object.h | 2 + include/tvm/te/tensor_intrin.h | 34 +- include/tvm/tir/expr.h | 383 +++++-- include/tvm/tir/op.h | 8 +- include/tvm/tir/stmt.h | 2 +- include/tvm/tir/var.h | 51 +- src/arith/canonical_simplify.cc | 32 +- src/arith/const_fold.h | 40 +- src/arith/detect_linear_equation.cc | 4 +- src/arith/int_set.cc | 97 +- src/arith/ir_mutator_with_analyzer.cc | 14 +- src/arith/pattern_match.h | 103 +- src/arith/rewrite_simplify.cc | 36 +- src/autotvm/touch_extractor.cc | 8 +- src/ir/expr.cc | 2 +- src/relay/analysis/type_solver.cc | 4 +- src/relay/analysis/util.cc | 2 +- src/relay/ir/dataflow_matcher.cc | 2 +- src/relay/ir/expr.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/nn/upsampling.cc | 10 +- src/relay/op/tensor/transform.cc | 38 +- src/relay/op/tensor/transform.h | 6 +- src/relay/op/type_relations.cc | 4 +- src/target/intrin_rule.h | 2 +- src/target/llvm/codegen_arm.cc | 13 +- src/target/llvm/codegen_cpu.cc | 5 +- src/target/llvm/codegen_x86_64.cc | 17 +- src/target/llvm/intrin_rule_llvm.cc | 25 +- src/target/llvm/intrin_rule_llvm.h | 4 +- src/target/llvm/intrin_rule_nvptx.cc | 2 +- src/target/llvm/intrin_rule_rocm.cc | 18 +- src/target/source/intrin_rule_cuda.cc | 2 +- src/target/source/intrin_rule_opencl.cc | 2 +- src/target/spirv/intrin_rule_spirv.cc | 2 +- src/te/autodiff/ad_util.cc | 7 +- src/te/autodiff/adjoint.cc | 2 +- src/te/autodiff/jacobian.cc | 60 +- src/te/operation/compute_op.cc | 15 +- src/te/operation/cross_thread_reduction.cc | 14 +- src/te/operation/extern_op.cc | 2 +- src/te/operation/hybrid_op.cc | 6 +- src/te/operation/scan_op.cc | 10 +- src/te/operation/tensor_compute_op.cc | 6 +- src/te/operation/tensorize.cc | 8 +- src/te/schedule/operation_inline.cc | 2 +- src/te/schedule/schedule_dataflow_rewrite.cc | 26 +- src/te/schedule/schedule_lang.cc | 14 +- src/te/schedule/schedule_ops.cc | 3 +- .../schedule_postproc_rewrite_for_tensor_core.cc | 72 +- src/te/tensor.cc | 23 +- src/tir/ir/buffer.cc | 17 +- src/tir/ir/data_layout.cc | 14 +- src/tir/ir/expr.cc | 1120 +++++++++++--------- src/tir/ir/expr_functor.cc | 75 +- src/tir/ir/op.cc | 184 ++-- src/tir/ir/stmt.cc | 2 +- src/tir/ir/stmt_functor.cc | 2 +- src/tir/transforms/arg_binder.cc | 42 +- src/tir/transforms/bound_checker.cc | 22 +- src/tir/transforms/coproc_sync.cc | 21 +- src/tir/transforms/inject_copy_intrin.cc | 2 +- src/tir/transforms/inject_double_buffer.cc | 9 +- src/tir/transforms/inject_virtual_thread.cc | 7 +- src/tir/transforms/ir_util.cc | 5 +- src/tir/transforms/ir_util.h | 20 +- src/tir/transforms/loop_partition.cc | 16 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_intrin.cc | 22 +- src/tir/transforms/lower_thread_allreduce.cc | 45 +- src/tir/transforms/lower_tvm_builtin.cc | 48 +- src/tir/transforms/lower_warp_memory.cc | 15 +- src/tir/transforms/make_packed_api.cc | 35 +- src/tir/transforms/narrow_datatype.cc | 6 +- src/tir/transforms/rewrite_unsafe_select.cc | 4 +- src/tir/transforms/split_host_device.cc | 8 +- src/tir/transforms/storage_flatten.cc | 17 +- src/tir/transforms/storage_rewrite.cc | 13 +- src/tir/transforms/tensorcore_infer_fragment.cc | 4 +- src/tir/transforms/thread_storage_sync.cc | 21 +- src/tir/transforms/vectorize_loop.cc | 85 +- tests/cpp/arith_simplify_test.cc | 2 +- tests/cpp/ir_functor_test.cc | 3 +- tests/cpp/pattern_match_test.cc | 29 +- tests/cpp/utvm_runtime_standalone_test.cc | 4 +- topi/include/topi/contrib/cublas.h | 16 +- topi/include/topi/contrib/rocblas.h | 2 +- topi/include/topi/detail/extern.h | 18 +- topi/include/topi/elemwise.h | 11 +- topi/include/topi/nn.h | 22 +- topi/include/topi/nn/pooling.h | 67 +- topi/include/topi/reduction.h | 13 +- topi/include/topi/transform.h | 7 +- 98 files changed, 1808 insertions(+), 1530 deletions(-) diff --git a/docs/dev/relay_add_op.rst b/docs/dev/relay_add_op.rst index f494cc6..7dca251 100644 --- a/docs/dev/relay_add_op.rst +++ b/docs/dev/relay_add_op.rst @@ -99,7 +99,7 @@ the arguments to the call node, as below. TVM_REGISTER_GLOBAL("relay.op._make.add") .set_body_typed([](Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); }); Including a Python API Hook diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 3eb9586..2fc4636 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -261,7 +261,7 @@ the pass. body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(var, value, body); + return Let(var, value, body); } } } diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index b40b06e..6c2b139 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -349,7 +349,7 @@ registration. auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); auto y = relay::VarNode::make("y", tensor_type); - auto call = relay::CallNode::make(f, tvm::Array{ y }); + auto call = relay::Call(f, tvm::Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); // Create a module for optimization. diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 105f74e..a388c82 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -41,7 +41,8 @@ namespace relay { // namespace update for backward compact // will be removed later. -using Any = tvm::tir::AnyNode; +using AnyNode = tvm::tir::AnyNode; +using Any = tvm::tir::Any; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 2b3eb92..6753ec7 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -1337,9 +1337,6 @@ class String : public ObjectRef { #endif } - /*! \return the internal StringObj pointer */ - const StringObj* get() const { return operator->(); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: @@ -1502,6 +1499,7 @@ class Optional : public ObjectRef { * otherwise return the default_value. */ T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } + /*! \return Whether the container is not nullptr.*/ explicit operator bool() const { return *this != nullptr; } // operator overloadings diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 9387ed4..483ad6b 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -700,6 +700,7 @@ struct ObjectPtrEqual { explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ using ContainerType = ObjectName; /* @@ -713,6 +714,7 @@ struct ObjectPtrEqual { explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ static constexpr bool _type_is_nullable = false; \ using ContainerType = ObjectName; diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index 252c5f5..7e76efe 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -112,24 +112,6 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const { return static_cast(get()); } -// Internal node container of tensor intrinsic calling. -class TensorIntrinCallNode; - -/*! \brief Tensor intrinsic calling node. */ -class TensorIntrinCall : public ObjectRef { - public: - TensorIntrinCall() {} - explicit TensorIntrinCall(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorIntrinCallNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = TensorIntrinCallNode; -}; - class TensorIntrinCallNode : public Object { public: /*! \brief the tensor intrinsic */ @@ -155,16 +137,22 @@ class TensorIntrinCallNode : public Object { v->Visit("reduce_axis", &reduce_axis); v->Visit("scalar_inputs", &scalar_inputs); } - static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis, Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); }; -inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(get()); -} +/*! + * \brief Managed reference to TensorIntrinCallNode + * \sa TensorIntrinCallNode + */ +class TensorIntrinCall : public ObjectRef { + public: + TVM_DLL TensorIntrinCall(TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrinCall, ObjectRef, TensorIntrinCallNode); +}; } // namespace te } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index d34165e..423f09e 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -64,14 +64,17 @@ class StringImmNode : public PrimExprNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - TVM_DLL PrimExpr static make(std::string value); - static constexpr const char* _type_key = "StringImm"; TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); }; +/*! + * \brief Managed reference to StringImmNode. + * \sa StringImmNode + */ class StringImm : public PrimExpr { public: + TVM_DLL StringImm(std::string value); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); }; @@ -98,13 +101,21 @@ class CastNode : public PrimExprNode { hash_reduce(value); } - TVM_DLL static PrimExpr make(DataType t, PrimExpr v); - static constexpr const char* _type_key = "Cast"; TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); }; /*! + * \brief Managed reference to CastNode + * \sa CastNode + */ +class Cast : public PrimExpr { + public: + TVM_DLL Cast(DataType dtype, PrimExpr value); + TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); +}; + +/*! * \brief Base template to implement binary ops. * \tparam T The type of the child class. */ @@ -132,17 +143,6 @@ class BinaryOpNode : public PrimExprNode { hash_reduce(b); } - static PrimExpr make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined\n"; - CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - ObjectPtr node = make_object(); - node->dtype = a.dtype(); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; @@ -152,12 +152,32 @@ class AddNode : public BinaryOpNode { static constexpr const char* _type_key = "Add"; }; +/*! + * \brief Managed reference to AddNode + * \sa AddNode + */ +class Add : public PrimExpr { + public: + TVM_DLL Add(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); +}; + /*! \brief a - b */ class SubNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Sub"; }; +/*! + * \brief Managed reference to SubNode + * \sa SubNode + */ +class Sub : public PrimExpr { + public: + TVM_DLL Sub(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); +}; + /*! \brief a * b */ class MulNode : public BinaryOpNode { public: @@ -165,6 +185,16 @@ class MulNode : public BinaryOpNode { }; /*! + * \brief Managed reference to MulNode + * \sa MulNode + */ +class Mul : public PrimExpr { + public: + TVM_DLL Mul(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); +}; + +/*! * \brief a / b in the C semnatics. * \note For integer division, C standard uses trunc div. */ @@ -174,6 +204,16 @@ class DivNode : public BinaryOpNode { }; /*! + * \brief Managed reference to DivNode + * \sa DivNode + */ +class Div : public PrimExpr { + public: + TVM_DLL Div(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); +}; + +/*! * \brief a % b in the C semnatics. * \note For integer division, C standard uses trunc div. */ @@ -182,24 +222,64 @@ class ModNode : public BinaryOpNode { static constexpr const char* _type_key = "Mod"; }; +/*! + * \brief Managed reference to ModNode + * \sa ModNode + */ +class Mod : public PrimExpr { + public: + TVM_DLL Mod(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); +}; + /*! \brief Floor division, floor(a/b) */ class FloorDivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorDiv"; }; +/*! + * \brief Managed reference to FloorDivNode + * \sa FloorDivNode + */ +class FloorDiv : public PrimExpr { + public: + TVM_DLL FloorDiv(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); +}; + /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorMod"; }; +/*! + * \brief Managed reference to FloorModNode + * \sa FloorModNode + */ +class FloorMod : public PrimExpr { + public: + TVM_DLL FloorMod(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); +}; + /*! \brief min(a, b) */ class MinNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Min"; }; +/*! + * \brief Managed reference to MinNode + * \sa MinNode + */ +class Min : public PrimExpr { + public: + TVM_DLL Min(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); +}; + /*! \brief max(a, b) */ class MaxNode : public BinaryOpNode { public: @@ -207,6 +287,16 @@ class MaxNode : public BinaryOpNode { }; /*! + * \brief Managed reference to MaxNode + * \sa MaxNode + */ +class Max : public PrimExpr { + public: + TVM_DLL Max(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); +}; + +/*! * \brief Base template to implement comparison ops. * \tparam T The type of the child class. */ @@ -234,17 +324,6 @@ class CmpOpNode : public PrimExprNode { hash_reduce(b); } - static PrimExpr make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined\n"; - CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; @@ -254,36 +333,96 @@ class EQNode : public CmpOpNode { static constexpr const char* _type_key = "EQ"; }; +/*! + * \brief Managed reference to EQNode + * \sa EQNode + */ +class EQ : public PrimExpr { + public: + TVM_DLL EQ(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); +}; + /*! \brief a != b */ class NENode : public CmpOpNode { public: static constexpr const char* _type_key = "NE"; }; +/*! + * \brief Managed reference to NENode + * \sa NENode + */ +class NE : public PrimExpr { + public: + TVM_DLL NE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); +}; + /*! \brief a < b */ class LTNode : public CmpOpNode { public: static constexpr const char* _type_key = "LT"; }; +/*! + * \brief Managed reference to LTNode + * \sa LTNode + */ +class LT : public PrimExpr { + public: + TVM_DLL LT(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); +}; + /*! \brief a <= b */ struct LENode : public CmpOpNode { public: static constexpr const char* _type_key = "LE"; }; +/*! + * \brief Managed reference to LENode + * \sa LENode + */ +class LE : public PrimExpr { + public: + TVM_DLL LE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); +}; + /*! \brief a > b */ class GTNode : public CmpOpNode { public: static constexpr const char* _type_key = "GT"; }; +/*! + * \brief Managed reference to GTNode + * \sa GTNode + */ +class GT : public PrimExpr { + public: + TVM_DLL GT(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); +}; + /*! \brief a >= b */ class GENode : public CmpOpNode { public: static constexpr const char* _type_key = "GE"; }; +/*! + * \brief Managed reference to GENode + * \sa GENode + */ +class GE : public PrimExpr { + public: + TVM_DLL GE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); +}; + /*! \brief a && b */ class AndNode : public PrimExprNode { public: @@ -308,12 +447,20 @@ class AndNode : public PrimExprNode { hash_reduce(b); } - TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); - static constexpr const char* _type_key = "And"; TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); }; +/*! + * \brief Managed reference to AndNode + * \sa AndNode + */ +class And : public PrimExpr { + public: + TVM_DLL And(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); +}; + /*! \brief a || b */ class OrNode : public PrimExprNode { public: @@ -338,12 +485,20 @@ class OrNode : public PrimExprNode { hash_reduce(b); } - TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); - static constexpr const char* _type_key = "Or"; TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); }; +/*! + * \brief Managed reference to OrNode + * \sa OrNode + */ +class Or : public PrimExpr { + public: + TVM_DLL Or(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); +}; + /*! \brief !a */ class NotNode : public PrimExprNode { public: @@ -364,13 +519,21 @@ class NotNode : public PrimExprNode { hash_reduce(a); } - TVM_DLL static PrimExpr make(PrimExpr a); - static constexpr const char* _type_key = "Not"; TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); }; /*! + * \brief Managed reference to NotNode + * \sa NotNode + */ +class Not : public PrimExpr { + public: + TVM_DLL Not(PrimExpr a); + TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); +}; + +/*! * \brief return true_value if condition is true, otherwise return false_value. * \note Both true_value and false_value could be evaluated * regardless of the condition value. @@ -405,13 +568,22 @@ class SelectNode : public PrimExprNode { hash_reduce(false_value); } - TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); - static constexpr const char* _type_key = "Select"; TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; /*! + * \brief Managed reference to SelectNode + * \sa SelectNode + */ +class Select : public PrimExpr { + public: + TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); + + TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); +}; + +/*! * \brief Load value from the high dimension buffer. * * \code @@ -550,13 +722,21 @@ class LoadNode : public PrimExprNode { hash_reduce(predicate); } - TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); - static constexpr const char* _type_key = "Load"; TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); }; /*! + * \brief Managed reference to LoadNode + * \sa LoadNode + */ +class Load : public PrimExpr { + public: + TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); + TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); +}; + +/*! * \brief Construct a vector with lanes elements * where its i-th element equals base + i * stride. * This is useful to construct a index for a continuous vector load. @@ -593,12 +773,20 @@ class RampNode : public PrimExprNode { hash_reduce(lanes); } - TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes); - static constexpr const char* _type_key = "Ramp"; TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); }; +/*! + * \brief Managed reference to RampNode + * \sa RampNode + */ +class Ramp : public PrimExpr { + public: + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes); + TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); +}; + /*! \brief Create a vector where all the elements are value. */ class BroadcastNode : public PrimExprNode { public: @@ -623,13 +811,21 @@ class BroadcastNode : public PrimExprNode { hash_reduce(lanes); } - TVM_DLL static PrimExpr make(PrimExpr value, int lanes); - static constexpr const char* _type_key = "Broadcast"; TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); }; /*! + * \brief Managed reference to BroadcastNode + * \sa BroadcastNode + */ +class Broadcast : public PrimExpr { + public: + TVM_DLL Broadcast(PrimExpr value, int lanes); + TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); +}; + +/*! * \brief Let binding. Bind var to value then evaluate body. */ class LetNode : public PrimExprNode { @@ -660,12 +856,20 @@ class LetNode : public PrimExprNode { hash_reduce(body); } - TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body); - static constexpr const char* _type_key = "Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); }; +/*! + * \brief Managed reference to LetNode + * \sa LetNode + */ +class Let : public PrimExpr { + public: + TVM_DLL Let(Var var, PrimExpr value, PrimExpr body); + TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); +}; + // Call node, represent a function call or a multi-dimensional array load. // // TODO(tvm-team): @@ -744,9 +948,6 @@ class CallNode : public PrimExprNode { hash_reduce(call_type); } - TVM_DLL static PrimExpr make(DataType dtype, std::string name, Array args, - CallType call_type); - /*! \return Whether call node is pure. */ bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } @@ -784,6 +985,18 @@ class CallNode : public PrimExprNode { }; /*! + * \brief Managed reference to CallNode + * \sa CallNode + */ +class Call : public PrimExpr { + public: + using CallType = CallNode::CallType; + + TVM_DLL Call(DataType dtype, std::string name, Array args, CallType call_type); + TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); +}; + +/*! * \brief Shuffle instruction. * vec = concat(vectors) * result = (vec[indices[0]], vec[indices[1]] ...) @@ -811,35 +1024,24 @@ class ShuffleNode : public PrimExprNode { hash_reduce(indices); } - TVM_DLL static PrimExpr make(Array vectors, Array indices); - TVM_DLL static PrimExpr make_concat(Array vectors); - TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index); - static constexpr const char* _type_key = "Shuffle"; TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); }; -// Reduce operator -class CommReducerNode; - -class CommReducer : public ObjectRef { +/*! + * \brief Managed reference to ShuffleNode + * \sa ShuffleNode + */ +class Shuffle : public PrimExpr { public: - CommReducer() {} - explicit CommReducer(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const CommReducerNode* get() const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const CommReducerNode* operator->() const; - /*! \brief type indicate the container type */ - using ContainerType = CommReducerNode; + TVM_DLL Shuffle(Array vectors, Array indices); + TVM_DLL static PrimExpr Concat(Array vectors); + TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); }; +// Reduce operator /*! * \brief A commutative reducer node to represent a commutative * binary operator with identity element @@ -860,9 +1062,6 @@ class CommReducerNode : public Object { Array identity_element; /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; - /*! \brief construct CommReducer from args, result and identity_element */ - TVM_DLL static CommReducer make(Array lhs, Array rhs, Array result, - Array identity_element); void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); @@ -889,10 +1088,17 @@ class CommReducerNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; -inline const CommReducerNode* CommReducer::get() const { - return static_cast(data_.get()); -} -inline const CommReducerNode* CommReducer::operator->() const { return get(); } +/*! + * \brief Managed reference to CommReducerNode + * \sa CommReducerNode + */ +class CommReducer : public ObjectRef { + public: + TVM_DLL CommReducer(Array lhs, Array rhs, Array result, + Array identity_element); + + TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); +}; /*! \brief Reduction operator operator */ class ReduceNode : public PrimExprNode { @@ -911,10 +1117,6 @@ class ReduceNode : public PrimExprNode { /*! \brief the index of this reduce node */ int value_index; - /*! \brief construct expr from op and rdom */ - TVM_DLL static PrimExpr make(CommReducer combiner, Array src, Array rdom, - PrimExpr condition, int value_index); - void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("combiner", &combiner); @@ -944,6 +1146,18 @@ class ReduceNode : public PrimExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; +/*! + * \brief Managed reference to ReduceNode + * \sa ReduceNode + */ +class Reduce : public PrimExpr { + public: + TVM_DLL Reduce(CommReducer combiner, Array src, Array rdom, PrimExpr condition, + int value_index); + + TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); +}; + /*! \brief Any shape. */ class AnyNode : public PrimExprNode { public: @@ -956,12 +1170,21 @@ class AnyNode : public PrimExprNode { /*! \brief Convert to var. */ Var ToVar() const { return Var("any_dim", DataType::Int(32)); } - TVM_DLL static PrimExpr make(); - static constexpr const char* _type_key = "Any"; TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; +/*! + * \brief Managed reference to AnyNode + * \sa AnyNode + */ +class Any : public PrimExpr { + public: + TVM_DLL Any(); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); +}; + /* * \brief Template function to convert Map to unordered_map * Sometimes useful for API gluing when internal uses unordered_map diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 8d2add2..71e9ac4 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -552,9 +552,9 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ } TVM_DECLARE_INTRIN_UNARY(exp); @@ -768,7 +768,7 @@ inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return tir::BroadcastNode::make(MakeConstScalar(t.element_of(), value), t.lanes()); + return tir::Broadcast(MakeConstScalar(t.element_of(), value), t.lanes()); } } diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 118ec0f..d4c813e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -904,7 +904,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \return Expr a expression with dtype. */ inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::CallNode::make(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); + return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); } // overload printing of for type. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 4db462d..363bf6b 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -83,7 +83,7 @@ class VarNode : public PrimExprNode { TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); }; -/*! \brief a named variable in TVM */ +/*! \brief a named variable in TIR */ class Var : public PrimExpr { public: explicit Var(ObjectPtr n) : PrimExpr(n) {} @@ -105,6 +105,7 @@ class Var : public PrimExpr { * \return the new Var copy */ TVM_DLL Var copy_with_suffix(const String& suffix) const; + /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -153,9 +154,6 @@ class SizeVar : public Var { using ContainerType = SizeVarNode; }; -/*! \brief container class of iteration variable. */ -class IterVarNode; - using Region = Array; /*! @@ -228,29 +226,6 @@ enum IterVarType : int { kTensorized = 8 }; -/*! - * \brief Iteration Variable, - * represents an iteration over an integer interval. - */ -class IterVar : public ObjectRef { - public: - // construct a new iter var without a domain - IterVar() {} - // construct from shared ptr. - explicit IterVar(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const IterVarNode* operator->() const; - /*! - * \return the corresponding var in the IterVar. - */ - inline operator PrimExpr() const; - /*! \brief specify container node */ - using ContainerType = IterVarNode; -}; - using Domain = Array; /*! @@ -293,20 +268,28 @@ class IterVarNode : public Object { hash_reduce(thread_tag); } - TVM_DLL static IterVar make(Range dom, Var var, IterVarType iter_type, - std::string thread_tag = ""); - static constexpr const char* _type_key = "IterVar"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); }; -// inline implementations -inline const IterVarNode* IterVar::operator->() const { - return static_cast(data_.get()); -} +/*! + * \brief Iteration Variable, + * represents an iteration over an integer interval. + */ +class IterVar : public ObjectRef { + public: + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; + + TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode); +}; +// inline implementations inline IterVar::operator PrimExpr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 2738707..b81565f 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -551,7 +551,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -576,7 +576,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -601,7 +601,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // x * c @@ -626,7 +626,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return MulNode::make(a, b); + return Mul(a, b); } } @@ -704,7 +704,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold
(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -750,7 +750,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return DivNode::make(a, b); + return Div(a, b); } } @@ -762,7 +762,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -804,7 +804,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorDivNode::make(a, b); + return FloorDiv(a, b); } } @@ -865,7 +865,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -920,7 +920,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return ModNode::make(a, b); + return Mod(a, b); } } @@ -933,7 +933,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -978,7 +978,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorModNode::make(a, b); + return FloorMod(a, b); } } @@ -1045,8 +1045,8 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) } } - CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return ReduceNode::make(new_combiner, new_source, op->axis, op->condition, new_value_index); + CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); + return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index); } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { @@ -1060,8 +1060,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return this->VisitExpr(SelectNode::make(op->condition, op->source[op->value_index], - op->combiner->identity_element[op->value_index])); + return this->VisitExpr(Select(op->condition, op->source[op->value_index], + op->combiner->identity_element[op->value_index])); } // combiner simplification. ret = SimplifyReduceCombiner(op); diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index ad6570e..876d336 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -46,9 +46,7 @@ namespace arith { * \return nullptr if constant fold fails, otherwise return folded result. */ template -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - return PrimExpr(); -} +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b); /*! * \brief Try to run unary compute with constant folding. @@ -94,7 +92,7 @@ inline bool IsIndexType(const DataType& type) { // specialization of constant folders. template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value + pb->value); @@ -108,7 +106,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value - pb->value); @@ -120,7 +118,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, pa->value * pb->value); @@ -146,7 +144,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -175,7 +173,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -194,7 +192,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -221,7 +219,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -240,7 +238,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); @@ -251,7 +249,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); @@ -262,7 +260,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); @@ -271,7 +269,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); @@ -280,7 +278,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); @@ -289,7 +287,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); @@ -298,7 +296,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); @@ -307,7 +305,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); @@ -316,7 +314,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -327,7 +325,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -338,7 +336,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline PrimExpr TryConstFold(PrimExpr a) { +inline PrimExpr TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 2bc7209..f0634fe 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -213,7 +213,7 @@ bool DetectClipBound(const PrimExpr& cond, if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { - p.min_value = tir::MaxNode::make(p.min_value, -ret.base); + p.min_value = max(p.min_value, -ret.base); } else { p.min_value = -ret.base; } @@ -222,7 +222,7 @@ bool DetectClipBound(const PrimExpr& cond, if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { - p.max_value = tir::MinNode::make(p.max_value, ret.base); + p.max_value = min(p.max_value, ret.base); } else { p.max_value = ret.base; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 7462808..b043b35 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -87,15 +87,15 @@ struct is_logical_op { static const bool value = true; \ }; -TVM_DECLARE_LOGICAL_OP(AndNode); -TVM_DECLARE_LOGICAL_OP(OrNode); -TVM_DECLARE_LOGICAL_OP(EQNode); -TVM_DECLARE_LOGICAL_OP(NENode); -TVM_DECLARE_LOGICAL_OP(GENode); -TVM_DECLARE_LOGICAL_OP(GTNode); -TVM_DECLARE_LOGICAL_OP(LENode); -TVM_DECLARE_LOGICAL_OP(LTNode); -TVM_DECLARE_LOGICAL_OP(NotNode); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); /*! * \brief Combine two interval set under arithmetic operations. @@ -105,7 +105,7 @@ template inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); - if (!res.defined()) res = Op::make(a->min_value, b->min_value); + if (!res.defined()) res = Op(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { @@ -119,7 +119,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -133,7 +133,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, Inter } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -147,7 +147,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, Inter } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -168,11 +168,11 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Inte PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Mul"; @@ -180,7 +180,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Inte } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -201,11 +201,11 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Inte PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -213,7 +213,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Inte } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -241,7 +241,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Inte } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -262,11 +262,11 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -274,7 +274,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -308,7 +308,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -318,7 +318,7 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Inte } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -377,39 +377,39 @@ class IntervalSetEvaluator : public ExprFunctor { } } - IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_
(op); } - IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); @@ -419,11 +419,11 @@ class IntervalSetEvaluator : public ExprFunctor { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -452,14 +452,15 @@ class IntervalSetEvaluator : public ExprFunctor { return set->min_value.same_as(value) && set->max_value.same_as(value); } - template + template inline IntervalSet VisitBinaryExpr_(const T* op) { + static_assert(std::is_same::value, "constraint"); IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(analyzer_, a, b); + return Combine(analyzer_, a, b); } // recursive depth diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index e6f37f4..f4bb9c2 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -68,8 +68,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { - With ctx(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(real_condition))); + With ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition))); else_case = this->VisitStmt(op->else_case); } if (is_one(real_condition)) return then_case; @@ -131,8 +130,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { true_value = this->VisitExpr(op->args[1]); } { - With constraint(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { @@ -145,7 +143,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); + return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); @@ -162,7 +160,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } @@ -174,7 +172,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { true_value = VisitExpr(op->true_value); } { - With constraint(analyzer_, analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -188,7 +186,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return SelectNode::make(cond, true_value, false_value); + return Select(cond, true_value, false_value); } } diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 2a02303..ff01941 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -223,12 +223,12 @@ class PConst : public Pattern> { /*! * \brief Pattern binary expression. - * \tparam NodeType The AST node type. + * \tparam OpType The AST noderef type. * \tparam TA The pattern type of the first operand. * \tparam TB The pattern type of the second operand. */ -template -class PBinaryExpr : public Pattern> { +template +class PBinaryExpr : public Pattern> { public: PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} @@ -238,6 +238,7 @@ class PBinaryExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { + using NodeType = typename OpType::ContainerType; if (const NodeType* ptr = node.as()) { if (!a_.Match_(ptr->a)) return false; if (!b_.Match_(ptr->b)) return false; @@ -250,9 +251,9 @@ class PBinaryExpr : public Pattern> { PrimExpr Eval() const { PrimExpr lhs = a_.Eval(); PrimExpr rhs = b_.Eval(); - PrimExpr ret = TryConstFold(lhs, rhs); + PrimExpr ret = TryConstFold(lhs, rhs); if (ret.defined()) return ret; - return NodeType::make(lhs, rhs); + return OpType(lhs, rhs); } private: @@ -304,30 +305,30 @@ class PConstWithTypeLike : public Pattern> { #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, tir::ModNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, tir::AddNode); -TVM_PATTERN_BINARY_OP(operator-, tir::SubNode); -TVM_PATTERN_BINARY_OP(operator*, tir::MulNode); -TVM_PATTERN_BINARY_OP(min, tir::MinNode); -TVM_PATTERN_BINARY_OP(max, tir::MaxNode); -TVM_PATTERN_BINARY_OP(div, tir::DivNode); -TVM_PATTERN_BINARY_OP(truncdiv, tir::DivNode); -TVM_PATTERN_BINARY_OP(truncmod, tir::ModNode); -TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDivNode); -TVM_PATTERN_BINARY_OP(floormod, tir::FloorModNode); +TVM_PATTERN_BINARY_OP(operator+, tir::Add); +TVM_PATTERN_BINARY_OP(operator-, tir::Sub); +TVM_PATTERN_BINARY_OP(operator*, tir::Mul); +TVM_PATTERN_BINARY_OP(min, tir::Min); +TVM_PATTERN_BINARY_OP(max, tir::Max); +TVM_PATTERN_BINARY_OP(div, tir::Div); +TVM_PATTERN_BINARY_OP(truncdiv, tir::Div); +TVM_PATTERN_BINARY_OP(truncmod, tir::Mod); +TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv); +TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, tir::GTNode); -TVM_PATTERN_BINARY_OP(operator>=, tir::GENode); -TVM_PATTERN_BINARY_OP(operator<, tir::LTNode); -TVM_PATTERN_BINARY_OP(operator<=, tir::LENode); -TVM_PATTERN_BINARY_OP(operator==, tir::EQNode); -TVM_PATTERN_BINARY_OP(operator!=, tir::NENode); -TVM_PATTERN_BINARY_OP(operator&&, tir::AndNode); -TVM_PATTERN_BINARY_OP(operator||, tir::OrNode); +TVM_PATTERN_BINARY_OP(operator>, tir::GT); +TVM_PATTERN_BINARY_OP(operator>=, tir::GE); +TVM_PATTERN_BINARY_OP(operator<, tir::LT); +TVM_PATTERN_BINARY_OP(operator<=, tir::LE); +TVM_PATTERN_BINARY_OP(operator==, tir::EQ); +TVM_PATTERN_BINARY_OP(operator!=, tir::NE); +TVM_PATTERN_BINARY_OP(operator&&, tir::And); +TVM_PATTERN_BINARY_OP(operator||, tir::Or); /*! * \brief Pattern not expression. @@ -349,7 +350,7 @@ class PNotExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::NotNode::make(value_.Eval()); } + PrimExpr Eval() const { return tir::Not(value_.Eval()); } private: typename TA::Nested value_; @@ -391,7 +392,7 @@ class PSelectExpr : public Pattern> { } PrimExpr Eval() const { - return tir::SelectNode::make(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: @@ -446,7 +447,7 @@ class PCastExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::CastNode::make(dtype_.Eval(), value_.Eval()); } + PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; @@ -498,7 +499,7 @@ class PRampExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); } + PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; @@ -558,7 +559,7 @@ class PBroadcastExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); } + PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; @@ -674,16 +675,16 @@ class PCallExpr : public Pattern> { }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ - return PCallExpr(a.derived(), b.derived()); \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); @@ -693,16 +694,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); @@ -710,7 +711,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::CallNode::make(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); + return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4149b15..ce3f2a6 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -118,7 +118,7 @@ void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -225,7 +225,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -408,7 +408,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -441,7 +441,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold
(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -615,7 +615,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -696,7 +696,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -813,7 +813,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -878,7 +878,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1050,7 +1050,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1213,7 +1213,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1243,11 +1243,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { - return this->VisitExpr(NotNode::make(op->a == op->b)); + return this->VisitExpr(Not(op->a == op->b)); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { - return this->VisitExpr(NotNode::make(op->b < op->a)); + return this->VisitExpr(Not(op->b < op->a)); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { @@ -1255,13 +1255,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { - return this->VisitExpr(NotNode::make(op->a < op->b)); + return this->VisitExpr(Not(op->a < op->b)); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1392,7 +1392,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a); + PrimExpr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y; @@ -1416,7 +1416,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1455,7 +1455,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1559,7 +1559,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 02dae64..91e2ee1 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -250,10 +250,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto var : vars) { Array > feature_row; ItervarFeature& fea = touch_analyzer.itervar_map[var]; - feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); + feature_row.push_back(Array{tvm::tir::StringImm("_itervar_"), var}); Array attr{ - tvm::tir::StringImmNode::make("_attr_"), + tvm::tir::StringImm("_attr_"), FloatImm(DataType::Float(32), trans(fea.length)), IntImm(DataType::Int(32), fea.nest_level), FloatImm(DataType::Float(32), trans(fea.topdown_product)), @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > // arithmetic feature_row.push_back(Array{ - tvm::tir::StringImmNode::make("_arith_"), + tvm::tir::StringImm("_arith_"), FloatImm(DataType::Float(32), trans(fea.add_ct)), FloatImm(DataType::Float(32), trans(fea.mul_ct)), FloatImm(DataType::Float(32), trans(fea.div_ct)), @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto k : bufs) { TouchPattern& v = fea.touch_feature[k]; feature_row.push_back(Array{ - tvm::tir::StringImmNode::make(k), + tvm::tir::StringImm(k), FloatImm(DataType::Float(32), trans(v.stride)), FloatImm(DataType::Float(32), trans(v.mod)), FloatImm(DataType::Float(32), trans(v.count)), diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 97e285c..289477e 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -47,7 +47,7 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { return GetRef(ptr)(); } if (auto* ptr = ref.as()) { - return tir::StringImmNode::make(GetRef(ptr)); + return tir::StringImm(GetRef(ptr)); } CHECK(ObjectTypeChecker::Check(ref.get())) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 05e231a..a192002 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -175,8 +175,8 @@ class TypeSolver::Unifier : public TypeFunctor { if (ulhs.same_as(urhs)) { return ulhs; } - if (ulhs.as() || urhs.as()) { - return Any::make(); + if (ulhs.as() || urhs.as()) { + return Any(); } auto left_index0 = ulhs.as(); diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 0885a35..b681b90 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -420,7 +420,7 @@ struct IsDynamicVisitor : public TypeVisitor { bool is_dyn{false}; void VisitType_(const TensorTypeNode* tt) { for (auto dim : tt->shape) { - if (dim.as()) { + if (dim.as()) { is_dyn = true; break; } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index d01a1e7..c9bf11e 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -156,7 +156,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this // and replace the whole thing with a Visitor-based approach ReflectionVTable* reflection = ReflectionVTable::Global(); - auto attrs_node = const_cast(op->attrs.get()); + auto attrs_node = const_cast(op->attrs.get()); auto attr_names = reflection->ListAttrNames(attrs_node); for (auto kv : attributes) { std::string attr = kv.first; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index c2f3aef..1d9e3ce 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -246,7 +246,7 @@ TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any::make(); }); +TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any(); }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 3db8eee..b02fe86 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -56,7 +56,7 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, out_shape.push_back(ck); } } else { - out_shape.push_back(Any::make()); + out_shape.push_back(Any()); } } auto values_ty = TensorType(out_shape, data->dtype); diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 3228b72..cb20881 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -81,8 +81,8 @@ bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); - oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); + oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); + oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); @@ -149,9 +149,9 @@ bool UpSampling3DRel(const Array& types, int num_inputs, const Attrs& attr << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); - oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); - oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); + oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); + oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); + oshape.Set(4, tir::Cast(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6544468..9d87610 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -481,7 +481,7 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, // Doesn't support dynamic output rank for (int i = 0; i < newshape->shape[0].as()->value; i++) { - oshape.push_back(Any::make()); + oshape.push_back(Any()); } reporter->Assign(types[2], TensorType(oshape, data->dtype)); @@ -526,8 +526,8 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, used_input_dims.insert(src_idx); IndexExpr d2 = data_shape[src_idx++]; used_output_dims.insert(oshape.size()); - if (d1.as() || d2.as()) { - oshape.push_back(Any::make()); + if (d1.as() || d2.as()) { + oshape.push_back(Any()); } else { oshape.push_back(d1 * d2); } @@ -543,8 +543,8 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, if (d1->value == -1) { CHECK(d2->value != -1) << "Split dims cannot both be -1."; used_output_dims.insert(oshape.size()); - if (d0.as()) { - oshape.push_back(Any::make()); + if (d0.as()) { + oshape.push_back(Any()); } else { oshape.push_back(indexdiv(d0, d2)); } @@ -555,8 +555,8 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.push_back(d1); used_output_dims.insert(oshape.size()); if (d2->value == -1) { - if (d0.as()) { - oshape.push_back(Any::make()); + if (d0.as()) { + oshape.push_back(Any()); } else { oshape.push_back(indexdiv(d0, d1)); } @@ -575,19 +575,19 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, if (used_input_dims.count(i) != 0) { continue; } - if (data_shape[i].as()) { - infer_dim = Any::make(); + if (data_shape[i].as()) { + infer_dim = Any(); break; } infer_dim *= data_shape[i]; } - if (!infer_dim.as()) { + if (!infer_dim.as()) { for (size_t i = 0; i < oshape.size(); ++i) { if (used_output_dims.count(i) != 0) { continue; } - if (oshape[i].as()) { - infer_dim = Any::make(); + if (oshape[i].as()) { + infer_dim = Any(); break; } infer_dim = indexdiv(infer_dim, oshape[i]); @@ -759,7 +759,7 @@ bool ArgWhereRel(const Array& types, int num_inputs, const Attrs& attrs, const auto& input_shape = tt->shape; const auto& input_rank = input_shape.size(); std::vector result_shape; - result_shape.push_back(Any::make()); + result_shape.push_back(Any()); result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32))); return true; @@ -960,7 +960,7 @@ bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, } } else { for (int i = 0; i < shape_shape->value; ++i) { - oshape.push_back(Any::make()); + oshape.push_back(Any()); } } reporter->Assign(types[2], TensorType(oshape, out_dtype)); @@ -1016,7 +1016,7 @@ bool InitOpRel(const Array& types, int num_inputs, const Attrs& attrs, } } else { for (int i = 0; i < shape_shape->value; ++i) { - oshape.push_back(Any::make()); + oshape.push_back(Any()); } } reporter->Assign(types[1], TensorType(oshape, out_dtype)); @@ -1136,7 +1136,7 @@ bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); return true; } else { - reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype)); + reporter->Assign(types[3], TensorType({Any()}, attrs->dtype)); return true; } } @@ -1331,7 +1331,7 @@ bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, for (size_t i = 0; i < tndim; ++i) { // Save Any if it is dynamic shape if (!data_shape[i].as()) { - oshape.emplace_back(Any::make()); + oshape.emplace_back(Any()); } else { oshape.emplace_back(data_shape[i] * reps_shape[i]); } @@ -1641,7 +1641,7 @@ bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs } } else { for (int i = 0; i < shape_shape->value; ++i) { - oshape.push_back(Any::make()); + oshape.push_back(Any()); } } reporter->Assign(types[2], TensorType(oshape, out_dtype)); @@ -1817,7 +1817,7 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr } } else { for (int64_t i = 0; i < num_axis; ++i) { - oshape[i] = Any::make(); + oshape[i] = Any(); } } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 1f30b68..7149417 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -99,16 +99,16 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs std::vector non_any; for (int j = 0; j < data_length; ++j) { const auto& e = Downcast(tensor_tuple->fields[j]); - if (!e->shape[i].as()) { + if (!e->shape[i].as()) { non_any.push_back(e->shape[i]); // accumulate axis dimension - if (j > 0 && i == axis && !oshape[i].as()) { + if (j > 0 && i == axis && !oshape[i].as()) { oshape[i] += e->shape[i]; } } } int non_any_size = static_cast(non_any.size()); - if (non_any_size != data_length) oshape[i] = Any::make(); + if (non_any_size != data_length) oshape[i] = Any(); if (i != axis) { for (int k = 1; k < non_any_size; k++) { if (reporter->AssertEQ(non_any[0], non_any[k])) continue; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 677683c..46143d1 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -75,10 +75,10 @@ Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType outp oshape.push_back(s2); } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); - } else if (s1.as()) { + } else if (s1.as()) { // s1 == 1 || s1 == s2 oshape.push_back(s2); - } else if (s2.as()) { + } else if (s2.as()) { // s2 == 1 || s2 == s1 oshape.push_back(s1); } else if (EqualCheck(s1, s2)) { diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 8a5a440..5a23e83 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -60,7 +60,7 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); std::string name = T()(call->dtype, call->name); if (name.length() != 0) { - *rv = CallNode::make(call->dtype, name, call->args, CallNode::PureExtern); + *rv = Call(call->dtype, name, call->args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index ba45115..991d473 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -94,16 +94,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = - tir::CallNode::make(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = - tir::CallNode::make(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -113,8 +111,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = - tir::CallNode::make(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -124,7 +121,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index c70a1ab..05c2ef2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -918,8 +919,8 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { op->loop_var, op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; - PrimExpr begin = MinNode::make(task_id * step, op->extent); - PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); + PrimExpr begin = min(task_id * step, op->extent); + PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index d0038b8..edffda2 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -89,10 +89,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::CallNode::make(DataType::Int(16, from.lanes()), - tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic)), - MakeValue(tir::BroadcastNode::make(FloatImm(DataType::Float(32), 0), from.lanes())), + MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic)), + MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); @@ -103,11 +102,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); if (from.lanes() >= 8 && has_f16c) { - return CallVectorIntrin(::llvm::Intrinsic::x86_vcvtph2ps_256, 8, - DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, - {op->value}, tir::CallNode::PureIntrinsic))}); + return CallVectorIntrin( + ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, + DTypeToLLVMType(DataType::Float(32, from.lanes())), + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic))}); } #endif } diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index d0bef46..8804b1e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -48,8 +48,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") CHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = - tir::CallNode::make(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); *rv = ret; }); @@ -98,14 +97,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_two = make_const(x.dtype(), -2); - PrimExpr exp_neg2x = - tir::CallNode::make(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = - tir::CallNode::make(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = tir::SelectNode::make(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + *rv = tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") @@ -119,8 +116,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::CallNode::make(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::CallNode::make(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); @@ -138,9 +135,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = - tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); PrimExpr ret = (exp_posx + exp_negx) / two; *rv = ret; }); @@ -158,9 +154,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = - tir::CallNode::make(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); PrimExpr ret = (exp_posx - exp_negx) / two; *rv = ret; }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 8c5053b..5613621 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -49,7 +49,7 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } template @@ -64,7 +64,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index ffe35ca..49c2224 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -39,7 +39,7 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << call->name; if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 22af9f1..3a2b8ac 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -38,7 +38,7 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { @@ -53,10 +53,10 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); - PrimExpr lo = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, - CallNode::PureExtern); - PrimExpr self = CallNode::make(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, - CallNode::PureExtern); + PrimExpr lo = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern); + PrimExpr self = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern); // compute lane to get from PrimExpr width = call->args[3]; @@ -67,15 +67,15 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { } else if (call->name == "tvm_warp_shuffle_up") { PrimExpr delta = call->args[2]; index = self - delta; - index = SelectNode::make(index < (self & ~(width - 1)), self, index); + index = Select(index < (self & ~(width - 1)), self, index); } else { CHECK_EQ(call->name, "tvm_warp_shuffle_down"); PrimExpr delta = call->args[2]; index = self + delta; - index = SelectNode::make((self & (width - 1)) + delta >= width, self, index); + index = Select((self & (width - 1)) + delta >= width, self, index); } - PrimExpr res = CallNode::make(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, - CallNode::PureExtern); + PrimExpr res = + Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern); *rv = res; } diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 7ebcfa6..45746b8 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -118,7 +118,7 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; const char* name = T()(call->dtype, call->name); - *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern); + *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 60fbde7..8453b33 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -80,7 +80,7 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; Array opencl_args{{call->args[1], call->args[2]}}; - *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); + *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 6b31bd7..a6b2547 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -43,7 +43,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index 874a512..89ff96d 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -35,8 +35,7 @@ std::pair, Map> CloneIterVars(const Array Array new_vars; Map vmap; for (const IterVar& iv : vars) { - IterVar new_v = - IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag); + IterVar new_v = IterVar(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag); new_vars.push_back(new_v); vmap.Set(iv->var, new_v->var); } @@ -54,8 +53,8 @@ PrimExpr CloneReduction(const PrimExpr& expr) { src_with_newaxis.push_back(tir::Substitute(src, vmap)); } - return ReduceNode::make(red->combiner, src_with_newaxis, new_axis, - tir::Substitute(red->condition, vmap), red->value_index); + return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap), + red->value_index); } else { return expr; } diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 4afca68..772213d 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -54,7 +54,7 @@ Tensor Identity(const Tensor& output) { res = res && (PrimExpr(input_indices[i]) == PrimExpr(input_indices[output->shape.size() + i])); } - return CastNode::make(output->dtype, res); + return Cast(output->dtype, res); }; return te::compute(shape, func, "identity"); } diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index ecddf5e..a8a9a0b 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -85,9 +85,9 @@ class JacobianMutator : public ExprMutator { CHECK_EQ(indices_.size(), op->indices.size()); PrimExpr condition = const_true(); for (size_t i = 0; i < input_.ndim(); ++i) { - condition = AndNode::make(condition, EQNode::make(indices_[i], op->indices[i])); + condition = And(condition, EQ(indices_[i], op->indices[i])); } - return CastNode::make(op->dtype, condition); + return Cast(op->dtype, condition); } else { return make_zero(op->dtype); } @@ -98,28 +98,25 @@ class JacobianMutator : public ExprMutator { if (op->call_type == CallNode::CallType::PureIntrinsic) { static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; if (op->name == "exp") { - return MulNode::make(Mutate(op->args[0]), expr); + return Mul(Mutate(op->args[0]), expr); } else if (op->name == "log") { - return DivNode::make(Mutate(op->args[0]), op->args[0]); + return Div(Mutate(op->args[0]), op->args[0]); } else if (op->name == "sigmoid") { - return MulNode::make(Mutate(op->args[0]), - MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr))); + return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); } else if (op->name == "sqrt") { - return DivNode::make(Mutate(op->args[0]), MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); + return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); } else if (op->name == "tanh") { - return MulNode::make(Mutate(op->args[0]), - SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr))); + return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); } else if (op->name == "pow") { auto x = op->args[0], y = op->args[1]; return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); } else if (op->name == "fabs") { auto type = op->args[0].dtype(); - return MulNode::make(Mutate(op->args[0]), - SelectNode::make(GENode::make(op->args[0], make_zero(type)), - FloatImm(type, 1.0), FloatImm(type, -1.0))); + return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), + FloatImm(type, 1.0), FloatImm(type, -1.0))); } else if (op->name == intrinsic::tvm_if_then_else) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return CallNode::make(op->dtype, op->name, new_args, op->call_type); + return Call(op->dtype, op->name, new_args, op->call_type); } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { @@ -129,36 +126,32 @@ class JacobianMutator : public ExprMutator { NOT_IMPLEMENTED; } - PrimExpr VisitExpr_(const AddNode* op) { return AddNode::make(Mutate(op->a), Mutate(op->b)); } + PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const SubNode* op) { return SubNode::make(Mutate(op->a), Mutate(op->b)); } + PrimExpr VisitExpr_(const SubNode* op) { return Sub(Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MulNode* op) { - return AddNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))); + return Add(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))); } PrimExpr VisitExpr_(const DivNode* op) { - return DivNode::make( - SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), - MulNode::make(op->b, op->b)); + return Div(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b)); } PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const FloorDivNode* op) { - return FloorDivNode::make( - SubNode::make(MulNode::make(Mutate(op->a), op->b), MulNode::make(op->a, Mutate(op->b))), - MulNode::make(op->b, op->b)); + return FloorDiv(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b)); } PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const MinNode* op) { - return SelectNode::make(LENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + return Select(LE(op->a, op->b), Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MaxNode* op) { - return SelectNode::make(GENode::make(op->a, op->b), Mutate(op->a), Mutate(op->b)); + return Select(GE(op->a, op->b), Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED; @@ -220,12 +213,12 @@ class JacobianMutator : public ExprMutator { for (size_t i = 0; i < new_op->combiner->lhs.size(); ++i) { PrimExpr res_di = Derivative(res, new_op->combiner->lhs[i]); // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor) - new_res = AddNode::make(new_res, MulNode::make(new_lhs[i], res_di)); + new_res = Add(new_res, Mul(new_lhs[i], res_di)); } for (size_t i = 0; i < new_op->combiner->rhs.size(); ++i) { PrimExpr res_di = Derivative(res, new_op->combiner->rhs[i]); // new_rhs[i] is the derivative of rhs[i] (wrt our input tensor) - new_res = AddNode::make(new_res, MulNode::make(new_rhs[i], res_di)); + new_res = Add(new_res, Mul(new_rhs[i], res_di)); } new_result.push_back(new_res); } @@ -252,16 +245,16 @@ class JacobianMutator : public ExprMutator { new_source.push_back(src); } - CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return analyzer_.Simplify(ReduceNode::make(new_combiner, new_source, new_op->axis, - new_op->condition, new_op->value_index)); + return analyzer_.Simplify( + Reduce(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index)); } PrimExpr VisitExpr_(const CastNode* op) { if (op->dtype.is_float()) { - return CastNode::make(op->dtype, Mutate(op->value)); + return Cast(op->dtype, Mutate(op->value)); } else { return make_zero(op->dtype); } @@ -270,7 +263,7 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const SelectNode* op) { - return SelectNode::make(op->condition, Mutate(op->true_value), Mutate(op->false_value)); + return Select(op->condition, Mutate(op->true_value), Mutate(op->false_value)); } PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED; @@ -321,7 +314,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { size_t i = 0; for (PrimExpr ext : input->shape) { IterVar new_v = - IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar); + IterVar(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar); // Append jacobian iter to new_axis new_axis.push_back(new_v); // Differentiate wrt input[input_indices] @@ -341,8 +334,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { if (const ReduceNode* red = new_body.as()) { value_index = red->value_index; for (size_t idx = 0; idx < red->source.size(); ++idx) { - new_bodies.push_back( - ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); + new_bodies.push_back(Reduce(red->combiner, red->source, red->axis, red->condition, idx)); } } else { new_bodies.push_back(new_body); diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 25715f4..66f0820 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -95,8 +95,7 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back( - IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -113,8 +112,7 @@ Array compute(Array shape, FBatchCompute fcompute, std::string for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back( - IterVarNode::make(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } @@ -275,11 +273,10 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, if (attr->dim_align_factor != 0) { Array tuple = {static_cast(i), attr->dim_align_factor, attr->dim_align_offset}; - realize = - tir::AttrStmtNode::make(t, tir::attr::buffer_dim_align, - CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), - realize); + realize = tir::AttrStmtNode::make( + t, tir::attr::buffer_dim_align, + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + realize); } } } diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index e1ef617..cd76910 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -142,7 +142,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle()); - lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); + lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; Array update_value = (*combiner)(lhs, reduces[0]->source); @@ -160,7 +160,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { DataType t = reduces[i]->dtype; - freduce_args.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } else { freduce_args.push_back(reduces[0]->source[i]); } @@ -194,7 +194,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = EvaluateNode::make(CallNode::make( + Stmt reduce_body = EvaluateNode::make(Call( DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); @@ -211,7 +211,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; assigns[idx] = ProducerStoreNode::make( - stage->op.output(idx), LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + stage->op.output(idx), Load(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); @@ -219,13 +219,13 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t idx = size; idx != 0; --idx) { body = AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, - StringImmNode::make("local"), body); + body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), + body); if (!normal_red.empty()) { body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope, - StringImmNode::make("local"), body); + StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 25a596f..75181b8 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -150,7 +150,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, } ret = AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); + Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index d0ffcfc..c927f80 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -221,8 +221,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_mapsecond; CHECK(is_const_int(outer_dom->min, 0)); - inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type); - outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); + inner = IterVar(inner_dom, inner_->var, inner_->iter_type); + outer = IterVar(outer_dom, outer_->var, outer_->iter_type); } Stmt VisitStmt_(const ForNode* op) final { @@ -447,7 +447,7 @@ std::vector GatherLoopVars(Stmt stmt) { if (const ForNode* op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); - res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type))); + res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type))); } }); std::reverse(res_.begin(), res_.end()); diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 4e6c824..675954a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -87,8 +87,8 @@ Operation ScanOpNode::make(std::string name, std::string tag, Mapspatial_axis_.push_back(IterVarNode::make( - Range::make_by_min_extent(0, update[i]->shape[k]), Var(spatial_name.str()), kOpaque)); + n->spatial_axis_.push_back(IterVar(Range::make_by_min_extent(0, update[i]->shape[k]), + Var(spatial_name.str()), kOpaque)); } } @@ -112,9 +112,9 @@ TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make); Array scan(Array init, Array update, Array state_placeholder, Array inputs, std::string name, std::string tag, Map attrs) { - IterVar scan_axis = IterVarNode::make( - Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), - Var(name + ".idx"), kOrdered); + IterVar scan_axis = + IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), + Var(name + ".idx"), kOrdered); Operation op = ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 96ddb36..f9e0c8d 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -146,8 +146,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), - nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding @@ -171,8 +170,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), - nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index ddc0595..224907d 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -192,7 +192,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis.push_back(it->second); } } - return ReduceNode::make(op->combiner, op->source, axis, op->condition, op->value_index); + return Reduce(op->combiner, op->source, axis, op->condition, op->value_index); } void Init(const ComputeOpNode* self, const Stage& stage, @@ -370,8 +370,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, } input_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), - nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -391,8 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmtNode::make( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), - nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index 8a130e9..13b601a 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -58,7 +58,7 @@ class OperationInliner final : public StmtExprMutator { } if (has_side_effect) { for (size_t i = 0; i < args_.size(); ++i) { - expr = LetNode::make(args_[i], op->indices[i], expr); + expr = Let(args_[i], op->indices[i], expr); } } else { Map vmap; diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index cfd8b26..009d74f 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -64,7 +64,7 @@ class VarReplacer : public tir::StmtExprMutator { combiner->identity_element.same_as(new_result)) { return combiner; } else { - return tir::CommReducerNode::make(combiner->lhs, combiner->rhs, new_result, new_identity); + return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity); } } @@ -75,8 +75,8 @@ class VarReplacer : public tir::StmtExprMutator { if (op->combiner.same_as(new_combiner)) { return new_e; } else { - return tir::ReduceNode::make(new_combiner, new_reduce->source, new_reduce->axis, - new_reduce->condition, new_reduce->value_index); + return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition, + new_reduce->value_index); } } @@ -96,7 +96,7 @@ PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { n->condition = foldl(fand, n->condition, predicates); return PrimExpr(n); } - return SelectNode::make(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype())); + return Select(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype())); } // Replace data flow appears in all stages given the tensor change. @@ -204,7 +204,7 @@ void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_setiter_type, kDataPar) << "Can only relayout with in data parallel dimensions"; Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make(dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { value_map[iv] = dom->min; @@ -300,9 +300,8 @@ Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_a const tir::ReduceNode* reduce_body = body.as(); if (first_reduce != nullptr) { CHECK(ReduceEqual(reduce_body, first_reduce)); - body = - tir::ReduceNode::make(first_reduce->combiner, first_reduce->source, first_reduce->axis, - first_reduce->condition, reduce_body->value_index); + body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis, + first_reduce->condition, reduce_body->value_index); } else { first_reduce = reduce_body; } @@ -362,7 +361,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar new_iv = IterVarNode::make(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); } Array new_regions; @@ -390,7 +389,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te Array compute_axis = tensor_op->axis; for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + IterVar aiv = IterVar(iv->dom, iv->var, kDataPar); compute_axis.Set(i, aiv); } @@ -468,7 +467,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { } if (idx < leaf_vars->size()) { // insert rebase - IterVar rebased = IterVarNode::make(Range(), iv->var.copy_with_suffix(""), iv->iter_type); + IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type); s->relations.push_back(RebaseNode::make(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); @@ -741,8 +740,7 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { - body.emplace_back( - ReduceNode::make(reduce->combiner, new_source, n->reduce_axis, new_pred, idx)); + body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx)); } n->body = Array(body); // refresh relations, keep the un-touched relations. @@ -806,7 +804,7 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f Array axis = {repl_red_axis}; PrimExpr cond = const_true(); for (int idx = 0; idx < size; ++idx) { - reductions.push_back(ReduceNode::make(reduce->combiner, factor_exprs, axis, cond, idx)); + reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx)); } return reductions; }, diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 9dc8269..24d9102 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -61,10 +61,8 @@ void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, It CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) << "Cannot split on " << IterVarType2String(parent->iter_type); - IterVar outer = - IterVarNode::make(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); - IterVar inner = - IterVarNode::make(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); + IterVar outer = IterVar(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); + IterVar inner = IterVar(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); *p_outer = outer; *p_inner = inner; // The splits @@ -231,7 +229,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT if (inner->iter_type > iter_type) iter_type = inner->iter_type; std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; - IterVar fused = IterVarNode::make(Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVar(Range(), Var(fused_name, outer->var.dtype()), iter_type); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; @@ -263,8 +261,8 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* StageNode* self = operator->(); // special handle fuse empty array. // insert at the outer most loop - IterVar singleton = IterVarNode::make(Range::make_by_min_extent(0, 1), - Var("singleton", DataType::Int(32)), kDataPar); + IterVar singleton = + IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); self->relations.push_back(SingletonNode::make(singleton)); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; @@ -370,7 +368,7 @@ Stage& Stage::pragma(IterVar var, const std::string& pragma_type, this->vectorize(var); } else { UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { - n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); + n->pragma_keys.push_back(tir::StringImm(pragma_type)); n->pragma_values.push_back(pragma_value); }); } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 228ce45..f5ba43c 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -53,8 +53,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. - pipeline = - AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImmNode::make(s->scope), pipeline); + pipeline = AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); return pipeline; } diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index da45e8a..e81ad2c 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -458,7 +458,7 @@ class BufferAnalyser : public StmtExprVisitor { for (size_t i = 1; i < bi.shape.size(); ++i) { PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, bi.shape[j]); + stride = Mul(stride, bi.shape[j]); } strides.push_back(stride); } @@ -560,7 +560,7 @@ class BufferAnalyser : public StmtExprVisitor { for (size_t i = 1; i < bi.shape.size(); ++i) { PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, bi.shape[j]); + stride = Mul(stride, bi.shape[j]); } strides.push_back(stride); } @@ -752,8 +752,8 @@ class ThreadIdxMutator : public StmtExprMutator { return zero; } if (op->name_hint == "threadIdx.y") { - PrimExpr div = DivNode::make(expr, warp_y_); - PrimExpr mul = MulNode::make(div, warp_y_); + PrimExpr div = Div(expr, warp_y_); + PrimExpr mul = Mul(div, warp_y_); return mul; } } @@ -819,7 +819,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto it = matrix_abc_.find(simplify_name(node->name)); CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); + auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); Stmt body = this->VisitStmt(op->body); return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body); } @@ -847,17 +847,17 @@ class TensorCoreIRMutator : public StmtExprMutator { Buffer buffer_a(buffer_node_a); Buffer buffer_b(buffer_node_b); if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return EvaluateNode::make(CallNode::make( - DataType::Handle(), intrinsic::tvm_bmma_sync, - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); + return EvaluateNode::make( + Call(DataType::Handle(), intrinsic::tvm_bmma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); } else { - return EvaluateNode::make(CallNode::make( - DataType::Handle(), intrinsic::tvm_mma_sync, - {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); + return EvaluateNode::make( + Call(DataType::Handle(), intrinsic::tvm_mma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); } }; @@ -879,10 +879,10 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return EvaluateNode::make(CallNode::make(DataType::Handle(), intrinsic::tvm_fill_fragment, - {buffer->data, warp_tile_.m, warp_tile_.n, - warp_tile_.k, buffer->elem_offset, op->value}, - CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, op->value}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -902,7 +902,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = CallNode::make(value->dtype, "&", {mutated_value}, CallNode::Extern); + PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern); auto pload = dst.as(); PrimExpr matrix_major; @@ -910,19 +910,18 @@ class TensorCoreIRMutator : public StmtExprMutator { CHECK(iter2 != matrix_major_.end()) << "Can not determine matrix major for " << pload->producer->GetNameHint(); if (iter2->second == "col_major") { - matrix_major = StringImmNode::make("col_major"); + matrix_major = StringImm("col_major"); } else if (iter2->second == "row_major") { - matrix_major = StringImmNode::make("row_major"); + matrix_major = StringImm("row_major"); } else { LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), intrinsic::tvm_load_matrix_sync, - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -942,16 +941,15 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = CallNode::make(DataType::Handle(), "&", {dst}, CallNode::Extern); + dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern); auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), intrinsic::tvm_store_matrix_sync, - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImmNode::make("col_major")}, - CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, StringImm("col_major")}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -1037,7 +1035,7 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 1; i < shape.size(); ++i) { PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, shape[j]); + stride = Mul(stride, shape[j]); } strides.push_back(stride); } @@ -1046,8 +1044,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr elem_offset = IntImm(DataType::Int(32), 0); CHECK_EQ(pload->indices.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = AddNode::make( - elem_offset, MulNode::make(strides[i], SubNode::make(pload->indices[i], min_bound[i]))); + elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); } auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); @@ -1068,8 +1065,7 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = - CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 1a31a85..7e7f648 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -31,12 +31,10 @@ namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { - return IterVarNode::make(dom, Var(tag), kThreadIndex, tag); + return IterVar(dom, Var(tag), kThreadIndex, tag); } -IterVar reduce_axis(Range dom, std::string name) { - return IterVarNode::make(dom, Var(name), kCommReduce); -} +IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name), kCommReduce); } Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } @@ -111,19 +109,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorIntrinNode); // TensorIntrinCall - -TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, - Array regions, Array reduce_axis, - Array scalar_inputs) { +TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array tensors, + Array regions, Array reduce_axis, + Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); n->scalar_inputs = std::move(scalar_inputs); - return TensorIntrinCall(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.TensorIntrinCall") + .set_body_typed([](TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs) { + return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* n = static_cast(node.get()); @@ -136,8 +139,6 @@ TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make); TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make); -TVM_REGISTER_GLOBAL("te.TensorIntrinCall").set_body_typed(TensorIntrinCallNode::make); - TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 3a60521..4c5b30f 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -275,7 +275,7 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp offset = offset * make_const(offset.dtype(), dtype.lanes()); } if (dtype.lanes() != 1) { - return tir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } else { return offset; } @@ -287,13 +287,11 @@ PrimExpr Buffer::vload(Array begin, DataType dtype) const { CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { - return tir::CastNode::make( - DataType::Bool(), - tir::LoadNode::make(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); + return tir::Cast(DataType::Bool(), + tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), + const_true())); } else { - return tir::LoadNode::make(dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } @@ -304,7 +302,7 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return tir::StoreNode::make(n->data, tir::CastNode::make(DataType::Int(8), value), + return tir::StoreNode::make(n->data, tir::Cast(DataType::Int(8), value), BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), @@ -379,8 +377,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::CallNode::make(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, - tir::CallNode::Intrinsic); + return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array strides, diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 6c38982..1f17c35 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -106,14 +106,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*) << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), - tir::kDataPar); + IterVar axis = + IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " for dimension " << c; - IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), - tir::kDataPar); + IterVar axis = + IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -172,8 +172,8 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { - new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(axis.ToSubordinate().name()), tir::kDataPar)); + new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), + Var(axis.ToSubordinate().name()), tir::kDataPar)); } if (i == this->ndim()) break; new_layout.push_back(axes[i]); @@ -323,7 +323,7 @@ inline Array TransformShape(const Array& src_shape, result.push_back(axis->dom->extent); } else { if (symbolic_var_set.count(i)) { - result.push_back(tir::AnyNode::make()); + result.push_back(tir::Any()); } else { result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index e1d8b3f..12df05e 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include #include @@ -34,6 +33,33 @@ namespace tvm { namespace tir { +#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b) { \ + using T = Name::ContainerType; \ + CHECK(a.defined()) << "ValueError: a is undefined\n"; \ + CHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = a.dtype(); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + data_ = std::move(node); \ + } + +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b) { \ + using T = Name::ContainerType; \ + CHECK(a.defined()) << "ValueError: a is undefined\n"; \ + CHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = DataType::Bool(a.dtype().lanes()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + data_ = std::move(node); \ + } + +// Var Var::Var(String name_hint, DataType dtype) { auto n = make_object(); n->name_hint = std::move(name_hint); @@ -61,13 +87,6 @@ Var Var::copy_with_suffix(const String& suffix) const { return Var(new_ptr); } -SizeVar::SizeVar(String name_hint, DataType dtype) { - auto n = make_object(); - n->name_hint = std::move(name_hint); - n->dtype = std::move(dtype); - data_ = std::move(n); -} - TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) { if (type.IsObjectRef()) { return Var(name_hint, type.operator Type()); @@ -76,22 +95,49 @@ TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMA } }); +TVM_REGISTER_NODE_TYPE(VarNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + p->stream << op->name_hint; + }); + +// SizeVar +SizeVar::SizeVar(String name_hint, DataType dtype) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) { return SizeVar(s, t); }); -IterVar IterVarNode::make(Range dom, Var var, IterVarType t, std::string thread_tag) { +TVM_REGISTER_NODE_TYPE(SizeVarNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }); + +// IterVar +IterVar::IterVar(Range dom, Var var, IterVarType t, std::string thread_tag) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; n->iter_type = t; n->thread_tag = thread_tag; - return IterVar(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.IterVar") .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { - return IterVarNode::make(dom, var, static_cast(iter_type), thread_tag); + return IterVar(dom, var, static_cast(iter_type), thread_tag); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -112,367 +158,173 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(IterVarNode); -PrimExpr StringImmNode::make(std::string value) { +// StringImm +StringImm::StringImm(std::string value) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); - return PrimExpr(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed(StringImmNode::make); +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](std::string value) { + return StringImm(value); +}); + +TVM_REGISTER_NODE_TYPE(StringImmNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '\"' << support::StrEscape(op->value) << '\"'; + }); -PrimExpr CastNode::make(DataType t, PrimExpr value) { +// Cast +Cast::Cast(DataType t, PrimExpr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); - return PrimExpr(node); + data_ = std::move(node); } -PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined"; - CHECK(b.defined()) << "ValueError: b is undefined"; - CHECK(a.dtype().is_bool()); - CHECK(b.dtype().is_bool()); - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; +TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value) { + return Cast(dtype, value); +}); - ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); -} +TVM_REGISTER_NODE_TYPE(CastNode); -PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined"; - CHECK(b.defined()) << "ValueError: b is undefined"; - CHECK(a.dtype().is_bool()); - CHECK(b.dtype().is_bool()); - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dtype << '('; + p->Print(op->value); + p->stream << ')'; + }); - ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); -} +// Add +TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -PrimExpr NotNode::make(PrimExpr a) { - CHECK(a.defined()) << "ValueError: a is undefined"; - CHECK(a.dtype().is_bool()); +TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b) { return Add(a, b); }); - ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); - node->a = std::move(a); - return PrimExpr(node); -} +TVM_REGISTER_NODE_TYPE(AddNode); -PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { - CHECK(condition.defined()) << "ValueError: condition is undefined"; - CHECK(true_value.defined()) << "ValueError: true_value is undefined"; - CHECK(false_value.defined()) << "ValueError: true_value is undefined"; - CHECK(condition.dtype().is_bool()); - CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); - CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " + "; + p->Print(op->b); + p->stream << ')'; + }); - ObjectPtr node = make_object(); - node->dtype = true_value.dtype(); - node->condition = std::move(condition); - node->true_value = std::move(true_value); - node->false_value = std::move(false_value); - return PrimExpr(node); -} +// Sub +TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { - CHECK(buffer_var.defined()); - CHECK(predicate.defined()); - CHECK(index.defined()); - CHECK_EQ(dtype.lanes(), index.dtype().lanes()); - CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); +TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b) { return Sub(a, b); }); - ObjectPtr node = make_object(); - node->dtype = dtype; - node->buffer_var = std::move(buffer_var); - node->index = std::move(index); - node->predicate = std::move(predicate); +TVM_REGISTER_NODE_TYPE(SubNode); - return PrimExpr(node); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " - "; + p->Print(op->b); + p->stream << ')'; + }); -PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) { - CHECK(base.defined()); - CHECK(stride.defined()); - CHECK(base.dtype().is_scalar()); - CHECK(stride.dtype().is_scalar()); - CHECK_GT(lanes, 1); - CHECK_EQ(stride.dtype(), base.dtype()); +// Mul +TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); - ObjectPtr node = make_object(); - node->dtype = base.dtype().with_lanes(lanes); - node->base = base; - node->stride = stride; - node->lanes = lanes; - return PrimExpr(node); -} +TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b) { return Mul(a, b); }); -PrimExpr BroadcastNode::make(PrimExpr value, int lanes) { - CHECK(value.defined()); - CHECK(value.dtype().is_scalar()); - CHECK_GT(lanes, 1); +TVM_REGISTER_NODE_TYPE(MulNode); - ObjectPtr node = make_object(); - node->dtype = value.dtype().with_lanes(lanes); - node->value = std::move(value); - node->lanes = lanes; - return PrimExpr(node); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "*"; + p->Print(op->b); + p->stream << ')'; + }); -PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { - CHECK(value.defined()); - CHECK(body.defined()); - CHECK_EQ(value.dtype(), var.dtype()); +// Div +TVM_DEFINE_BINOP_CONSTRUCTOR(Div); - ObjectPtr node = make_object(); - node->dtype = body.dtype(); - node->var = std::move(var); - node->value = std::move(value); - node->body = std::move(body); - return PrimExpr(node); -} +TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b) { return Div(a, b); }); -const char* CallNode::vectorizable_intrinsics[] = {"floor", - "ceil", - "sign", - "trunc", - "fabs", - "round", - "exp", - "tanh", - "sqrt", - "log", - "sin", - "cos", - "pow", - "tan", - tir::CallNode::shift_left, - tir::CallNode::shift_right, - tir::CallNode::likely, - tir::CallNode::popcount}; +TVM_REGISTER_NODE_TYPE(DivNode); -bool CallNode::is_vectorizable() const { - size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); - for (size_t i = 0; i < cnt; ++i) { - if (name == CallNode::vectorizable_intrinsics[i]) { - return true; - } - } - return false; -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "/"; + p->Print(op->b); + p->stream << ')'; + }); -PrimExpr CallNode::make(DataType dtype, std::string name, Array args, - CallType call_type) { - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].defined()); - } +// Mod +TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); - ObjectPtr node = make_object(); - node->dtype = dtype; - node->name = std::move(name); - node->args = std::move(args); - node->call_type = call_type; - return PrimExpr(node); -} +TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b) { return Mod(a, b); }); -PrimExpr ShuffleNode::make(Array vectors, Array indices) { - CHECK_NE(vectors.size(), 0U); - CHECK_NE(indices.size(), 0U); +TVM_REGISTER_NODE_TYPE(ModNode); - DataType base_type = vectors[0].dtype().element_of(); - int total_lanes = 0; +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " % "; + p->Print(op->b); + p->stream << ')'; + }); - for (PrimExpr val : vectors) { - CHECK(val.dtype().element_of() == base_type); - total_lanes += val.dtype().lanes(); - } - CHECK_LE(indices.size(), static_cast(total_lanes)); +// FloorDiv +TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); - ObjectPtr node = make_object(); - node->dtype = base_type.with_lanes(static_cast(indices.size())); - node->vectors = std::move(vectors); - node->indices = std::move(indices); - return PrimExpr(node); -} +TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b) { + return FloorDiv(a, b); +}); -PrimExpr ShuffleNode::make_concat(Array vectors) { - CHECK_NE(vectors.size(), 0); - if (vectors.size() == 1) { - return vectors[0]; - } - Array indices; - int index = 0; - for (const PrimExpr& e : vectors) { - for (int i = 0; i < e.dtype().lanes(); ++i) { - indices.push_back(IntImm(DataType::Int(32), index++)); - } - } - return make(vectors, indices); -} +TVM_REGISTER_NODE_TYPE(FloorDivNode); -PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) { - return make({vector}, {Integer(index)}); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floordiv(" << op->a << ", " << op->b << ")"; + }); -CommReducer CommReducerNode::make(Array lhs, Array rhs, Array result, - Array identity_element) { - auto node = make_object(); - node->lhs = lhs; - node->rhs = rhs; - node->result = result; - node->identity_element = identity_element; - return CommReducer(node); -} - -Array CommReducerNode::operator()(Array a, Array b) const { - CHECK_EQ(a.size(), b.size()); - CHECK_EQ(lhs.size(), a.size()); - CHECK_EQ(rhs.size(), b.size()); - Map value_map; - for (size_t i = 0; i < a.size(); ++i) { - value_map.Set(lhs[i], a[i]); - value_map.Set(rhs[i], b[i]); - } - auto ret = this->result; - ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); - return ret; -} - -TVM_REGISTER_GLOBAL("tir.CommReducer").set_body_typed(CommReducerNode::make); - -TVM_REGISTER_GLOBAL("tir.CommReducerCombine") - .set_body_method(&tir::CommReducerNode::operator()); - -PrimExpr ReduceNode::make(CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index) { - for (size_t i = 0; i < axis.size(); ++i) { - CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; - } - if (!condition.defined()) { - condition = const_true(); - } - auto n = make_object(); - CHECK(source.defined()); - for (size_t i = 0; i < axis.size(); ++i) { - CHECK(axis[i].defined()); - } - n->dtype = source[value_index].dtype(); - n->combiner = std::move(combiner); - n->source = std::move(source); - n->axis = std::move(axis); - n->condition = condition; - n->value_index = value_index; - return PrimExpr(n); -} - -TVM_REGISTER_GLOBAL("tir.Reduce").set_body_typed(ReduceNode::make); - -PrimExpr AnyNode::make() { - auto n = make_object(); - return PrimExpr(n); -} - -BufferLoad::BufferLoad(Buffer buffer, Array indices) { - ObjectPtr node = make_object(); - node->dtype = buffer->dtype; - node->buffer = std::move(buffer); - node->indices = std::move(indices); - data_ = std::move(node); -} +// FloorMod +TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { - return BufferLoad(buffer, indices); +TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b) { + return FloorMod(a, b); }); -TVM_REGISTER_NODE_TYPE(BufferLoadNode); - -ProducerLoad::ProducerLoad(DataProducer producer, Array indices) { - ObjectPtr node = make_object(); - node->dtype = producer->GetDataType(); - node->producer = std::move(producer); - node->indices = std::move(indices); - data_ = std::move(node); -} +TVM_REGISTER_NODE_TYPE(FloorModNode); -TVM_REGISTER_GLOBAL("tir.ProducerLoad") - .set_body_typed([](DataProducer producer, Array indices) { - return ProducerLoad(producer, indices); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); -TVM_REGISTER_NODE_TYPE(ProducerLoadNode); +// Min +TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '\"' << support::StrEscape(op->value) << '\"'; - }); +TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b) { return Min(a, b); }); + +TVM_REGISTER_NODE_TYPE(MinNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dtype << '('; - p->Print(op->value); - p->stream << ')'; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - p->stream << op->name_hint; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " + "; - p->Print(op->b); - p->stream << ')'; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " - "; - p->Print(op->b); - p->stream << ')'; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "*"; - p->Print(op->b); - p->stream << ')'; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "/"; - p->Print(op->b); - p->stream << ')'; - }) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " % "; - p->Print(op->b); - p->stream << ')'; - }) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "min("; @@ -480,7 +332,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", "; p->Print(op->b); p->stream << ")"; - }) + }); + +// Max +TVM_DEFINE_BINOP_CONSTRUCTOR(Max); + +TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b) { return Max(a, b); }); + +TVM_REGISTER_NODE_TYPE(MaxNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "max("; @@ -488,7 +349,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", "; p->Print(op->b); p->stream << ")"; - }) + }); + +// EQ +TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); + +TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b) { return EQ(a, b); }); + +TVM_REGISTER_NODE_TYPE(EQNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -496,7 +366,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << " == "; p->Print(op->b); p->stream << ')'; - }) + }); + +// NE +TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); + +TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b) { return NE(a, b); }); + +TVM_REGISTER_NODE_TYPE(NENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -504,7 +383,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << " != "; p->Print(op->b); p->stream << ')'; - }) + }); + +// LT +TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); + +TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b) { return LT(a, b); }); + +TVM_REGISTER_NODE_TYPE(LTNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -512,7 +400,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << " < "; p->Print(op->b); p->stream << ')'; - }) + }); + +// LE +TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); + +TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b) { return LE(a, b); }); + +TVM_REGISTER_NODE_TYPE(LENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -520,7 +417,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << " <= "; p->Print(op->b); p->stream << ')'; - }) + }); + +// GT +TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); + +TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b) { return GT(a, b); }); + +TVM_REGISTER_NODE_TYPE(GTNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -528,7 +434,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << " > "; p->Print(op->b); p->stream << ')'; - }) + }); + +// GE +TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); + +TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b) { return GE(a, b); }); + +TVM_REGISTER_NODE_TYPE(GENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '('; @@ -538,17 +453,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floordiv(" << op->a << ", " << op->b << ")"; - }); +// And +And::And(PrimExpr a, PrimExpr b) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(b.defined()) << "ValueError: b is undefined"; + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floormod(" << op->a << ", " << op->b << ")"; - }); + ObjectPtr node = make_object(); + node->dtype = DataType::Bool(a.dtype().lanes()); + node->a = std::move(a); + node->b = std::move(b); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b) { return And(a, b); }); + +TVM_REGISTER_NODE_TYPE(AndNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -560,6 +482,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +// Or +Or::Or(PrimExpr a, PrimExpr b) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(b.defined()) << "ValueError: b is undefined"; + CHECK(a.dtype().is_bool()); + CHECK(b.dtype().is_bool()); + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; + + ObjectPtr node = make_object(); + node->dtype = DataType::Bool(a.dtype().lanes()); + node->a = std::move(a); + node->b = std::move(b); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b) { return Or(a, b); }); + +TVM_REGISTER_NODE_TYPE(OrNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -570,6 +511,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +// Not +Not::Not(PrimExpr a) { + CHECK(a.defined()) << "ValueError: a is undefined"; + CHECK(a.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->dtype = DataType::Bool(a.dtype().lanes()); + node->a = std::move(a); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a) { return Not(a); }); + +TVM_REGISTER_NODE_TYPE(NotNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -577,6 +533,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->a); }); +// Select +Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { + CHECK(condition.defined()) << "ValueError: condition is undefined"; + CHECK(true_value.defined()) << "ValueError: true_value is undefined"; + CHECK(false_value.defined()) << "ValueError: true_value is undefined"; + CHECK(condition.dtype().is_bool()); + CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); + CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; + + ObjectPtr node = make_object(); + node->dtype = true_value.dtype(); + node->condition = std::move(condition); + node->true_value = std::move(true_value); + node->false_value = std::move(false_value); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Select") + .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { + return Select(condition, true_value, false_value); + }); + +TVM_REGISTER_NODE_TYPE(SelectNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -589,18 +569,69 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "]"; - if (!is_one(op->predicate)) { - p->stream << " if "; +// Load +Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { + CHECK(buffer_var.defined()); + CHECK(predicate.defined()); + CHECK(index.defined()); + CHECK_EQ(dtype.lanes(), index.dtype().lanes()); + CHECK_EQ(dtype.lanes(), predicate.dtype().lanes()); + + ObjectPtr node = make_object(); + node->dtype = dtype; + node->buffer_var = std::move(buffer_var); + node->index = std::move(index); + node->predicate = std::move(predicate); + + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { + DataType t = args[0]; + if (args.size() == 3) { + *ret = Load(t, args[1], args[2], const_true(t.lanes())); + } else { + *ret = Load(t, args[1], args[2], args[3]); + } +}); + +TVM_REGISTER_NODE_TYPE(LoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "]"; + if (!is_one(op->predicate)) { + p->stream << " if "; p->Print(op->predicate); } }); +// Ramp +Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes) { + CHECK(base.defined()); + CHECK(stride.defined()); + CHECK(base.dtype().is_scalar()); + CHECK(stride.dtype().is_scalar()); + CHECK_GT(lanes, 1); + CHECK_EQ(stride.dtype(), base.dtype()); + + ObjectPtr node = make_object(); + node->dtype = base.dtype().with_lanes(lanes); + node->base = base; + node->stride = stride; + node->lanes = lanes; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed([](PrimExpr base, PrimExpr stride, int lanes) { + return Ramp(base, stride, lanes); +}); + +TVM_REGISTER_NODE_TYPE(RampNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -611,6 +642,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ", " << op->lanes << ")"; }); +// Broadcast +Broadcast::Broadcast(PrimExpr value, int lanes) { + CHECK(value.defined()); + CHECK(value.dtype().is_scalar()); + CHECK_GT(lanes, 1); + + ObjectPtr node = make_object(); + node->dtype = value.dtype().with_lanes(lanes); + node->value = std::move(value); + node->lanes = lanes; + data_ = node; +} + +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes) { + return Broadcast(value, lanes); +}); + +TVM_REGISTER_NODE_TYPE(BroadcastNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -619,44 +669,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->name << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - p->stream << ", "; - } - } - p->stream << ")"; - }); +// Let +Let::Let(Var var, PrimExpr value, PrimExpr body) { + CHECK(value.defined()); + CHECK(body.defined()); + CHECK_EQ(value.dtype(), var.dtype()); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - p->stream << ", "; - } - } - p->stream << "]"; - }); + ObjectPtr node = make_object(); + node->dtype = body.dtype(); + node->var = std::move(var); + node->value = std::move(value); + node->body = std::move(body); + data_ = std::move(node); +} -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - p->stream << ", "; - } - } - p->stream << "]"; - }); +TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body) { + return Let(var, value, body); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -668,128 +697,275 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); +// Call +Call::Call(DataType dtype, std::string name, Array args, CallType call_type) { + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].defined()); + } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "reduce(combiner=" << op->combiner; - p->stream << ", source=" << op->source; - p->stream << ", axis=" << op->axis; - p->stream << ", where=" << op->condition; - p->stream << ", value_index=" << op->value_index; - p->stream << ")"; - }); + ObjectPtr node = make_object(); + node->dtype = dtype; + node->name = std::move(name); + node->args = std::move(args); + node->call_type = call_type; + data_ = std::move(node); +} -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs - << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")"; +const char* CallNode::vectorizable_intrinsics[] = {"floor", + "ceil", + "sign", + "trunc", + "fabs", + "round", + "exp", + "tanh", + "sqrt", + "log", + "sin", + "cos", + "pow", + "tan", + tir::CallNode::shift_left, + tir::CallNode::shift_right, + tir::CallNode::likely, + tir::CallNode::popcount}; + +bool CallNode::is_vectorizable() const { + size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); + for (size_t i = 0; i < cnt; ++i) { + if (name == CallNode::vectorizable_intrinsics[i]) { + return true; + } + } + return false; +} + +TVM_REGISTER_GLOBAL("tir.Call") + .set_body_typed([](DataType type, std::string name, Array args, int call_type) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImm(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } + return Call(type, name, prim_expr_args, static_cast(call_type)); }); -TVM_REGISTER_NODE_TYPE(StringImmNode); -TVM_REGISTER_NODE_TYPE(CastNode); -TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_NODE_TYPE(SizeVarNode); -TVM_REGISTER_NODE_TYPE(AddNode); -TVM_REGISTER_NODE_TYPE(SubNode); -TVM_REGISTER_NODE_TYPE(MulNode); -TVM_REGISTER_NODE_TYPE(DivNode); -TVM_REGISTER_NODE_TYPE(ModNode); -TVM_REGISTER_NODE_TYPE(FloorDivNode); -TVM_REGISTER_NODE_TYPE(FloorModNode); -TVM_REGISTER_NODE_TYPE(MinNode); -TVM_REGISTER_NODE_TYPE(MaxNode); -TVM_REGISTER_NODE_TYPE(EQNode); -TVM_REGISTER_NODE_TYPE(NENode); -TVM_REGISTER_NODE_TYPE(LTNode); -TVM_REGISTER_NODE_TYPE(LENode); -TVM_REGISTER_NODE_TYPE(GTNode); -TVM_REGISTER_NODE_TYPE(GENode); -TVM_REGISTER_NODE_TYPE(AndNode); -TVM_REGISTER_NODE_TYPE(OrNode); -TVM_REGISTER_NODE_TYPE(NotNode); -TVM_REGISTER_NODE_TYPE(SelectNode); -TVM_REGISTER_NODE_TYPE(LoadNode); -TVM_REGISTER_NODE_TYPE(RampNode); -TVM_REGISTER_NODE_TYPE(BroadcastNode); -TVM_REGISTER_NODE_TYPE(ShuffleNode); -TVM_REGISTER_NODE_TYPE(CommReducerNode); -TVM_REGISTER_NODE_TYPE(ReduceNode); -TVM_REGISTER_NODE_TYPE(AnyNode); +TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_GLOBAL("tir.Add").set_body_typed(AddNode::make); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->name << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + p->stream << ", "; + } + } + p->stream << ")"; + }); -TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed(SubNode::make); +// Shuffle +Shuffle::Shuffle(Array vectors, Array indices) { + CHECK_NE(vectors.size(), 0U); + CHECK_NE(indices.size(), 0U); -TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed(MulNode::make); + DataType base_type = vectors[0].dtype().element_of(); + int total_lanes = 0; -TVM_REGISTER_GLOBAL("tir.Div").set_body_typed(DivNode::make); + for (PrimExpr val : vectors) { + CHECK(val.dtype().element_of() == base_type); + total_lanes += val.dtype().lanes(); + } + CHECK_LE(indices.size(), static_cast(total_lanes)); -TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed(ModNode::make); + ObjectPtr node = make_object(); + node->dtype = base_type.with_lanes(static_cast(indices.size())); + node->vectors = std::move(vectors); + node->indices = std::move(indices); + data_ = node; +} -TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed(FloorDivNode::make); +PrimExpr Shuffle::Concat(Array vectors) { + CHECK_NE(vectors.size(), 0); + if (vectors.size() == 1) { + return vectors[0]; + } + Array indices; + int index = 0; + for (const PrimExpr& e : vectors) { + for (int i = 0; i < e.dtype().lanes(); ++i) { + indices.push_back(IntImm(DataType::Int(32), index++)); + } + } + return Shuffle(vectors, indices); +} -TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed(FloorModNode::make); +PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index) { + return Shuffle({vector}, {Integer(index)}); +} -TVM_REGISTER_GLOBAL("tir.Min").set_body_typed(MinNode::make); +TVM_REGISTER_GLOBAL("tir.Shuffle") + .set_body_typed([](Array vectors, Array indices) { + return Shuffle(vectors, indices); + }); -TVM_REGISTER_GLOBAL("tir.Max").set_body_typed(MaxNode::make); +TVM_REGISTER_NODE_TYPE(ShuffleNode); -TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed(EQNode::make); +// CommReducer +CommReducer::CommReducer(Array lhs, Array rhs, Array result, + Array identity_element) { + auto node = make_object(); + node->lhs = lhs; + node->rhs = rhs; + node->result = result; + node->identity_element = identity_element; + data_ = std::move(node); +} -TVM_REGISTER_GLOBAL("tir.NE").set_body_typed(NENode::make); +Array CommReducerNode::operator()(Array a, Array b) const { + CHECK_EQ(a.size(), b.size()); + CHECK_EQ(lhs.size(), a.size()); + CHECK_EQ(rhs.size(), b.size()); + Map value_map; + for (size_t i = 0; i < a.size(); ++i) { + value_map.Set(lhs[i], a[i]); + value_map.Set(rhs[i], b[i]); + } + auto ret = this->result; + ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); + return ret; +} -TVM_REGISTER_GLOBAL("tir.LT").set_body_typed(LTNode::make); +TVM_REGISTER_GLOBAL("tir.CommReducer") + .set_body_typed([](Array lhs, Array rhs, Array result, + Array identity_element) { + return CommReducer(lhs, rhs, result, identity_element); + }); -TVM_REGISTER_GLOBAL("tir.LE").set_body_typed(LENode::make); +TVM_REGISTER_GLOBAL("tir.CommReducerCombine") + .set_body_method(&tir::CommReducerNode::operator()); -TVM_REGISTER_GLOBAL("tir.GT").set_body_typed(GTNode::make); +TVM_REGISTER_NODE_TYPE(CommReducerNode); -TVM_REGISTER_GLOBAL("tir.GE").set_body_typed(GENode::make); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs + << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")"; + }); -TVM_REGISTER_GLOBAL("tir.And").set_body_typed(AndNode::make); +// Reduce +Reduce::Reduce(CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { + for (size_t i = 0; i < axis.size(); ++i) { + CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; + } + if (!condition.defined()) { + condition = const_true(); + } + auto n = make_object(); + CHECK(source.defined()); + for (size_t i = 0; i < axis.size(); ++i) { + CHECK(axis[i].defined()); + } + n->dtype = source[value_index].dtype(); + n->combiner = std::move(combiner); + n->source = std::move(source); + n->axis = std::move(axis); + n->condition = condition; + n->value_index = value_index; + data_ = std::move(n); +} -TVM_REGISTER_GLOBAL("tir.Or").set_body_typed(OrNode::make); +TVM_REGISTER_GLOBAL("tir.Reduce") + .set_body_typed([](CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { + return Reduce(combiner, source, axis, condition, value_index); + }); -TVM_REGISTER_GLOBAL("tir.Not").set_body_typed(NotNode::make); +TVM_REGISTER_NODE_TYPE(ReduceNode); -TVM_REGISTER_GLOBAL("tir.Select").set_body_typed(SelectNode::make); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "reduce(combiner=" << op->combiner; + p->stream << ", source=" << op->source; + p->stream << ", axis=" << op->axis; + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; + p->stream << ")"; + }); -TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed(RampNode::make); +// Any +Any::Any() { data_ = make_object(); } -TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed(CastNode::make); +TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); }); -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed(BroadcastNode::make); +TVM_REGISTER_NODE_TYPE(AnyNode); -TVM_REGISTER_GLOBAL("tir.Shuffle").set_body_typed(ShuffleNode::make); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); -TVM_REGISTER_GLOBAL("tir.Let").set_body_typed(LetNode::make); +// BufferLoad +BufferLoad::BufferLoad(Buffer buffer, Array indices) { + ObjectPtr node = make_object(); + node->dtype = buffer->dtype; + node->buffer = std::move(buffer); + node->indices = std::move(indices); + data_ = std::move(node); +} -TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { - DataType t = args[0]; - if (args.size() == 3) { - *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); - } else { - *ret = LoadNode::make(t, args[1], args[2], args[3]); - } +TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { + return BufferLoad(buffer, indices); }); -TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, std::string name, Array args, int call_type) { - Array prim_expr_args; - for (const auto& it : args) { - CHECK(it->IsInstance() || it->IsInstance()); - if (const auto* str = it.as()) { - prim_expr_args.push_back(StringImmNode::make(str->data)); - } else { - prim_expr_args.push_back(Downcast(it)); +TVM_REGISTER_NODE_TYPE(BufferLoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; } } - return CallNode::make(type, name, prim_expr_args, static_cast(call_type)); + p->stream << "]"; + }); + +// ProducerLoad +ProducerLoad::ProducerLoad(DataProducer producer, Array indices) { + ObjectPtr node = make_object(); + node->dtype = producer->GetDataType(); + node->producer = std::move(producer); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.ProducerLoad") + .set_body_typed([](DataProducer producer, Array indices) { + return ProducerLoad(producer, indices); }); +TVM_REGISTER_NODE_TYPE(ProducerLoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 98d61a0..b92127b 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -125,7 +125,7 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { if (index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { - return LoadNode::make(op->dtype, op->buffer_var, index, predicate); + return Load(op->dtype, op->buffer_var, index, predicate); } } @@ -155,7 +155,7 @@ PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } @@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, args, op->call_type); + return Call(op->dtype, op->name, args, op->call_type); } } @@ -177,34 +177,34 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return OP::make(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP(a, b); \ + } \ } -DEFINE_BIOP_EXPR_MUTATE_(AddNode); -DEFINE_BIOP_EXPR_MUTATE_(SubNode); -DEFINE_BIOP_EXPR_MUTATE_(MulNode); -DEFINE_BIOP_EXPR_MUTATE_(DivNode); -DEFINE_BIOP_EXPR_MUTATE_(ModNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); -DEFINE_BIOP_EXPR_MUTATE_(MinNode); -DEFINE_BIOP_EXPR_MUTATE_(MaxNode); -DEFINE_BIOP_EXPR_MUTATE_(EQNode); -DEFINE_BIOP_EXPR_MUTATE_(NENode); -DEFINE_BIOP_EXPR_MUTATE_(LTNode); -DEFINE_BIOP_EXPR_MUTATE_(LENode); -DEFINE_BIOP_EXPR_MUTATE_(GTNode); -DEFINE_BIOP_EXPR_MUTATE_(GENode); -DEFINE_BIOP_EXPR_MUTATE_(AndNode); -DEFINE_BIOP_EXPR_MUTATE_(OrNode); +DEFINE_BIOP_EXPR_MUTATE_(Add); +DEFINE_BIOP_EXPR_MUTATE_(Sub); +DEFINE_BIOP_EXPR_MUTATE_(Mul); +DEFINE_BIOP_EXPR_MUTATE_(Div); +DEFINE_BIOP_EXPR_MUTATE_(Mod); +DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); +DEFINE_BIOP_EXPR_MUTATE_(FloorMod); +DEFINE_BIOP_EXPR_MUTATE_(Min); +DEFINE_BIOP_EXPR_MUTATE_(Max); +DEFINE_BIOP_EXPR_MUTATE_(EQ); +DEFINE_BIOP_EXPR_MUTATE_(NE); +DEFINE_BIOP_EXPR_MUTATE_(LT); +DEFINE_BIOP_EXPR_MUTATE_(LE); +DEFINE_BIOP_EXPR_MUTATE_(GT); +DEFINE_BIOP_EXPR_MUTATE_(GE); +DEFINE_BIOP_EXPR_MUTATE_(And); +DEFINE_BIOP_EXPR_MUTATE_(Or); PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { auto fitervar = [this](const IterVar& v) { @@ -214,8 +214,7 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; } else { - return IterVarNode::make(Range::make_by_min_extent(min, extent), v->var, v->iter_type, - v->thread_tag); + return IterVar(Range::make_by_min_extent(min, extent), v->var, v->iter_type, v->thread_tag); } }; Array axis = MutateArray(op->axis, fitervar); @@ -228,7 +227,7 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) { return GetRef(op); } else { - return ReduceNode::make(op->combiner, source, axis, condition, op->value_index); + return Reduce(op->combiner, source, axis, condition, op->value_index); } } @@ -237,7 +236,7 @@ PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return CastNode::make(op->dtype, value); + return Cast(op->dtype, value); } } @@ -246,7 +245,7 @@ PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { if (a.same_as(op->a)) { return GetRef(op); } else { - return NotNode::make(a); + return Not(a); } } @@ -258,7 +257,7 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return SelectNode::make(condition, true_value, false_value); + return Select(condition, true_value, false_value); } } @@ -268,7 +267,7 @@ PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { - return RampNode::make(base, stride, op->lanes); + return Ramp(base, stride, op->lanes); } } @@ -277,7 +276,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return BroadcastNode::make(value, op->lanes); + return Broadcast(value, op->lanes); } } @@ -287,7 +286,7 @@ PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { if (vectors.same_as(op->vectors)) { return GetRef(op); } else { - return ShuffleNode::make(vectors, op->indices); + return Shuffle(vectors, op->indices); } } diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 2757c2f..5ac9f59 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -67,14 +67,13 @@ Type GetType(const PrimExpr& expr) { // simple cast that only checks if type matches and cast inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return tir::CallNode::make( - t, tir::intrinsic::tvm_large_uint_imm, - {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, - tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::intrinsic::tvm_large_uint_imm, + {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, + tir::CallNode::PureIntrinsic); } // The public function with a quick checking path. @@ -83,9 +82,9 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = tir::BroadcastNode::make(lhs, rtype.lanes()); + lhs = tir::Broadcast(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); + rhs = tir::Broadcast(rhs, ltype.lanes()); } else { CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } @@ -227,7 +226,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast @@ -238,27 +237,27 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { - value = tir::CastNode::make(vtype, value); + value = tir::Cast(vtype, value); } } - return tir::BroadcastNode::make(value, t.lanes()); + return tir::Broadcast(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } } } PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CallNode::make(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::AddNode::make(a, b); + return tir::Add(a, b); } // negation @@ -274,23 +273,23 @@ PrimExpr operator-(PrimExpr a) { PrimExpr operator-(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::SubNode::make(a, b); + return tir::Sub(a, b); } PrimExpr operator*(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MulNode::make(a, b); + return tir::Mul(a, b); } PrimExpr div(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::DivNode::make(a, b); + return tir::Div(a, b); } PrimExpr truncdiv(PrimExpr a, PrimExpr b) { @@ -301,9 +300,9 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b) { PrimExpr truncmod(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::ModNode::make(a, b); + return tir::Mod(a, b); } PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } @@ -319,18 +318,18 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::FloorDivNode::make(a, b); + return tir::FloorDiv(a, b); } PrimExpr floormod(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::FloorModNode::make(a, b); + return tir::FloorMod(a, b); } PrimExpr min(PrimExpr a, PrimExpr b) { @@ -342,9 +341,9 @@ PrimExpr min(PrimExpr a, PrimExpr b) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MinNode::make(a, b); + return tir::Min(a, b); } PrimExpr max(PrimExpr a, PrimExpr b) { @@ -356,9 +355,9 @@ PrimExpr max(PrimExpr a, PrimExpr b) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MaxNode::make(a, b); + return tir::Max(a, b); } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { @@ -372,79 +371,78 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::CallNode::make(true_value.dtype(), tir::intrinsic::tvm_if_then_else, - {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); + return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::CallNode::make(cond.dtype(), tir::CallNode::likely, {cond}, - tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::GTNode::make(a, b); + return tir::GT(a, b); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::GENode::make(a, b); + return tir::GE(a, b); } PrimExpr operator<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::LTNode::make(a, b); + return tir::LT(a, b); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::LENode::make(a, b); + return tir::LE(a, b); } PrimExpr operator==(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::EQNode::make(a, b); + return tir::EQ(a, b); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::NENode::make(a, b); + return tir::NE(a, b); } PrimExpr operator&&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::AndNode::make(a, b); + return tir::And(a, b); } PrimExpr operator||(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::OrNode::make(a, b); + return tir::Or(a, b); } PrimExpr operator!(PrimExpr a) { CHECK(a.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a); + PrimExpr ret = arith::TryConstFold(a); if (ret.defined()) return ret; - return tir::NotNode::make(a); + return tir::Not(a); } PrimExpr operator>>(PrimExpr a, PrimExpr b) { @@ -462,8 +460,7 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::CallNode::make(a.dtype(), tir::CallNode::shift_right, {a, b}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { @@ -481,8 +478,7 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::CallNode::make(a.dtype(), tir::CallNode::shift_left, {a, b}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { @@ -493,8 +489,7 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_and, {a, b}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { @@ -505,8 +500,7 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_or, {a, b}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { @@ -517,20 +511,18 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::CallNode::make(a.dtype(), tir::CallNode::bitwise_not, {a}, - tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::CallNode::make(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { @@ -540,14 +532,14 @@ PrimExpr abs(PrimExpr x) { if (px) { return IntImm(x.dtype(), std::abs(px->value)); } - return tir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); + return tir::Select(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - return tir::CallNode::make(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -568,11 +560,11 @@ PrimExpr isnan(PrimExpr x) { return make_const(t, std::isnan(fx->value)); } if (x.dtype().bits() == 16) { - return tir::CallNode::make(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, - tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::isnan, + {cast(DataType::Float(32, t.lanes()), std::move(x))}, + tir::CallNode::PureIntrinsic); } else { - return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -597,58 +589,58 @@ PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); } PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::AddNode::make(x, y); + PrimExpr result = tir::Add(x, y); PrimExpr identity_element = make_zero(source.dtype()); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr all(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::AndNode::make(x, y); + PrimExpr result = tir::And(x, y); PrimExpr identity_element = make_const(source.dtype(), true); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr any(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::OrNode::make(x, y); + PrimExpr result = tir::Or(x, y); PrimExpr identity_element = make_const(source.dtype(), false); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MaxNode::make(x, y); + PrimExpr result = tir::Max(x, y); PrimExpr identity_element = min_value(source.dtype()); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MinNode::make(x, y); + PrimExpr result = tir::Min(x, y); PrimExpr identity_element = max_value(source.dtype()); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MulNode::make(x, y); + PrimExpr result = tir::Mul(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); - tir::CommReducer combiner = tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::CallNode::make(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { @@ -658,7 +650,7 @@ PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - return tir::CallNode::make(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); } PrimExpr ceil(PrimExpr x) { @@ -668,7 +660,7 @@ PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - return tir::CallNode::make(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); } PrimExpr round(PrimExpr x) { @@ -678,7 +670,7 @@ PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::CallNode::make(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); } PrimExpr nearbyint(PrimExpr x) { @@ -688,7 +680,7 @@ PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::CallNode::make(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); } PrimExpr trunc(PrimExpr x) { @@ -700,7 +692,7 @@ PrimExpr trunc(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } // expose basic functions to node namespace @@ -716,8 +708,6 @@ TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("node.String").set_body_typed(tir::StringImmNode::make); - TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 094abc3..46c4b09 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -67,7 +67,7 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { TVM_REGISTER_GLOBAL("tir.AssertStmt") .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { if (const auto* str = message.as()) { - auto msg = StringImmNode::make(str->data); + auto msg = StringImm(str->data); return AssertStmtNode::make(condition, msg, body); } else { return AssertStmtNode::make(condition, Downcast(message), body); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 6d0a60f..06958a2 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -488,7 +488,7 @@ class IRSubstitue : public StmtExprMutator { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return LoadNode::make(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); + return Load(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); } else { return ret; } diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 384d459..14452a6 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -42,8 +42,8 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint"; - asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()), - EvaluateNode::make(0))); + asserts->emplace_back( + AssertStmtNode::make(scond, tvm::tir::StringImm(os.str()), EvaluateNode::make(0))); } } @@ -157,7 +157,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); - auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); + auto msg = tvm::tir::StringImm(ndim_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); // type checks DataType dtype = buffer->dtype; @@ -170,7 +170,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { - auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); + auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); } @@ -195,19 +195,18 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; - Bind_( - buffer->shape[k], - cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), - field_name.str(), true); + Bind_(buffer->shape[k], + cast(buffer->shape[k].dtype(), + Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), + field_name.str(), true); } // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back(LetStmtNode::make( v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); - PrimExpr is_null = CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, - CallNode::PureIntrinsic); + PrimExpr is_null = + Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -215,8 +214,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = cast(stype, LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = + cast(stype, Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -224,11 +223,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, stride_err_msg << arg_name << ".strides:" << " expected to be compact array"; if (conds.size() != 0) { - auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); + auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg, EvaluateNode::make(0)); - check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); + check = IfThenElseNode::make(Not(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { @@ -238,9 +237,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - PrimExpr value = cast( - buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr value = + cast(buffer->shape[k].dtype(), + Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); @@ -249,16 +248,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; - asserts_.emplace_back(AssertStmtNode::make( - NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop)); + asserts_.emplace_back( + AssertStmtNode::make(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop)); for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), - const_true(1))), + Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 2e1e5b9..55a8131 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -87,7 +87,7 @@ class BoundChecker : public StmtExprMutator { if (!condition.as()) { Stmt nop = EvaluateNode::make(1); Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); - Stmt else_case = AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); + Stmt else_case = AssertStmtNode::make(condition, StringImm(error_message_), nop); Stmt body = IfThenElseNode::make(condition, then_case, else_case); return body; } @@ -121,12 +121,12 @@ class BoundChecker : public StmtExprMutator { } // Scalarize the shape. - PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[0])); + PrimExpr shape = + Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. - shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[i]))); + shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), + Cast(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -163,8 +163,7 @@ class BoundChecker : public StmtExprMutator { if (const RampNode* ramp_index = index.as()) { // In case index is base + stride * i. // Non inclusive range. - index = AddNode::make(ramp_index->base, MulNode::make(ramp_index->stride, - make_const(ramp_index->stride.dtype(), + index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); } @@ -173,15 +172,14 @@ class BoundChecker : public StmtExprMutator { upper_bound = analyzer_.Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. - index = CastNode::make(DataType::Int(64), index); - upper_bound = CastNode::make(DataType::Int(64), upper_bound); + index = Cast(DataType::Int(64), index); + upper_bound = Cast(DataType::Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. PrimExpr lower_bound = make_zero(DataType::Int(64)); - PrimExpr current_condition = - AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound)); - condition = !i ? current_condition : AndNode::make(condition, current_condition); + PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); + condition = !i ? current_condition : And(condition, current_condition); } return condition; } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 41e1124..3072c0d 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -195,8 +195,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return { - EvaluateNode::make(CallNode::make(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; + return {EvaluateNode::make(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -332,9 +331,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor { CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return EvaluateNode::make(CallNode::make( - DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, - CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Int(32), func, + {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, + CallNode::Intrinsic)); } // Write barrier name bool read_barrier_{false}; @@ -557,15 +556,15 @@ class CoProcInstDepDetector : public StmtVisitor { Stmt MakePush(int from, int to) { return EvaluateNode::make( - CallNode::make(DataType::Int(32), sync_push_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + Call(DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { return EvaluateNode::make( - CallNode::make(DataType::Int(32), sync_pop_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + Call(DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } // sync states. SyncState first_state_, last_state_, curr_state_; diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 5af1a39..c201b8f 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -120,7 +120,7 @@ class CopyIntrinInjector : public StmtMutator { DataType t = loop_vars[i].dtype(); PrimExpr svalue = src_shape[i]; if (min_value.defined()) { - PrimExpr pbefore = analyzer_.Simplify(MaxNode::make(min_value, make_zero(t))); + PrimExpr pbefore = analyzer_.Simplify(Max(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index c405b1f..3f53022 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -125,9 +125,8 @@ class DoubleBufferInjector : public StmtExprMutator { } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmtNode::make(op->buffer_var, attr::storage_scope, - StringImmNode::make(it->second.scope), - EvaluateNode::make(0))); + alloc_nest.emplace_back(AttrStmtNode::make( + op->buffer_var, attr::storage_scope, StringImm(it->second.scope), EvaluateNode::make(0))); alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents, op->condition, EvaluateNode::make(0))); return op->body; @@ -205,8 +204,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); - return LoadNode::make(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, - op->predicate); + return Load(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, + op->predicate); } else { return expr; } diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 6528e97..f9088e3 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -212,8 +212,7 @@ class VTInjector : public StmtExprMutator { } auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return LoadNode::make(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), - op->predicate); + return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate); } else { return expr; } @@ -231,8 +230,8 @@ class VTInjector : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return CallNode::make(op->dtype, op->name, - {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); + return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, + op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return allow_share_ ? GetRef(op) : var_; } else { diff --git a/src/tir/transforms/ir_util.cc b/src/tir/transforms/ir_util.cc index ff3e941..28f347e 100644 --- a/src/tir/transforms/ir_util.cc +++ b/src/tir/transforms/ir_util.cc @@ -103,7 +103,7 @@ class IRConvertSSA final : public StmtExprMutator { scope_[v.get()].push_back(new_var); PrimExpr body = this->VisitExpr(op->body); scope_[v.get()].pop_back(); - return LetNode::make(new_var, value, body); + return Let(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); @@ -113,8 +113,7 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { - return LoadNode::make(op->dtype, scope_[op->buffer_var.get()].back(), op->index, - op->predicate); + return Load(op->dtype, scope_[op->buffer_var.get()].back(), op->index, op->predicate); } else { return expr; } diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 69b5a39..4fbd2a0 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -86,7 +86,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, intrinsic::TVMStructFieldKind kind) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); + return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } /*! @@ -96,11 +96,10 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return CallNode::make( - DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return Call(DataType::Handle(), intrinsic::tvm_address_of, + {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), + const_true(dtype.lanes()))}, + CallNode::PureIntrinsic); } /*! @@ -112,11 +111,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { if (dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); - offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return CallNode::make(DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, offset, const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return Call(DataType::Handle(), intrinsic::tvm_address_of, + {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } /*! @@ -132,7 +130,7 @@ inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind ki Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); + Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } /*! diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index c72928b..b06bb8a 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -246,22 +246,22 @@ class PartitionFinder : public StmtExprVisitor { PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { // a < b -> a >= b - inverse_cond = GENode::make(op->a, op->b); + inverse_cond = GE(op->a, op->b); } else if (const GTNode* op = cond.as()) { // a > b -> a <= b - inverse_cond = LENode::make(op->a, op->b); + inverse_cond = LE(op->a, op->b); } else if (const LENode* op = cond.as()) { // a <= b -> a > b - inverse_cond = GTNode::make(op->a, op->b); + inverse_cond = GT(op->a, op->b); } else if (const GENode* op = cond.as()) { // a >= b -> a < b - inverse_cond = LTNode::make(op->a, op->b); + inverse_cond = LT(op->a, op->b); } else if (const EQNode* op = cond.as()) { // a == b -> a != b - inverse_cond = NENode::make(op->a, op->b); + inverse_cond = NE(op->a, op->b); // a != b -> a == b } else if (const NENode* op = cond.as()) { - inverse_cond = EQNode::make(op->a, op->b); + inverse_cond = EQ(op->a, op->b); } return inverse_cond; } @@ -509,7 +509,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; - body_begin = MaxNode::make(body_begin, min); + body_begin = Max(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } @@ -534,7 +534,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var PrimExpr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; - post_doubt_begin = MinNode::make(post_doubt_begin, max + 1); + post_doubt_begin = Min(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 92b463c..4a15501 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -92,7 +92,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { load = expr.as(); if (toBeLowered) { auto new_load_type = DataType::UInt(load->dtype.bits()); - return LoadNode::make(new_load_type, load->buffer_var, load->index, load->predicate); + return Load(new_load_type, load->buffer_var, load->index, load->predicate); } return expr; } diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 7df8fd2..c7aa949 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -98,7 +98,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return tir::SelectNode::make(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); + return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -108,8 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); - return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, - rdiv - make_const(dtype, 1)); + return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1)); } } @@ -144,7 +144,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return tir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); + return tir::Select(rmod >= 0, rmod, rmod + op->b); } } } else { @@ -155,8 +155,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return tir::SelectNode::make((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, - rmod + op->b); + return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); } } @@ -217,8 +216,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { }; if (should_swap()) { - PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes); - return CastNode::make(bcast->dtype, new_bcast); + PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes); + return Cast(bcast->dtype, new_bcast); } } } @@ -231,13 +230,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = - (*fma_)(CallNode::make(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { - PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs)); - return AddNode::make(mul, this->VisitExpr(c)); + PrimExpr mul = this->VisitExpr(Mul(lhs, rhs)); + return Add(mul, this->VisitExpr(c)); } } return IRMutatorWithAnalyzer::VisitExpr_(op); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 7f9a329..f6daabd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -86,15 +86,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (warp_allocs_.count(repl)) { stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, - StringImmNode::make("local"), stmt); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); } else { // use volatile access to shared buffer. stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, - StringImmNode::make("shared"), stmt); + stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); } return stmt; } else { @@ -139,7 +137,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1 + idx]; if (!is_one(cond)) { - values[idx] = SelectNode::make(cond, values[idx], inits[idx]); + values[idx] = Select(cond, values[idx], inits[idx]); } types[idx] = values[idx].dtype(); } @@ -232,8 +230,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var mask_var("mask", DataType::UInt(32)); { PrimExpr pred = const_true(1); - PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, - CallNode::Intrinsic); + PrimExpr mask = + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); seq.emplace_back(StoreNode::make(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -249,7 +247,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - PrimExpr val = LoadNode::make(types[i], var, index, pred); + PrimExpr val = Load(types[i], var, index, pred); a.push_back(val); // __shfl_*sync calls shall not appear in if_then_else expressions @@ -271,7 +269,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt s = StoreNode::make(repl->buffer_var, other, index, pred); seq.push_back(s); - PrimExpr load = LoadNode::make(types[i], repl->buffer_var, index, pred); + PrimExpr load = Load(types[i], repl->buffer_var, index, pred); b.push_back(load); } @@ -296,7 +294,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); const char* shfl_func = intrinsic::tvm_warp_shuffle; - PrimExpr val = LoadNode::make(types[i], var, index, pred); + PrimExpr val = Load(types[i], var, index, pred); PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); seq.push_back(StoreNode::make(var, splat, index, pred)); } @@ -306,7 +304,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { CHECK(!load_remap_.count(buffers[i])); PrimExpr pred = const_true(types[i].lanes()); Var var = shared_bufs[i]; - load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred); + load_remap_[buffers[i]] = Load(types[i], var, index, pred); Array extents{PrimExpr(1)}; auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0)); alloc_remap_[buffers[i]] = node; @@ -343,9 +341,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { CHECK(!load_remap_.count(buffers[idx])); PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = LoadNode::make( - types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); + load_remap_[buffers[idx]] = + Load(types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); alloc_remap_[buffers[idx]] = AllocateNode::make( shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, EvaluateNode::make(0)); @@ -359,8 +357,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (repl) { body = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, - StringImmNode::make("local"), body); + body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), body); } } @@ -385,10 +382,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto freduce = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { - b.push_back(LoadNode::make(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); - a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); + b.push_back(Load(types[i], shared_bufs[i], + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); + a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); std::vector stores(size); @@ -459,18 +456,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync)}, CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync)}, CallNode::Intrinsic)); } // Emit warp shuffle intrinsic calls. PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); - PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred); + PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; - return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic); + return Call(val.dtype(), name, args, CallNode::Intrinsic); } // Check if this is a reduction on threadIdx.x and its extent matches diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 88c4363..0e52802 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -39,8 +39,8 @@ inline PrimExpr ConstInt32(size_t index) { } inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImmNode::make(type), ConstInt32(num)}; - return CallNode::make(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); + Array args = {StringImm(type), ConstInt32(num)}; + return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -102,29 +102,27 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); + Stmt throw_last_error = EvaluateNode::make( + Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt( - {IfThenElseNode::make(CallNode::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), - throw_last_error), - op->body}); + Stmt body = SeqStmt({IfThenElseNode::make(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), + op->body}); Stmt alloca = LetStmtNode::make( op->buffer_var, - CallNode::make( - op->buffer_var.dtype(), "TVMBackendAllocWorkspace", - {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}, - CallNode::Extern), + Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, + CallNode::Extern), body); - PrimExpr free_op = CallNode::make(DataType::Int(32), "TVMBackendFreeWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), op->buffer_var}, - CallNode::Extern); + PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), op->buffer_var}, + CallNode::Extern); Stmt free_stmt = IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); @@ -226,7 +224,7 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = CastNode::make(api_type, arg); + arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), intrinsic::kTVMValueContent, arg)); @@ -248,8 +246,8 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, - CallNode::Intrinsic); + return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + CallNode::Intrinsic); } PrimExpr MakeCallTracePacked(const CallNode* op) { @@ -267,7 +265,7 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = CastNode::make(api_type, arg); + arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), intrinsic::kTVMValueContent, arg)); @@ -290,8 +288,8 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. op->args[args_size - 1]}; - return CallNode::make(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, - CallNode::Intrinsic); + return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + CallNode::Intrinsic); } private: diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index fb86bc2..7294c01 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -249,11 +249,11 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK(!ExprUseVar(local_index, warp_index_)) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; - PrimExpr load_value = LoadNode::make(op->dtype, op->buffer_var, local_index, op->predicate); - PrimExpr mask = CallNode::make(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, - CallNode::Intrinsic); - return CallNode::make(load_value.dtype(), intrinsic::tvm_warp_shuffle, - {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); + PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); + PrimExpr mask = + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle, + {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } @@ -271,8 +271,7 @@ class WarpAccessRewriter : protected StmtExprMutator { CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); std::tie(local_index, group) = SplitIndexByGroup(base.Eval()); - local_index = - RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); + local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } PrimExpr m = make_const(index.dtype(), warp_coeff_); @@ -374,7 +373,7 @@ class WarpMemoryRewriter : private StmtMutator { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); - return AttrStmtNode::make(op->node, op->attr_key, StringImmNode::make("local"), op->body); + return AttrStmtNode::make(op->node, op->attr_key, StringImm("local"), op->body); } } return StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b6314ad..0fdfb85 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,8 +41,7 @@ namespace tvm { namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg), - EvaluateNode::make(0)); + return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImm(msg), EvaluateNode::make(0)); } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { @@ -86,11 +85,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = - CallNode::make(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { - res = CastNode::make(t, res); + res = Cast(t, res); } return res; }; @@ -127,11 +125,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back( - LetStmtNode::make(tcode, - LoadNode::make(DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back(LetStmtNode::make(tcode, + Load(DataType::Int(32), v_packed_arg_type_ids, + IntImm(DataType::Int(32), i), const_true(1)), + nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; @@ -139,18 +136,18 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { seq_check.emplace_back( AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, - tvm::tir::StringImmNode::make(msg.str()), nop)); + tvm::tir::StringImm(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop)); + AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop)); + AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); } } else { args.push_back(v_arg); @@ -186,18 +183,18 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope, - StringImmNode::make(name_hint + "_compute_"), func_ptr->body); + StringImm(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { - PrimExpr node = StringImmNode::make("default"); + PrimExpr node = StringImm("default"); seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop)); seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { - Stmt set_device = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - {StringImmNode::make(runtime::symbol::tvm_set_device), device_type, device_id}, - CallNode::Intrinsic)); + Stmt set_device = EvaluateNode::make( + Call(DataType::Int(32), intrinsic::tvm_call_packed, + {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, + CallNode::Intrinsic)); body = SeqStmt({set_device, body}); } } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index ad86e45..af2886e 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -235,7 +235,7 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { - ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); + ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag); } return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } @@ -266,7 +266,7 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate); + PrimExpr e = Load(op->dtype, op->buffer_var, index, op->predicate); return StmtExprMutator::VisitExpr_(e.as()); } @@ -285,7 +285,7 @@ class DataTypeRewriter : public StmtExprMutator { const CastNode* new_op = e.as(); CHECK(new_op != nullptr) << "Expected type to be CastNode" << ", but get " << e->GetTypeKey(); - return CastNode::make(visitor_.vmap[op], new_op->value); + return Cast(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 149cda9..701f0ce 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -104,8 +104,8 @@ class UnsafeSelectRewriter : public StmtExprMutator { bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return CallNode::make(op->dtype, intrinsic::tvm_if_then_else, - {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); + return Call(op->dtype, intrinsic::tvm_if_then_else, + {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 98577a7..1806265 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -107,7 +107,7 @@ class VarUseDefAnalysis : public StmtExprMutator { if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } } @@ -230,15 +230,15 @@ class HostDeviceSplitter : public StmtMutator { // generate calls to the device function Array call_args; - call_args.push_back(StringImmNode::make(kernel_symbol)); + call_args.push_back(StringImm(kernel_symbol)); for (PrimExpr arg : arguments) { call_args.push_back(arg); } for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, - call_args, CallNode::Intrinsic)); + return EvaluateNode::make( + Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); } // target ir module diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 8f2a95b..21ddaaf 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -228,8 +228,8 @@ class StorageFlattener : public StmtExprMutator { ret = AllocateNode::make(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmtNode::make(e.buffer->data, attr::storage_scope, - StringImmNode::make(e.buffer->scope), ret); + ret = + AttrStmtNode::make(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, @@ -246,7 +246,7 @@ class StorageFlattener : public StmtExprMutator { if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); - return LoadNode::make(op->dtype, buf_var, op->index, op->predicate); + return Load(op->dtype, buf_var, op->index, op->predicate); } else { return expr; } @@ -324,9 +324,9 @@ class StorageFlattener : public StmtExprMutator { } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = - CallNode::make(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); - PrimExpr prefetch = CallNode::make(op->buffer->dtype, CallNode::prefetch, - {address, 0, 3, 1}, CallNode::Intrinsic); + Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + PrimExpr prefetch = + Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); stmt = EvaluateNode::make(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -481,10 +481,9 @@ class StorageFlattener : public StmtExprMutator { PrimExpr MakeBound(const DataType& type, const Array& shape) { // We have already checked the shape size to be greater then 0. - PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); + PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { - bound = - MulNode::make(bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); + bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); } return bound; } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 365ff75..952d273 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -351,7 +351,7 @@ class StoragePlanRewriter : public StmtExprMutator { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), + StringImm(e->scope.to_string()), EvaluateNode::make(0))); nest.push_back(e->new_alloc); } @@ -373,8 +373,8 @@ class StoragePlanRewriter : public StmtExprMutator { op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; - return LoadNode::make(op->dtype, it->second->alloc_var, - RemapIndex(op->dtype, op->index, it->second), op->predicate); + return Load(op->dtype, it->second->alloc_var, RemapIndex(op->dtype, op->index, it->second), + op->predicate); } PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); @@ -404,9 +404,8 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return CallNode::make(op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { return StmtExprMutator::VisitExpr_(op); } @@ -500,7 +499,7 @@ class StoragePlanRewriter : public StmtExprMutator { for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), + StringImm(e->scope.to_string()), EvaluateNode::make(0))); nest.push_back(e->new_alloc); } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 8650d2c..bd66fc0 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -187,12 +187,12 @@ class InferFragmenter : public StmtMutator { // Add shape attribute to all fragments std::string shape = std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); - PrimExpr shape_expr = StringImmNode::make(shape); + PrimExpr shape_expr = StringImm(shape); Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout, - StringImmNode::make(info.layout), shape_attr); + StringImm(info.layout), shape_attr); return layout_attr; } else { return shape_attr; diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 4383ecf..266ada0 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -209,9 +209,9 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = EvaluateNode::make(CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string())}, - CallNode::Intrinsic)); + barrier = + EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -298,9 +298,9 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { CHECK(op != nullptr); - Array pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)}; + Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); + Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; @@ -309,8 +309,8 @@ class ThreadSyncInserter : public StmtExprMutator { } } rw_stats_.clear(); - Stmt kinit = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Stmt kinit = EvaluateNode::make( + Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); body = AttrStmtNode::make(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); @@ -333,10 +333,9 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string()), is_lead_, num_blocks_}, - CallNode::Intrinsic)); + return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, + CallNode::Intrinsic)); } // data structure. StorageScope sync_scope_; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index c0a546d..290a3a4 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -39,12 +39,12 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { if (e.dtype().lanes() == lanes) return e; if (const BroadcastNode* op = e.as()) { if (lanes % op->lanes == 0) { - return BroadcastNode::make(op->value, lanes); + return Broadcast(op->value, lanes); } } CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " << lanes; - return BroadcastNode::make(e, lanes); + return Broadcast(e, lanes); } // Rewrite vectorized allocation access @@ -64,8 +64,7 @@ class VecAllocAccess : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { - return LoadNode::make(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, - op->predicate); + return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); } else { return expr; } @@ -94,7 +93,7 @@ class VecAllocAccess : public StmtExprMutator { class Vectorizer : public StmtExprMutator { public: Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { - ramp_ = RampNode::make(0, 1, var_lanes); + ramp_ = Ramp(0, 1, var_lanes); } Stmt VisitStmt(const Stmt& stmt) final { @@ -127,37 +126,37 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - return RampNode::make(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); + return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - return RampNode::make(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); + return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } } - return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } - return BinaryVec(op); + return BinaryVec(op); } - PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } - PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec
(op); } + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const RampNode* base_ramp = base.as(); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { - return RampNode::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); + return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); @@ -165,10 +164,10 @@ class Vectorizer : public StmtExprMutator { stride = BroadcastTo(stride, lanes); Array elems; for (int i = 0; i < lanes; ++i) { - elems.push_back(RampNode::make(ShuffleNode::make_extract_element(base, i), - ShuffleNode::make_extract_element(stride, i), op->lanes)); + elems.push_back( + Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); } - return ShuffleNode::make_concat(elems); + return Shuffle::Concat(elems); } PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr cond = this->VisitExpr(op->condition); @@ -178,7 +177,7 @@ class Vectorizer : public StmtExprMutator { return GetRef(op); } else { int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); - return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } PrimExpr VisitExpr_(const CastNode* op) final { @@ -186,7 +185,7 @@ class Vectorizer : public StmtExprMutator { if (value.same_as(op->value)) { return GetRef(op); } else { - return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value); + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable @@ -214,7 +213,7 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return CallNode::make(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); + return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); } } // Call @@ -236,7 +235,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, new_args, op->call_type); + return Call(op->dtype, op->name, new_args, op->call_type); } } else { int lane = 0; @@ -245,7 +244,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); + return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); } } } @@ -257,8 +256,8 @@ class Vectorizer : public StmtExprMutator { return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return LoadNode::make(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // Let @@ -268,13 +267,13 @@ class Vectorizer : public StmtExprMutator { if (value.dtype().lanes() != op->value.dtype().lanes()) { Var v(op->var->name_hint, value.dtype()); lets_[op->var.get()] = v; - return LetNode::make(v, value, this->VisitExpr(op->body)); + return Let(v, value, this->VisitExpr(op->body)); } else { PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } } @@ -406,15 +405,16 @@ class Vectorizer : public StmtExprMutator { if (!changed) return arr; return Array(new_arr); } - template + template PrimExpr BinaryVec(const T* op) { + static_assert(std::is_same::value, "constraint"); PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } template @@ -429,15 +429,14 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a.dtype().lanes() == 1 && b_ramp) { - return RampNode::make(fcompute(a, b_ramp->base), - fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), - b_ramp->lanes); + return Ramp(fcompute(a, b_ramp->base), + fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } if (b.dtype().lanes() == 1 && a_ramp) { - return RampNode::make(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } }; diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 8e9d7bc..341d9f8 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -49,7 +49,7 @@ TEST(Simplify, Mod) { // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = ana.canonical_simplify(tvm::tir::ModNode::make(x, y)); + auto mod = ana.canonical_simplify(tvm::tir::Mod(x, y)); auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 052cba1..b9f5b9c 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -192,8 +192,7 @@ TEST(IRF, StmtMutator) { } { - auto body = - EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = EvaluateNode::make(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 59d0a43..5063509 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -50,6 +50,7 @@ TEST(Pattern, Basic) { CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); + CHECK((px + min(py, px)).Match(z + min(y, z))); CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); @@ -65,7 +66,7 @@ TEST(Pattern, Basic) { CHECK((px >= py && px < pz).Match(x >= y && x < z)); CHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - CHECK(select(px >= pz, py, py + pz).Match(tir::SelectNode::make((x + 1) >= 1, y, y + 1))); + CHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics @@ -81,13 +82,13 @@ TEST(Pattern, Basic) { CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - CHECK(select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 1, y, y + 1))); + CHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } - CHECK(!select(px > pz, py, py + pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); - CHECK(!select(px > pz, py, py).Match(tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); { - CHECK(select(px, py, pz).Match(tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else @@ -97,28 +98,26 @@ TEST(Pattern, Basic) { } // cast pattern { - CHECK(!cast(PConst(DataType::Int(32)), px) - .Match(tir::CastNode::make(DataType::Float(64), x))); - CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x))); + CHECK(!cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); + CHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); CHECK((cast(pt, px) - cast(pt, py)) - .Match(tir::CastNode::make(DataType::Float(64), x) - - tir::CastNode::make(DataType::Int(64), x))); - auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x)); + .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); + auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - CHECK(ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 1, 10))); + CHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); CHECK(planes.Eval() == 10); - CHECK(!ramp(px, PConst(1), planes).Match(tir::RampNode::make(x, 2, 10))); + CHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { - CHECK(broadcast(px, planes).Match(tir::BroadcastNode::make(x, 10))); + CHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); CHECK(planes.Eval() == 10); - CHECK(broadcast(px * py, planes).Match(tir::BroadcastNode::make(x * 10, 10))); + CHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); } } diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index c9c9f88..8823134 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -54,9 +54,9 @@ TEST(MicroStandaloneRuntime, BuildModule) { auto a = relay::VarNode::make("a", tensor_type); auto b = relay::VarNode::make("b", tensor_type); auto add_op = relay::Op::Get("add"); - auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); auto c = relay::VarNode::make("c", tensor_type); - auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index 0c04eaf..30ad525 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -49,7 +49,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, b return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, [&](Array ins, Array outs) { - return call_packed({StringImmNode::make("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), + return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, "C", "", {})[0]; @@ -71,13 +71,13 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tra auto n = transa ? lhs->shape[2] : lhs->shape[1]; auto m = transb ? rhs->shape[1] : rhs->shape[2]; - return make_extern({{b, n, m}}, {lhs->dtype}, {lhs, rhs}, - [&](Array ins, Array outs) { - return call_packed({StringImmNode::make("tvm.contrib.cublas.batch_matmul"), - pack_buffer(ins[0]), pack_buffer(ins[1]), - pack_buffer(outs[0]), transa, transb}); - }, - "C", "", {})[0]; + return make_extern( + {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index 3baf105..988c375 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -49,7 +49,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, return make_extern( {{n, m}}, {lhs->dtype}, {lhs, rhs}, [&](Array ins, Array outs) { - return call_packed({StringImmNode::make("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), + return call_packed({StringImm("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); }, "C", "", {})[0]; diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 104faa6..f53693b 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -113,14 +113,12 @@ inline Array make_extern(const Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = - tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = - tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; } @@ -130,8 +128,8 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, - pack_args, tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args, + tvm::tir::CallNode::CallType::Intrinsic); } /*! @@ -144,8 +142,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, - tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + tvm::tir::CallNode::CallType::Intrinsic); } } // namespace detail diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 70daac2..a92d21c 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -213,8 +213,8 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag PrimExpr zero = make_zero(x->dtype); PrimExpr one = make_const(x->dtype, 1); PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); - auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); + auto s1 = tvm::tir::Select((x(i) < zero), minus_one, zero); + auto s2 = tvm::tir::Select((x(i) > zero), one, s1); return s2; }, name, tag); @@ -279,13 +279,13 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { return compute( x->shape, - [&](const Array& i) { + [&](const Array& i) -> PrimExpr { auto expr = x(i); if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { return expr; } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { - return tvm::tir::BroadcastNode::make(expr, type.lanes()); + return tvm::tir::Broadcast(expr, type.lanes()); } } @@ -309,8 +309,7 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te return compute( x->shape, [&](const Array& i) { - return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, - tvm::tir::CallNode::PureIntrinsic); + return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic); }, name, tag); } diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 7fbe7eb..2a195b3 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -37,19 +37,6 @@ namespace topi { using namespace tvm; using namespace tvm::te; -namespace detail { - -template -tvm::PrimExpr Map(const tvm::Array& exprs, T op) { - CHECK_GE(exprs.size(), 1); - tvm::PrimExpr res = exprs[0]; - for (size_t i = 1; i < exprs.size(); ++i) { - res = op(res, exprs[i]); - } - return res; -} - -} // namespace detail /*! * \brief Creates an operation that performs a rectified linear unit @@ -91,7 +78,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, [&](const tvm::Array& i) { auto value = t(i); auto calpha = tvm::tir::make_const(value.dtype(), alpha); - return tvm::tir::SelectNode::make(value > 0, value, value * calpha); + return tvm::tir::Select(value > 0, value, value * calpha); }, name, tag); } @@ -118,7 +105,7 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl x->shape, [&](const tvm::Array& indices) { auto xval = x(indices); - return tvm::tir::SelectNode::make(xval > 0, xval, xval * slope(indices[axis])); + return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); }, name, tag); } @@ -226,10 +213,11 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& kernel_size, } else { PrimExpr h_start = output[height_axis] * stride_height - pad_top; PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); + + PrimExpr h_end = min(h_start + kernel_height, height); + PrimExpr w_end = min(w_start + kernel_width, width); + h_start = max(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = max(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = max((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, @@ -266,19 +267,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); - PrimExpr out_idx_lower_h = tir::SelectNode::make( + PrimExpr out_idx_lower_h = tir::Select( pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::SelectNode::make( + PrimExpr out_idx_lower_w = tir::Select( pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else( - tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[width_axis] >= out_idx_lower_w), - mp_inds(out_idx) == idx), - out_grad(out_idx), make_const(x->dtype, 0)), + tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[width_axis] >= out_idx_lower_w), + mp_inds(out_idx) == idx), + out_grad(out_idx), make_const(x->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); @@ -298,11 +298,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); PrimExpr out_idx_lower_h = - tir::SelectNode::make(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), - (pad_h_idx - kernel_height) / stride_height + 1); + tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), + (pad_h_idx - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = - tir::SelectNode::make(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), - (pad_w_idx - kernel_width) / stride_width + 1); + tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), + (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { @@ -310,20 +310,20 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, } else { PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top; PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); - divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::Int(32), 1)); + + PrimExpr h_end = min(h_start + kernel_height, height); + PrimExpr w_end = min(w_start + kernel_width, width); + h_start = max(h_start, make_const(DataType::Int(32), 0)); + w_start = max(w_start, make_const(DataType::Int(32), 0)); + divide_factor = + max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1)); } return tvm::sum( - tvm::if_then_else( - tir::AndNode::make(tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[height_axis] < out_height), - tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, - out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), + tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[height_axis] < out_height), + tir::And(out_idx[width_axis] >= out_idx_lower_w, + out_idx[width_axis] < out_width)), + out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -462,7 +462,7 @@ inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const Pr inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) { PrimExpr tmp = indexdiv((out_index + 1) * idim, odim); - return tvm::tir::SelectNode::make(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); + return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); } /*! @@ -743,13 +743,12 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, for (int i = 0; i < k_size; i++) { int ii = axis[i]; start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + end[i] = min(start[i] + kernel[i], x->shape[ii]); + start[i] = max(start[i], make_const(DataType::Int(32), 0)); kernel_size *= (end[i] - start[i]); } - PrimExpr divide_factor = - tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + PrimExpr divide_factor = max(kernel_size, make_const(DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index c45bb50..8555500 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -280,11 +280,10 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, auto id_elem = fidentity(dtypes); auto cond = condition != nullptr ? *condition : tir::const_true(); - auto combiner = tvm::tir::CommReducerNode::make(lhs, rhs, result, id_elem); + auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { - outputs.push_back( - tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); + outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast(i))); } return outputs; }; @@ -442,8 +441,8 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { @@ -459,8 +458,8 @@ inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdi inline FCommReduce MakeArgmaxReducer() { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 4b7f7ca..9aa4e35 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -867,7 +867,7 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, out = compute( oshape, [&](const Array& indices) { - return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); + return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices)); }, name, tag); } else { @@ -878,7 +878,7 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, oshape, [&](const Array& indices) { Array condition_idx{indices[0]}; - return tvm::tir::SelectNode::make(condition(condition_idx) != 0, x(indices), y(indices)); + return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices)); }, name, tag); } @@ -1312,8 +1312,7 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim } auto idx = iter_vars[true_axis]; - return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, - off_value_cast); + return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast); }, name, tag); } -- 2.7.4