From: Tianqi Chen Date: Fri, 12 Jun 2020 15:24:18 +0000 (-0700) Subject: [REFACTOR][API-Change] Migrate all Object construction to constructor. (#5784) X-Git-Tag: upstream/0.7.0~577 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1a79dc9f43a99bae2efcae5e4fa302e8223f529a;p=platform%2Fupstream%2Ftvm.git [REFACTOR][API-Change] Migrate all Object construction to constructor. (#5784) This PR migrates all the remaining object constructions to the new constructor style that is consistent with the rest of the codebase and changes the affected files accordingly. Other changes: - ThreadScope::make -> ThreadScope::Create - StorageScope::make -> StorageScope::Create --- diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index a66328f..8674c8e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``. :: inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); + return Schedule(ops); } ``Schedule`` consists of collections of ``Stage`` and output ``Operation``. diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 2fc4636..a82ae4f 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -138,7 +138,7 @@ is shown below. if (g->tuple == t) { return GetRef(g); } else { - return TupleGetItemNode::make(t, g->index); + return TupleGetItem(t, g->index); } } diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 6c2b139..446a91b 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -344,13 +344,13 @@ registration. .. code:: c++ // Create a simple Relay program. - auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool()); - auto x = relay::VarNode::make("x", relay::Type()); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + auto tensor_type = relay::TensorType({}, tvm::Bool()); + auto x = relay::Var("x", relay::Type()); + auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); - auto y = relay::VarNode::make("y", tensor_type); + auto y = relay::Var("y", tensor_type); auto call = relay::Call(f, tvm::Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); + auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); // Create a module for optimization. auto mod = IRModule::FromExpr(fx); diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 1ed6848..84d6a7b 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -97,14 +97,14 @@ class SpanNode : public Object { equal(col_offset, other->col_offset); } - TVM_DLL static Span make(SourceName source, int lineno, int col_offset); - static constexpr const char* _type_key = "Span"; TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; class Span : public ObjectRef { public: + TVM_DLL Span(SourceName source, int lineno, int col_offset); + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index c161cc9..4b7037a 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -177,13 +177,23 @@ class PlaceholderOpNode : public OperationNode { v->Visit("shape", &shape); v->Visit("dtype", &dtype); } - static Operation make(std::string name, Array shape, DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! + * \brief Managed reference to PlaceholderOpNode + * \sa PlaceholderOpNode + */ +class PlaceholderOp : public Operation { + public: + TVM_DLL PlaceholderOp(std::string name, Array shape, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); +}; + +/*! * \brief A Compute op that compute a tensor on certain domain. * This is the base class for ComputeOp (operating on a scalar at a time) and * TensorComputeOp (operating on a TensorSlice at a time) @@ -237,14 +247,24 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } - static Operation make(std::string name, std::string tag, Map attrs, - Array axis, Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; /*! + * \brief Managed reference to ComputeOpNode + * \sa ComputeOpNode + */ +class ComputeOp : public Operation { + public: + TVM_DLL ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); +}; + +/*! * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. */ class TensorComputeOpNode : public BaseComputeOpNode { @@ -285,16 +305,26 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("input_regions", &input_regions); v->Visit("scalar_inputs", &scalar_inputs); } - static Operation make(std::string name, std::string tag, Array axis, - Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, - Array tensors, Array regions, - Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); }; /*! + * \brief Managed reference to TensorComputeOpNode + * \sa TensorComputeOpNode + */ +class TensorComputeOp : public Operation { + public: + TVM_DLL TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode); +}; + +/*! * \brief Symbolic scan. */ class ScanOpNode : public OperationNode { @@ -353,15 +383,25 @@ class ScanOpNode : public OperationNode { v->Visit("inputs", &inputs); v->Visit("spatial_axis_", &spatial_axis_); } - static Operation make(std::string name, std::string tag, Map attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array input); static constexpr const char* _type_key = "ScanOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; /*! + * \brief Managed reference to ScanOpNode + * \sa ScanOpNode + */ +class ScanOp : public Operation { + public: + TVM_DLL ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array input); + + TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); +}; + +/*! * \brief External computation that cannot be splitted. */ class ExternOpNode : public OperationNode { @@ -404,15 +444,25 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body); static constexpr const char* _type_key = "ExternOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; /*! + * \brief Managed reference to ExternOpNode + * \sa ExternOpNode + */ +class ExternOp : public Operation { + public: + TVM_DLL ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); +}; + +/*! * \brief A computation operator that generated by hybrid script. */ class HybridOpNode : public OperationNode { @@ -459,14 +509,24 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, std::string tag, Map attrs, - Array inputs, Array outputs, Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; /*! + * \brief Managed reference to HybridOpNode + * \sa HybridOpNode + */ +class HybridOp : public Operation { + public: + TVM_DLL HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode); +}; + +/*! * \brief Construct a new Var expression * \param name_hint The name hint for the expression * \param t The type of the expression diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index f74a008..ee4fb33 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -278,6 +278,12 @@ class Schedule : public ObjectRef { Schedule() {} explicit Schedule(ObjectPtr n) : ObjectRef(n) {} /*! + * \brief Create a schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + * \return sch The created Schedule. + */ + TVM_DLL explicit Schedule(Array ops); + /*! * \brief Get a copy of current schedule. * \return The copied schedule. */ @@ -553,13 +559,6 @@ class ScheduleNode : public Object { */ TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } - /*! - * \brief Create a schedule for array of ops(and their dependencies). - * \param ops The ops to be scheduled. - * \return sch The created Schedule. - */ - TVM_DLL static Schedule make(Array ops); - static constexpr const char* _type_key = "Schedule"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); }; @@ -569,7 +568,7 @@ class ScheduleNode : public Object { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ -inline Schedule create_schedule(Array ops) { return ScheduleNode::make(ops); } +inline Schedule create_schedule(Array ops) { return Schedule(ops); } /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Object { @@ -648,14 +647,22 @@ class SplitNode : public IterVarRelationNode { v->Visit("nparts", &nparts); } - static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, - PrimExpr nparts); - static constexpr const char* _type_key = "Split"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); }; /*! + * \brief Managed reference to SplitNode + * \sa SplitNode + */ +class Split : public IterVarRelation { + public: + TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); + + TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); +}; + +/*! * \brief Fuse two domains into one domain. */ class FuseNode : public IterVarRelationNode { @@ -673,13 +680,22 @@ class FuseNode : public IterVarRelationNode { v->Visit("fused", &fused); } - static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused); - static constexpr const char* _type_key = "Fuse"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); }; /*! + * \brief Managed reference to FuseNode + * \sa FuseNode + */ +class Fuse : public IterVarRelation { + public: + TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused); + + TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode); +}; + +/*! * \brief Rebase the iteration to make min to be 0. * This is useful to normalize the Schedule * to make every leaf variable's min to be 0. @@ -696,13 +712,22 @@ class RebaseNode : public IterVarRelationNode { v->Visit("rebased", &rebased); } - static IterVarRelation make(IterVar parent, IterVar rebased); - static constexpr const char* _type_key = "Rebase"; TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; /*! + * \brief Managed reference to RebaseNode + * \sa RebaseNode + */ +class Rebase : public IterVarRelation { + public: + TVM_DLL Rebase(IterVar parent, IterVar rebased); + + TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode); +}; + +/*! * \brief Singleton iterator [0, 1) */ class SingletonNode : public IterVarRelationNode { @@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode { void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } - static IterVarRelation make(IterVar iter); - static constexpr const char* _type_key = "Singleton"; TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to SingletonNode + * \sa SingletonNode + */ +class Singleton : public IterVarRelation { + public: + TVM_DLL explicit Singleton(IterVar iter); + + TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); +}; + /*! \brief Container for specialization conditions. */ class SpecializedConditionNode : public Object { public: diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 045d186..0c4af4b 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -40,25 +40,68 @@ namespace te { using arith::IntSet; using namespace tvm::tir; -// Internal node container of Tensor -class TensorNode; // internal node container for Operation class OperationNode; +/*! \brief Operation that produces tensors */ +class Operation : public tir::FunctionRef { + public: + /*! \brief default constructor */ + Operation() {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const OperationNode* operator->() const; + /*! + * \brief get the i-th output of the operation. + * \param i the output index. + * \return The i-th output. + */ + TVM_DLL Tensor output(size_t i) const; + /*! \brief specify container node */ + using ContainerType = OperationNode; +}; + +/*! \brief Node to represent a tensor */ +class TensorNode : public DataProducerNode { + public: + /*! \brief The shape of the tensor */ + Array shape; + /*! \brief data type in the content of the tensor */ + DataType dtype; + /*! \brief the source operation, can be None */ + Operation op; + /*! \brief the output index from source operation */ + int value_index{0}; + /*! \brief constructor */ + TensorNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("op", &op); + v->Visit("value_index", &value_index); + } + + Array GetShape() const final { return shape; } + + DataType GetDataType() const final { return dtype; } + + TVM_DLL String GetNameHint() const final; + + static constexpr const char* _type_key = "Tensor"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); +}; + /*! * \brief Tensor structure representing a possible input, * or intermediate computation result. */ class Tensor : public DataProducer { public: - /*! \brief default constructor, used internally */ - Tensor() {} - explicit Tensor(ObjectPtr n) : DataProducer(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorNode* operator->() const; + TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. @@ -131,69 +174,11 @@ class Tensor : public DataProducer { * \return the subsequent slice. */ inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } - /*! \brief specify container node */ - using ContainerType = TensorNode; -}; -/*! \brief Operation that produces tensors */ -class Operation : public tir::FunctionRef { - public: - /*! \brief default constructor */ - Operation() {} - explicit Operation(ObjectPtr n) : FunctionRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const OperationNode* operator->() const; - /*! - * \brief get the i-th output of the operation. - * \param i the output index. - * \return The i-th output. - */ - TVM_DLL Tensor output(size_t i) const; - /*! \brief specify container node */ - using ContainerType = OperationNode; -}; - -/*! \brief Node to represent a tensor */ -class TensorNode : public DataProducerNode { - public: - /*! \brief The shape of the tensor */ - Array shape; - /*! \brief data type in the content of the tensor */ - DataType dtype; - /*! \brief the source operation, can be None */ - Operation op; - /*! \brief the output index from source operation */ - int value_index{0}; - /*! \brief constructor */ - TensorNode() {} - - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("value_index", &value_index); - } - - Array GetShape() const final { return shape; } - - DataType GetDataType() const final { return dtype; } - - TVM_DLL String GetNameHint() const final; - - TVM_DLL static Tensor make(Array shape, DataType dtype, Operation op, int value_index); - - static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); + TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); }; // Implementations of inline functions -inline const TensorNode* Tensor::operator->() const { - return static_cast(get()); -} - inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index 7e76efe..22f29de 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -32,24 +32,6 @@ namespace tvm { namespace te { -// Internal node container of tensor intrinsics. -class TensorIntrinNode; - -/*! \brief Tensor intrinsic node. */ -class TensorIntrin : public ObjectRef { - public: - TensorIntrin() {} - explicit TensorIntrin(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorIntrinNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = TensorIntrinNode; -}; - /*! \brief Node to represent a Tensor intrinsic operator */ class TensorIntrinNode : public Object { public: @@ -100,17 +82,21 @@ class TensorIntrinNode : public Object { v->Visit("reduce_update", &reduce_update); } - TVM_DLL static TensorIntrin make(std::string name, Operation op, Array inputs, - Array buffers, Array scalar_params, Stmt body, - Stmt reduce_init, Stmt reduce_update); - static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; -inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(get()); -} +/*! + * \brief Managed reference to TensorIntrinNode + * \sa TensorIntrinNode + */ +class TensorIntrin : public ObjectRef { + public: + TVM_DLL TensorIntrin(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); +}; class TensorIntrinCallNode : public Object { public: diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 6904f2a..5b07cc5 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -32,8 +32,6 @@ namespace tvm { namespace tir { -// Internal node container Buffer -class BufferNode; // forward declare Stmt class Stmt; @@ -45,62 +43,6 @@ enum BufferType : int { kAutoBroadcast = 2, }; -/*! - * \brief Buffer is a symbolic n-darray structure. - * It is a composition of primitive symbolic types, - * used to specify the memory layout of the Tensor used in program input. - */ -class Buffer : public ObjectRef { - public: - Buffer() {} - explicit Buffer(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Return a new buffer that is equivalent with current one - * but always add stride field. - * \return The strided version of the buffer. - */ - TVM_DLL Buffer MakeStrideView() const; - /*! - * \brief Make a new symbolic buffer representing a slice of the buffer. - * \param begins The beginning position of each dimension. - * \param extents The extent of each dimension. - * \note This function will make target buffer as compact as possible. - * If stride is not needed in the slice, it won't be presented - * \return the result buffer. - */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; - /*! - * \brief Get access ptr to the entire buffer. - * \param access_mask The access mask - * \param ptr_type The type of the pointer. - * \param content_lanes The number of lanes for the (data) type. - * \param offset The offset of ptr. - */ - TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), - int content_lanes = 1, - PrimExpr offset = IntImm(DataType::Int(32), 0)) const; - /*! - * \brief Create an Expr that does a vector load at begin index. - * \param begin The beginning index - * \param dtype The data type to be loaded. - */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; - /*! - * \brief Create a Stmt that does a vector store at begin index. - * \param begin The beginning index - * \param value The value to be stored. - */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BufferNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BufferNode; -}; - /*! \brief Node to represent a buffer */ class BufferNode : public Object { public: @@ -176,22 +118,65 @@ class BufferNode : public Object { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } - // User can specify data_alignment and offset_factor to be 0 - // A default value will be picked. - TVM_DLL static Buffer make(Var ptr, DataType dtype, Array shape, - Array strides, PrimExpr elem_offset, std::string name, - std::string scope, int data_alignment, int offset_factor, - BufferType buffer_type); - static constexpr const char* _type_key = "Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; -inline const BufferNode* Buffer::operator->() const { - return static_cast(get()); -} +/*! + * \brief Buffer is a symbolic n-darray structure. + * It is a composition of primitive symbolic types, + * used to specify the memory layout of the Tensor used in program input. + */ +class Buffer : public ObjectRef { + public: + // User can specify data_alignment and offset_factor to be 0 + // A default value will be picked. + TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + int offset_factor, BufferType buffer_type); + + /*! + * \brief Return a new buffer that is equivalent with current one + * but always add stride field. + * \return The strided version of the buffer. + */ + TVM_DLL Buffer MakeStrideView() const; + /*! + * \brief Make a new symbolic buffer representing a slice of the buffer. + * \param begins The beginning position of each dimension. + * \param extents The extent of each dimension. + * \note This function will make target buffer as compact as possible. + * If stride is not needed in the slice, it won't be presented + * \return the result buffer. + */ + TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + /*! + * \brief Get access ptr to the entire buffer. + * \param access_mask The access mask + * \param ptr_type The type of the pointer. + * \param content_lanes The number of lanes for the (data) type. + * \param offset The offset of ptr. + */ + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), + int content_lanes = 1, + PrimExpr offset = IntImm(DataType::Int(32), 0)) const; + /*! + * \brief Create an Expr that does a vector load at begin index. + * \param begin The beginning index + * \param dtype The data type to be loaded. + */ + TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; + /*! + * \brief Create a Stmt that does a vector store at begin index. + * \param begin The beginning index + * \param value The value to be stored. + */ + TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); +}; /*! * \brief Construct a new buffer given shape, and dtype. @@ -199,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const { * \param dtype The content data type. * \param name The name of the buffer * \return The created buffer. - * \sa BufferNode::make for complete constructor. + * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), std::string name = "buffer"); diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 0a20db6..f705247 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -37,6 +37,8 @@ namespace tvm { namespace tir { +class Layout; + class LayoutAxis { public: static const LayoutAxis& Get(const char name); @@ -45,7 +47,7 @@ class LayoutAxis { static const LayoutAxis& Get(const tir::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). - static const LayoutAxis& make(const std::string& name); + static const LayoutAxis& Get(const std::string& name); inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } inline std::string name() const { return std::string(1, name_); } @@ -83,8 +85,16 @@ class LayoutAxis { const char name_; }; -class Layout; -// Internal node container Buffer +/*! + * \brief Layout is to describe how data is organized within an N-dimention tensor. + * It is composed of upper cases, lower cases and numbers, + * where upper case indicates a primal axis and + * the corresponding lower case with factor size indicates the subordinate axis. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). + * Layout for scalar is defined, while both its name and axes have size 0. + */ class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ @@ -102,29 +112,16 @@ class LayoutNode : public Object { v->Visit("axes", &axes); } - TVM_DLL static Layout make(const std::string& layout); - static constexpr const char* _type_key = "Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; /*! - * \brief Layout is to describe how data is organized within an N-dimention tensor. - * It is composed of upper cases, lower cases and numbers, - * where upper case indicates a primal axis and - * the corresponding lower case with factor size indicates the subordinate axis. - * For example, NCHW16c can describe a 5-D tensor of - * [batch_size, channel, height, width, channel_block]. - * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). - * Layout for scalar is defined, while both its name and axes have size 0. + * \brief Managed reference to LayoutNode + * \sa LayoutNode */ class Layout : public ObjectRef { public: - explicit Layout(ObjectPtr n) : ObjectRef(n) {} - - /*! \brief default constructor */ - Layout() = default; - explicit Layout(const Array& axes); /*! \brief construct from a string */ @@ -138,13 +135,7 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) - - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - const LayoutNode* operator->() const { return static_cast(get()); } + TVM_DLL Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container @@ -292,10 +283,9 @@ class Layout : public ObjectRef { return os; } - using ContainerType = LayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); }; -class BijectiveLayout; // Internal node container BijectiveLayout class BijectiveLayoutNode : public Object { public: @@ -329,8 +319,6 @@ class BijectiveLayoutNode : public Object { */ class BijectiveLayout : public ObjectRef { public: - BijectiveLayout() = default; - explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} /*! * \brief The constructor * \param src_layout The source layout @@ -347,19 +335,9 @@ class BijectiveLayout : public ObjectRef { // Given the destination indices, recover the source indices. TVM_DLL Array BackwardIndex(const Array& dst_index) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BijectiveLayoutNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BijectiveLayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; -inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(get()); -} } // namespace tir } // namespace tvm diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c667b49..9d2a11c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -88,8 +88,8 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std elem_offset = PrimExpr(); } - return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor, buffer_type); + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + offset_factor, buffer_type); } void GetBinds(const Array& args, bool compact, diff --git a/src/ir/span.cc b/src/ir/span.cc index 742c985..565439f 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -61,17 +61,19 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode) return static_cast(n)->name; }); -Span SpanNode::make(SourceName source, int lineno, int col_offset) { +Span::Span(SourceName source, int lineno, int col_offset) { auto n = make_object(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; - return Span(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { + return Span(source, lineno, col_offset); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index be749fd..3687b75 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -217,8 +217,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); - outputs.push_back( - te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d9be91d..9a75c0a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -181,22 +181,25 @@ class InterpreterStateObj : public Object { v->Visit("stack", &stack); } - static InterpreterState make(Expr current_expr, Stack stack); - static constexpr const char* _type_key = "relay.InterpreterState"; TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); }; class InterpreterState : public ObjectRef { public: + using Frame = tvm::Map; + using Stack = tvm::Array; + + InterpreterState(Expr current_expr, Stack stack); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); }; -InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { +InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) { ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); - return InterpreterState(n); + data_ = std::move(n); } // NOTE: the current interpreter assumes A-normal form. @@ -701,7 +704,7 @@ class Interpreter : public ExprFunctor, InterpreterStateObj::Frame frame = fr.locals; stack.push_back(frame); } - auto state = InterpreterStateObj::make(e, stack); + auto state = InterpreterState(e, stack); return state; } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 92e12b5..1917096 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -112,11 +112,11 @@ struct StorageScope { } } /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static StorageScope make(const std::string& s) { + static StorageScope Create(const std::string& s) { StorageScope r; if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; @@ -153,11 +153,11 @@ struct ThreadScope { /*! \brief the dimension index under the rank */ int dim_index{0}; /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static ThreadScope make(const std::string& s) { + static ThreadScope Create(const std::string& s) { ThreadScope r; if (s == "vthread" || s == "cthread") { // virtual thread at the same level as local @@ -199,7 +199,7 @@ class ThreadAxisConfig { std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 280c999..8e6b3a2 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -125,7 +125,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b43e988..3af9fc3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1260,7 +1260,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 353f322..bc47ce1 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -101,7 +101,7 @@ class CodeGenNVPTX : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index e381afb..2c26ee9 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -122,7 +122,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); work_dim = std::max(work_dim, scope.dim_index + 1); } if (work_dim != 0) { diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 746d418..8616853 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -75,7 +75,7 @@ std::string CodeGenOpenCL::Finish() { void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); std::ostringstream os; if (ts.rank == 1) { os << "get_local_id(" << ts.dim_index << ")"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 364a62f..699d395 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -92,7 +92,7 @@ void CodeGenSPIRV::InitFuncState() { } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); @@ -580,7 +580,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); - storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); + storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index a8a9a0b..1834aa3 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -340,8 +340,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_bodies.push_back(new_body); } - auto new_op = - ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); // Jacobian shape = output.shape + input.shape Array new_shape = output->shape; @@ -349,7 +348,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_shape.push_back(e); } - return TensorNode::make(new_shape, output->dtype, new_op, value_index); + return Tensor(new_shape, output->dtype, new_op, value_index); } } // namespace te diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 7f957b5..1fc0520 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -99,7 +99,7 @@ Tensor compute(Array shape, FCompute fcompute, std::string name, std:: args.push_back(axis.back()->var); } - return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0); + return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); } Array compute(Array shape, FBatchCompute fcompute, std::string name, @@ -116,7 +116,7 @@ Array compute(Array shape, FBatchCompute fcompute, std::string args.push_back(axis.back()->var); } - Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args)); + Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); Array outputs; for (int idx = 0; idx < op->num_outputs(); ++idx) { outputs.push_back(op.output(idx)); @@ -124,8 +124,8 @@ Array compute(Array shape, FBatchCompute fcompute, std::string return outputs; } -Operation ComputeOpNode::make(std::string name, std::string tag, Map attrs, - Array axis, Array body) { +ComputeOp::ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body) { if (!attrs.defined()) { attrs = Map(); } @@ -140,10 +140,13 @@ Operation ComputeOpNode::make(std::string name, std::string tag, Mapreduce_axis = reduce->axis; } VerifyComputeOp(n.get()); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.ComputeOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array axis, + Array body) { return ComputeOp(name, tag, attrs, axis, body); }); // The schedule related logics Array ComputeOpNode::InputTensors() const { @@ -188,7 +191,7 @@ Operation ComputeOpNode::ReplaceInputs(const Operation& self, UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); } if (!arr.same_as(this->body)) { - return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr); + return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr); } else { return self; } @@ -331,7 +334,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { // grab the nest structure - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); // Normal loop structure n.init_nest.emplace_back(MakeIfNest(n.init_predicates)); n.main_nest.emplace_back(MakeIfNest(n.main_predicates)); @@ -424,9 +427,9 @@ Stmt ComputeOpNode::BuildProvide(const Stage& stage, } } -ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 610c014..2661eb9 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -59,9 +59,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ - static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); + static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); }; /*! diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 0933e30..ef55c44 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -50,9 +50,9 @@ DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } -Operation ExternOpNode::make(std::string name, std::string tag, Map attrs, - Array inputs, Array input_placeholders, - Array output_placeholders, Stmt body) { +ExternOp::ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -73,10 +73,15 @@ Operation ExternOpNode::make(std::string name, std::string tag, Mapinput_placeholders = std::move(input_placeholders); n->output_placeholders = std::move(output_placeholders); n->body = std::move(body); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make); +TVM_REGISTER_GLOBAL("te.ExternOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body); + }); Array ExternOpNode::InputTensors() const { return inputs; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 9b3a79f..9be474d 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -57,8 +57,8 @@ DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } -Operation HybridOpNode::make(std::string name, std::string tag, Map attrs, - Array inputs, Array outputs, Stmt body) { +HybridOp::HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { attrs = Map(); } @@ -70,11 +70,13 @@ Operation HybridOpNode::make(std::string name, std::string tag, Mapoutputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); - Operation res = Operation(n); - return res; + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make); +TVM_REGISTER_GLOBAL("te.HybridOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, + Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); }); Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index f1b0527..61b7826 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -156,9 +156,9 @@ std::vector > MakeLoopNest(const Stage& stage, if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { - runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); if (stage->scope == "" || - static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { + static_cast(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) { value_map[iv] = var; } else if (stage->scope == "warp" && ts.rank == 1) { // To determine whether a thread index is inside or outside a warp, we need diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 9c536eb..5b7ede3 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -50,16 +50,16 @@ Array PlaceholderOpNode::output_shape(size_t i) const { return shape; } -Operation PlaceholderOpNode::make(std::string name, Array shape, DataType dtype) { +PlaceholderOp::PlaceholderOp(std::string name, Array shape, DataType dtype) { auto n = make_object(); n->name = name; n->shape = shape; n->dtype = dtype; - return Operation(n); + data_ = std::move(n); } Tensor placeholder(Array shape, DataType dtype, std::string name) { - return PlaceholderOpNode::make(name, shape, dtype).output(0); + return PlaceholderOp(name, shape, dtype).output(0); } TVM_REGISTER_GLOBAL("te.Placeholder") diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 45e86e2..cc86d0f 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -55,9 +55,9 @@ Array ScanOpNode::output_shape(size_t i) const { return state_placeholder[i]->shape; } -Operation ScanOpNode::make(std::string name, std::string tag, Map attrs, - IterVar axis, Array init, Array update, - Array state_placeholder, Array inputs) { +ScanOp::ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array inputs) { if (!attrs.defined()) { attrs = Map(); } @@ -104,10 +104,15 @@ Operation ScanOpNode::make(std::string name, std::string tag, Mapupdate = std::move(update); n->state_placeholder = std::move(state_placeholder); n->inputs = std::move(inputs); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make); +TVM_REGISTER_GLOBAL("te.ScanOp") + .set_body_typed([](std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { + return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); + }); Array scan(Array init, Array update, Array state_placeholder, Array inputs, std::string name, std::string tag, @@ -115,8 +120,7 @@ Array scan(Array init, Array update, Array state IterVar scan_axis = IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), Var(name + ".idx"), kOrdered); - Operation op = - ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); + Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index c8dfce8..8d5265b 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -52,10 +52,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Operation TensorComputeOpNode::make(std::string name, std::string tag, Array axis, - Array reduce_axis, int schedulable_ndim, - TensorIntrin intrin, Array tensors, - Array regions, Array scalar_inputs) { +TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, + TensorIntrin intrin, Array tensors, Array regions, + Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -66,10 +66,17 @@ Operation TensorComputeOpNode::make(std::string name, std::string tag, Arrayinputs = std::move(tensors); n->input_regions = std::move(regions); n->scalar_inputs = std::move(scalar_inputs); - return Operation(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make); +TVM_REGISTER_GLOBAL("te.TensorComputeOp") + .set_body_typed([](std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs) { + return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors, + regions, scalar_inputs); + }); Array TensorComputeOpNode::InputTensors() const { return inputs; } @@ -191,7 +198,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, binder.BindArray(sp_expr, user_expr, this->name); size_t tloc = stage->leaf_iter_vars.size(); - ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop); if (this->reduce_axis.size() == 0) { std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index af4b08e..82832c9 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -347,7 +347,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); // Start bind data. diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 01d4f93..099f488 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -59,7 +59,7 @@ bool NeedRelax(const IterVar& iv, bool found_attach, if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); // When there is warp memory // threadIdx.x must be set to be warp index. @@ -72,14 +72,14 @@ bool NeedRelax(const IterVar& iv, bool found_attach, // infer storage scope, if not given StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { - return StorageScope::make(stage->scope); + return StorageScope::Create(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { - max_rank = std::max(max_rank, ThreadScope::make(tag).rank); + max_rank = std::max(max_rank, ThreadScope::Create(tag).rank); } } StorageScope s; diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c360513..af72d3b 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -324,16 +324,16 @@ Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_a args.push_back(value_map.at(iv)); } } - Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag, - compute->attrs, new_axis, body_list); + Operation cache_op = + ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, - compute->axis, cache_expr_list); + Operation orig_new_op = + ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } @@ -380,10 +380,10 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } - Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag, - new_axis, tensor_op->reduce_axis, - tensor_op->schedulable_ndim, tensor_op->intrin, - tensor_op->inputs, new_regions, new_scalar_inputs); + Operation cache_op = + TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin, + tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; @@ -419,7 +419,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& te cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = - ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); + ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } @@ -468,7 +468,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { if (idx < leaf_vars->size()) { // insert rebase IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type); - s->relations.push_back(RebaseNode::make(iv, rebased)); + s->relations.push_back(te::Rebase(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); } @@ -583,8 +583,7 @@ void InjectInline(ScheduleNode* sch) { CHECK(compute); Operation op = s->op; if (changed[i]) { - op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis, - new_body[i]); + op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -596,8 +595,8 @@ void InjectInline(ScheduleNode* sch) { } else if (hybrid_changed[i]) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); CHECK(hybrid); - Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 24d9102..707d52f 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -55,8 +55,8 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) return 0; } -void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer, - IterVar* p_inner) { +void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, + IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) @@ -69,7 +69,7 @@ void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, It Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); - self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts)); + self->relations.push_back(Split(parent, outer, inner, factor, nparts)); // add vars to all vars all_vars.push_back(outer); all_vars.push_back(inner); @@ -206,13 +206,13 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); + SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); + SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } @@ -242,7 +242,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT } CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; - self->relations.push_back(FuseNode::make(outer, inner, fused)); + self->relations.push_back(Fuse(outer, inner, fused)); all_vars.push_back(fused); leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); @@ -263,7 +263,7 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* // insert at the outer most loop IterVar singleton = IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); - self->relations.push_back(SingletonNode::make(singleton)); + self->relations.push_back(Singleton(singleton)); Array& all_vars = self->all_iter_vars; Array& leaf_vars = self->leaf_iter_vars; all_vars.push_back(singleton); @@ -624,9 +624,9 @@ bool ScheduleNode::Contain(const Operation& op) const { return stage_map.find(op) != stage_map.end(); } -Schedule ScheduleNode::make(Array ops) { +Schedule::Schedule(Array ops) { auto n = make_object(); - Schedule sch(n); + data_ = n; n->outputs = ops; auto g = te::CreateReadGraph(n->outputs); Array post_order = te::PostDFSOrder(n->outputs, g); @@ -650,7 +650,7 @@ Schedule ScheduleNode::make(Array ops) { inputs.push_back(t); } // Create the scan group. - Stage scan_group = sch.create_group(scan->update, inputs, false); + Stage scan_group = this->create_group(scan->update, inputs, false); scan_group->attach_type = kScanUpdate; scan_group->attach_stage = stage; @@ -660,39 +660,37 @@ Schedule ScheduleNode::make(Array ops) { } } } - return sch; } -IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, - PrimExpr nparts) { +Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; n->factor = factor; n->nparts = nparts; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) { +Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) { auto n = make_object(); n->outer = outer; n->inner = inner; n->fused = fused; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { +Rebase::Rebase(IterVar parent, IterVar rebased) { auto n = make_object(); n->parent = parent; n->rebased = rebased; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation SingletonNode::make(IterVar iter) { +Singleton::Singleton(IterVar iter) { auto n = make_object(); n->iter = iter; - return IterVarRelation(n); + data_ = std::move(n); } SpecializedCondition::SpecializedCondition(Array conditions) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 7e7f648..e66b963 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -66,28 +66,32 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, DataType dtype, Operation op, int value_index) { +Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; n->value_index = value_index; - return Tensor(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.Tensor") + .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); + +TVM_REGISTER_NODE_TYPE(TensorNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* t = static_cast(node.get()); p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; }); -TVM_REGISTER_NODE_TYPE(TensorNode); - // TensorIntrin - -TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, - Array buffers, Array scalar_params, Stmt body, - Stmt reduce_init, Stmt reduce_update) { +TensorIntrin::TensorIntrin(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); @@ -97,17 +101,24 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Arraybody = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); - return TensorIntrin(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.TensorIntrin") + .set_body_typed([](std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { + return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, + reduce_update); + }); + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; }); -TVM_REGISTER_NODE_TYPE(TensorIntrinNode); - // TensorIntrinCall TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array tensors, Array regions, Array reduce_axis, @@ -135,10 +146,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); -TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make); - +// Other tensor ops. TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 46f4160..2b64f11 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,8 +45,8 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { } Buffer decl_buffer(Array shape, DataType dtype, std::string name) { - return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), - PrimExpr(), name, "", 0, 0, kDefault); + return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault); } // Split the given expression w.r.t the add operator @@ -348,8 +348,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", - n->scope, n->data_alignment, 0, n->buffer_type); + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + n->data_alignment, 0, n->buffer_type); } PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, @@ -379,9 +379,9 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } -Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, std::string name, std::string scope, - int data_alignment, int offset_factor, BufferType buffer_type) { +Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); n->dtype = dtype; @@ -410,7 +410,7 @@ Buffer BufferNode::make(Var data, DataType dtype, Array shape, Array

strides.push_back(Var("stride", n->shape[i].dtype())); } } - return Buffer(n); + data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -425,8 +425,8 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); auto buffer_type = args[9].operator std::string(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], - args[8], type); + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); }); TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 1f17c35..bc777db 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -66,7 +66,7 @@ const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { return LayoutAxis::Get(axis[0]); } -const LayoutAxis& LayoutAxis::make(const std::string& name) { +const LayoutAxis& LayoutAxis::Get(const std::string& name) { CHECK_EQ(name.length(), 1) << "Invalid axis " << name; return LayoutAxis::Get(name[0]); } @@ -144,8 +144,6 @@ Layout::Layout(const std::string& name) { // NOLINT(*) data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { return Layout(layout); } - Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); if (len == 0) return Layout(Array()); @@ -365,15 +363,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make); +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); }); TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::make(axis)); + return layout.IndexOf(LayoutAxis::Get(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") .set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); + return layout.FactorOf(LayoutAxis::Get(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index c201b8f..416358c 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -147,12 +147,12 @@ class CopyIntrinInjector : public StmtMutator { src_strides.push_back(make_const(DataType::Int(32), 1)); dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, - store_strides[loop_var_size], store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); - Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides, - src_elem_offset, load->buffer_var->name_hint, - GetStorageScope(load->buffer_var.get()), 0, 0, kDefault); + Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, + store_strides[loop_var_size], store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, + load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, + kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 7dbf0fc..3b2580c 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -113,7 +113,7 @@ class CandidateSelector final : public StmtExprVisitor { const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); @@ -361,7 +361,7 @@ class LoopPartitioner : public StmtMutator { } // normal path when loop parittion fails. - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 0b87757..9d6b47a 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -63,7 +63,7 @@ class StorageAccessInfoLower : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); StorageEntry e; e.scope = scope; if (scope.tag.length() != 0) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 058014c..ee17f08 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -165,7 +165,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (const AttrStmtNode* attr : thread_extents_) { ThreadEntry e; IterVar iv = Downcast(attr->node); - e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.iv = iv; CHECK_LE(e.scope.rank, 1); CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; @@ -516,7 +516,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { IterVar iv = Downcast(op->node); ThreadEntry e; - e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.extent = 0; if (auto ptr = op->value.as()) { e.extent = static_cast(ptr->value); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index a0ddf26..92f9ab5 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -367,7 +367,7 @@ class WarpMemoryRewriter : private StmtMutator { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 3a42137..20cc640 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -92,7 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); @@ -215,11 +215,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { - StorageScope scope = StorageScope::make(s); + StorageScope scope = StorageScope::Create(s); AccessEntry e; e.threads = env_threads(); e.type = kSync; - e.scope = StorageScope::make(s); + e.scope = StorageScope::Create(s); curr_stmt_.access.emplace_back(std::move(e)); } } else { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 4c3de58..e29d978 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -91,7 +91,7 @@ class StorageFlattener : public StmtExprMutator { return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ThreadScope ts = ThreadScope::make(iv->thread_tag); + ThreadScope ts = ThreadScope::Create(iv->thread_tag); curr_thread_scope_.push_back(ts); Stmt stmt = StmtExprMutator::VisitStmt_(op); curr_thread_scope_.pop_back(); @@ -165,7 +165,7 @@ class StorageFlattener : public StmtExprMutator { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } } else { - skey = StorageScope::make(strkey); + skey = StorageScope::Create(strkey); } // use small alignment for small arrays @@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - skey.to_string(), align, 0, kDefault); + e.buffer = + Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, + strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 2d09e8b..283ab0f 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -178,7 +178,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); + alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index e5b4bdd..b8575d2 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -251,7 +251,7 @@ class ThreadSyncInserter : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -321,7 +321,7 @@ class ThreadSyncInserter : public StmtExprMutator { num_work_dim_ = thread_extents_.size(); for (const AttrStmtNode* attr : thread_extents_) { IterVar iv = Downcast(attr->node); - runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); if (s.rank == 0) { num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { @@ -353,7 +353,7 @@ class ThreadSyncInserter : public StmtExprMutator { }; Stmt ThreadSync(Stmt stmt, std::string storage_scope) { - StorageScope sync_scope = StorageScope::make(storage_scope); + StorageScope sync_scope = StorageScope::Create(storage_scope); ThreadSyncPlanner planner(sync_scope); planner(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 8823134..70709b0 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -51,11 +51,11 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* TEST(MicroStandaloneRuntime, BuildModule) { using namespace tvm; auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32)); - auto a = relay::VarNode::make("a", tensor_type); - auto b = relay::VarNode::make("b", tensor_type); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); auto add_op = relay::Op::Get("add"); auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); - auto c = relay::VarNode::make("c", tensor_type); + auto c = relay::Var("c", tensor_type); auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 25b3800..b84fbc7 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -46,8 +46,7 @@ using namespace tvm::te; inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, - kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); } /*! @@ -93,8 +92,7 @@ inline Array make_extern(const Array >& out_shapes, auto body = fextern(input_placeholders, output_placeholders); auto body_stmt = tvm::tir::Evaluate(body); - auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders, - body_stmt); + auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 9aa4e35..e830e09 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1194,8 +1194,8 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, const std::string name = "T_layout_trans", const std::string tag = kInjective) { - Layout src_layout_struct = LayoutNode::make(src_layout); - Layout dst_layout_struct = LayoutNode::make(dst_layout); + Layout src_layout_struct(src_layout); + Layout dst_layout_struct(dst_layout); if (src_layout_struct.Equals(dst_layout_struct)) { return src;