[REFACTOR][API-Change] Migrate all Object construction to constructor. (#5784)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 12 Jun 2020 15:24:18 +0000 (08:24 -0700)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 15:24:18 +0000 (08:24 -0700)
This PR migrates all the remaining object constructions to the new constructor style
that is consistent with the rest of the codebase and changes the affected files accordingly.

Other changes:

- ThreadScope::make -> ThreadScope::Create
- StorageScope::make -> StorageScope::Create

49 files changed:
docs/dev/codebase_walkthrough.rst
docs/dev/relay_add_pass.rst
docs/dev/relay_pass_infra.rst
include/tvm/ir/span.h
include/tvm/te/operation.h
include/tvm/te/schedule.h
include/tvm/te/tensor.h
include/tvm/te/tensor_intrin.h
include/tvm/tir/buffer.h
include/tvm/tir/data_layout.h
src/driver/driver_api.cc
src/ir/span.cc
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc
src/runtime/thread_storage_scope.h
src/target/llvm/codegen_amdgpu.cc
src/target/llvm/codegen_llvm.cc
src/target/llvm/codegen_nvptx.cc
src/target/source/codegen_metal.cc
src/target/source/codegen_opencl.cc
src/target/spirv/codegen_spirv.cc
src/te/autodiff/jacobian.cc
src/te/operation/compute_op.cc
src/te/operation/compute_op.h
src/te/operation/extern_op.cc
src/te/operation/hybrid_op.cc
src/te/operation/op_util.cc
src/te/operation/placeholder_op.cc
src/te/operation/scan_op.cc
src/te/operation/tensor_compute_op.cc
src/te/operation/tensorize.cc
src/te/schedule/bound.cc
src/te/schedule/schedule_dataflow_rewrite.cc
src/te/schedule/schedule_lang.cc
src/te/tensor.cc
src/tir/ir/buffer.cc
src/tir/ir/data_layout.cc
src/tir/transforms/inject_copy_intrin.cc
src/tir/transforms/loop_partition.cc
src/tir/transforms/lower_device_storage_access_info.cc
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/lower_warp_memory.cc
src/tir/transforms/storage_access.cc
src/tir/transforms/storage_flatten.cc
src/tir/transforms/storage_rewrite.cc
src/tir/transforms/thread_storage_sync.cc
tests/cpp/utvm_runtime_standalone_test.cc
topi/include/topi/detail/extern.h
topi/include/topi/transform.h

index a66328f..8674c8e 100644 (file)
@@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
 ::
 
    inline Schedule create_schedule(Array<Operation> ops) {
-     return ScheduleNode::make(ops);
+     return Schedule(ops);
    }
 
 ``Schedule`` consists of collections of ``Stage`` and output ``Operation``.
index 2fc4636..a82ae4f 100644 (file)
@@ -138,7 +138,7 @@ is shown below.
       if (g->tuple == t) {
         return GetRef<Expr>(g);
       } else {
-        return TupleGetItemNode::make(t, g->index);
+        return TupleGetItem(t, g->index);
       }
     }
 
index 6c2b139..446a91b 100644 (file)
@@ -344,13 +344,13 @@ registration.
 .. code:: c++
 
     // Create a simple Relay program.
-    auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
-    auto x = relay::VarNode::make("x", relay::Type());
-    auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
+    auto tensor_type = relay::TensorType({}, tvm::Bool());
+    auto x = relay::Var("x", relay::Type());
+    auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
 
-    auto y = relay::VarNode::make("y", tensor_type);
+    auto y = relay::Var("y", tensor_type);
     auto call = relay::Call(f, tvm::Array<relay::Expr>{ y });
-    auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
+    auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
 
     // Create a module for optimization.
     auto mod = IRModule::FromExpr(fx);
index 1ed6848..84d6a7b 100644 (file)
@@ -97,14 +97,14 @@ class SpanNode : public Object {
            equal(col_offset, other->col_offset);
   }
 
-  TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
-
   static constexpr const char* _type_key = "Span";
   TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
 };
 
 class Span : public ObjectRef {
  public:
+  TVM_DLL Span(SourceName source, int lineno, int col_offset);
+
   TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
 };
 
index c161cc9..4b7037a 100644 (file)
@@ -177,13 +177,23 @@ class PlaceholderOpNode : public OperationNode {
     v->Visit("shape", &shape);
     v->Visit("dtype", &dtype);
   }
-  static Operation make(std::string name, Array<PrimExpr> shape, DataType dtype);
 
   static constexpr const char* _type_key = "PlaceholderOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
 };
 
 /*!
+ * \brief Managed reference to PlaceholderOpNode
+ * \sa PlaceholderOpNode
+ */
+class PlaceholderOp : public Operation {
+ public:
+  TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
+};
+
+/*!
  * \brief A Compute op that compute a tensor on certain domain.
  * This is the base class for ComputeOp (operating on a scalar at a time) and
  * TensorComputeOp (operating on a TensorSlice at a time)
@@ -237,14 +247,24 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
     v->Visit("reduce_axis", &reduce_axis);
     v->Visit("body", &body);
   }
-  static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                        Array<IterVar> axis, Array<PrimExpr> body);
 
   static constexpr const char* _type_key = "ComputeOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
 };
 
 /*!
+ * \brief Managed reference to ComputeOpNode
+ * \sa ComputeOpNode
+ */
+class ComputeOp : public Operation {
+ public:
+  TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                    Array<IterVar> axis, Array<PrimExpr> body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
+};
+
+/*!
  * \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
  */
 class TensorComputeOpNode : public BaseComputeOpNode {
@@ -285,16 +305,26 @@ class TensorComputeOpNode : public BaseComputeOpNode {
     v->Visit("input_regions", &input_regions);
     v->Visit("scalar_inputs", &scalar_inputs);
   }
-  static Operation make(std::string name, std::string tag, Array<IterVar> axis,
-                        Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
-                        Array<Tensor> tensors, Array<Region> regions,
-                        Array<PrimExpr> scalar_inputs);
 
   static constexpr const char* _type_key = "TensorComputeOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
 };
 
 /*!
+ * \brief Managed reference to TensorComputeOpNode
+ * \sa TensorComputeOpNode
+ */
+class TensorComputeOp : public Operation {
+ public:
+  TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
+                          Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
+                          Array<Tensor> tensors, Array<Region> regions,
+                          Array<PrimExpr> scalar_inputs);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
+};
+
+/*!
  * \brief Symbolic scan.
  */
 class ScanOpNode : public OperationNode {
@@ -353,15 +383,25 @@ class ScanOpNode : public OperationNode {
     v->Visit("inputs", &inputs);
     v->Visit("spatial_axis_", &spatial_axis_);
   }
-  static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                        IterVar axis, Array<Tensor> init, Array<Tensor> update,
-                        Array<Tensor> state_placeholder, Array<Tensor> input);
 
   static constexpr const char* _type_key = "ScanOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
 };
 
 /*!
+ * \brief Managed reference to ScanOpNode
+ * \sa ScanOpNode
+ */
+class ScanOp : public Operation {
+ public:
+  TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
+                 Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
+                 Array<Tensor> input);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
+};
+
+/*!
  * \brief External computation that cannot be splitted.
  */
 class ExternOpNode : public OperationNode {
@@ -404,15 +444,25 @@ class ExternOpNode : public OperationNode {
     v->Visit("output_placeholders", &output_placeholders);
     v->Visit("body", &body);
   }
-  TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                                Array<Tensor> inputs, Array<Buffer> input_placeholders,
-                                Array<Buffer> output_placeholders, Stmt body);
 
   static constexpr const char* _type_key = "ExternOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
 };
 
 /*!
+ * \brief Managed reference to ExternOpNode
+ * \sa ExternOpNode
+ */
+class ExternOp : public Operation {
+ public:
+  TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                   Array<Tensor> inputs, Array<Buffer> input_placeholders,
+                   Array<Buffer> output_placeholders, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
+};
+
+/*!
  * \brief A computation operator that generated by hybrid script.
  */
 class HybridOpNode : public OperationNode {
@@ -459,14 +509,24 @@ class HybridOpNode : public OperationNode {
     v->Visit("axis", &axis);
     v->Visit("body", &body);
   }
-  TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                                Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);
 
   static constexpr const char* _type_key = "HybridOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
 };
 
 /*!
+ * \brief Managed reference to HybridOpNode
+ * \sa HybridOpNode
+ */
+class HybridOp : public Operation {
+ public:
+  TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                   Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
+};
+
+/*!
  * \brief Construct a new Var expression
  * \param name_hint The name hint for the expression
  * \param t The type of the expression
index f74a008..ee4fb33 100644 (file)
@@ -278,6 +278,12 @@ class Schedule : public ObjectRef {
   Schedule() {}
   explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
   /*!
+   * \brief Create a schedule for array of ops(and their dependencies).
+   * \param ops The ops to be scheduled.
+   * \return sch The created Schedule.
+   */
+  TVM_DLL explicit Schedule(Array<Operation> ops);
+  /*!
    * \brief Get a copy of current schedule.
    * \return The copied schedule.
    */
@@ -553,13 +559,6 @@ class ScheduleNode : public Object {
    */
   TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
 
-  /*!
-   * \brief Create a schedule for array of ops(and their dependencies).
-   * \param ops The ops to be scheduled.
-   * \return sch The created Schedule.
-   */
-  TVM_DLL static Schedule make(Array<Operation> ops);
-
   static constexpr const char* _type_key = "Schedule";
   TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
 };
@@ -569,7 +568,7 @@ class ScheduleNode : public Object {
  * \param ops The ops to be scheduled.
  * \return sch The created Schedule.
  */
-inline Schedule create_schedule(Array<Operation> ops) { return ScheduleNode::make(ops); }
+inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
 
 /*! \brief node container for IterVar attr */
 class IterVarAttrNode : public Object {
@@ -648,14 +647,22 @@ class SplitNode : public IterVarRelationNode {
     v->Visit("nparts", &nparts);
   }
 
-  static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
-                              PrimExpr nparts);
-
   static constexpr const char* _type_key = "Split";
   TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
 };
 
 /*!
+ * \brief Managed reference to SplitNode
+ * \sa SplitNode
+ */
+class Split : public IterVarRelation {
+ public:
+  TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
+};
+
+/*!
  * \brief Fuse two domains into one domain.
  */
 class FuseNode : public IterVarRelationNode {
@@ -673,13 +680,22 @@ class FuseNode : public IterVarRelationNode {
     v->Visit("fused", &fused);
   }
 
-  static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused);
-
   static constexpr const char* _type_key = "Fuse";
   TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
 };
 
 /*!
+ * \brief Managed reference to FuseNode
+ * \sa FuseNode
+ */
+class Fuse : public IterVarRelation {
+ public:
+  TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
+};
+
+/*!
  * \brief Rebase the iteration to make min to be 0.
  *  This is useful to normalize the Schedule
  *  to make every leaf variable's min to be 0.
@@ -696,13 +712,22 @@ class RebaseNode : public IterVarRelationNode {
     v->Visit("rebased", &rebased);
   }
 
-  static IterVarRelation make(IterVar parent, IterVar rebased);
-
   static constexpr const char* _type_key = "Rebase";
   TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
 };
 
 /*!
+ * \brief Managed reference to RebaseNode
+ * \sa RebaseNode
+ */
+class Rebase : public IterVarRelation {
+ public:
+  TVM_DLL Rebase(IterVar parent, IterVar rebased);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
+};
+
+/*!
  * \brief Singleton iterator [0, 1)
  */
 class SingletonNode : public IterVarRelationNode {
@@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode {
 
   void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
 
-  static IterVarRelation make(IterVar iter);
-
   static constexpr const char* _type_key = "Singleton";
   TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
 };
 
+/*!
+ * \brief Managed reference to SingletonNode
+ * \sa SingletonNode
+ */
+class Singleton : public IterVarRelation {
+ public:
+  TVM_DLL explicit Singleton(IterVar iter);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
+};
+
 /*! \brief Container for specialization conditions. */
 class SpecializedConditionNode : public Object {
  public:
index 045d186..0c4af4b 100644 (file)
@@ -40,25 +40,68 @@ namespace te {
 using arith::IntSet;
 using namespace tvm::tir;
 
-// Internal node container of Tensor
-class TensorNode;
 // internal node container for Operation
 class OperationNode;
 
+/*! \brief Operation that produces tensors */
+class Operation : public tir::FunctionRef {
+ public:
+  /*! \brief default constructor  */
+  Operation() {}
+  explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
+  /*!
+   * \brief access the internal node container
+   * \return the pointer to the internal node container
+   */
+  inline const OperationNode* operator->() const;
+  /*!
+   * \brief get the i-th output of the operation.
+   * \param i the output index.
+   * \return The i-th output.
+   */
+  TVM_DLL Tensor output(size_t i) const;
+  /*! \brief specify container node */
+  using ContainerType = OperationNode;
+};
+
+/*! \brief Node to represent a tensor */
+class TensorNode : public DataProducerNode {
+ public:
+  /*! \brief The shape of the tensor */
+  Array<PrimExpr> shape;
+  /*! \brief data type in the content of the tensor */
+  DataType dtype;
+  /*! \brief the source operation, can be None */
+  Operation op;
+  /*! \brief the output index from source operation */
+  int value_index{0};
+  /*! \brief constructor */
+  TensorNode() {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("shape", &shape);
+    v->Visit("dtype", &dtype);
+    v->Visit("op", &op);
+    v->Visit("value_index", &value_index);
+  }
+
+  Array<PrimExpr> GetShape() const final { return shape; }
+
+  DataType GetDataType() const final { return dtype; }
+
+  TVM_DLL String GetNameHint() const final;
+
+  static constexpr const char* _type_key = "Tensor";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
+};
+
 /*!
  * \brief Tensor structure representing a possible input,
  *  or intermediate computation result.
  */
 class Tensor : public DataProducer {
  public:
-  /*! \brief default constructor, used internally */
-  Tensor() {}
-  explicit Tensor(ObjectPtr<Object> n) : DataProducer(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const TensorNode* operator->() const;
+  TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
   /*!
    * \brief check if two tensors equals each other.
    * \param other tensor to be checked.
@@ -131,69 +174,11 @@ class Tensor : public DataProducer {
    * \return the subsequent slice.
    */
   inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
-  /*! \brief specify container node */
-  using ContainerType = TensorNode;
-};
 
-/*! \brief Operation that produces tensors */
-class Operation : public tir::FunctionRef {
- public:
-  /*! \brief default constructor  */
-  Operation() {}
-  explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const OperationNode* operator->() const;
-  /*!
-   * \brief get the i-th output of the operation.
-   * \param i the output index.
-   * \return The i-th output.
-   */
-  TVM_DLL Tensor output(size_t i) const;
-  /*! \brief specify container node */
-  using ContainerType = OperationNode;
-};
-
-/*! \brief Node to represent a tensor */
-class TensorNode : public DataProducerNode {
- public:
-  /*! \brief The shape of the tensor */
-  Array<PrimExpr> shape;
-  /*! \brief data type in the content of the tensor */
-  DataType dtype;
-  /*! \brief the source operation, can be None */
-  Operation op;
-  /*! \brief the output index from source operation */
-  int value_index{0};
-  /*! \brief constructor */
-  TensorNode() {}
-
-  void VisitAttrs(AttrVisitor* v) {
-    v->Visit("shape", &shape);
-    v->Visit("dtype", &dtype);
-    v->Visit("op", &op);
-    v->Visit("value_index", &value_index);
-  }
-
-  Array<PrimExpr> GetShape() const final { return shape; }
-
-  DataType GetDataType() const final { return dtype; }
-
-  TVM_DLL String GetNameHint() const final;
-
-  TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
-
-  static constexpr const char* _type_key = "Tensor";
-  TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
+  TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode);
 };
 
 // Implementations of inline functions
-inline const TensorNode* Tensor::operator->() const {
-  return static_cast<const TensorNode*>(get());
-}
-
 inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
 
 inline bool Tensor::operator==(const Tensor& other) const {
index 7e76efe..22f29de 100644 (file)
 namespace tvm {
 namespace te {
 
-// Internal node container of tensor intrinsics.
-class TensorIntrinNode;
-
-/*! \brief Tensor intrinsic node. */
-class TensorIntrin : public ObjectRef {
- public:
-  TensorIntrin() {}
-  explicit TensorIntrin(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const TensorIntrinNode* operator->() const;
-
-  /*! \brief specify container node */
-  using ContainerType = TensorIntrinNode;
-};
-
 /*! \brief Node to represent a Tensor intrinsic operator */
 class TensorIntrinNode : public Object {
  public:
@@ -100,17 +82,21 @@ class TensorIntrinNode : public Object {
     v->Visit("reduce_update", &reduce_update);
   }
 
-  TVM_DLL static TensorIntrin make(std::string name, Operation op, Array<Tensor> inputs,
-                                   Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
-                                   Stmt reduce_init, Stmt reduce_update);
-
   static constexpr const char* _type_key = "TensorIntrin";
   TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
 };
 
-inline const TensorIntrinNode* TensorIntrin::operator->() const {
-  return static_cast<const TensorIntrinNode*>(get());
-}
+/*!
+ * \brief Managed reference to TensorIntrinNode
+ * \sa TensorIntrinNode
+ */
+class TensorIntrin : public ObjectRef {
+ public:
+  TVM_DLL TensorIntrin(std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers,
+                       Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode);
+};
 
 class TensorIntrinCallNode : public Object {
  public:
index 6904f2a..5b07cc5 100644 (file)
@@ -32,8 +32,6 @@
 
 namespace tvm {
 namespace tir {
-// Internal node container Buffer
-class BufferNode;
 
 // forward declare Stmt
 class Stmt;
@@ -45,62 +43,6 @@ enum BufferType : int {
   kAutoBroadcast = 2,
 };
 
-/*!
- * \brief Buffer is a symbolic n-darray structure.
- *  It is a composition of primitive symbolic types,
- *  used to specify the memory layout of the Tensor used in program input.
- */
-class Buffer : public ObjectRef {
- public:
-  Buffer() {}
-  explicit Buffer(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief Return a new buffer that is equivalent with current one
-   *  but always add stride field.
-   * \return The strided version of the buffer.
-   */
-  TVM_DLL Buffer MakeStrideView() const;
-  /*!
-   * \brief Make a new symbolic buffer representing a slice of the buffer.
-   * \param begins The beginning position of each dimension.
-   * \param extents The extent of each dimension.
-   * \note This function will make target buffer as compact as possible.
-   *  If stride is not needed in the slice, it won't be presented
-   * \return the result buffer.
-   */
-  TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
-  /*!
-   * \brief Get access ptr to the entire buffer.
-   * \param access_mask The access mask
-   * \param ptr_type The type of the pointer.
-   * \param content_lanes The number of lanes for the (data) type.
-   * \param offset The offset of ptr.
-   */
-  TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
-                              int content_lanes = 1,
-                              PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
-  /*!
-   * \brief Create an Expr that does a vector load at begin index.
-   * \param begin The beginning index
-   * \param dtype The data type to be loaded.
-   */
-  TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
-  /*!
-   * \brief Create a Stmt that does a vector store at begin index.
-   * \param begin The beginning index
-   * \param value The value to be stored.
-   */
-  TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const BufferNode* operator->() const;
-
-  /*! \brief specify container node */
-  using ContainerType = BufferNode;
-};
-
 /*! \brief Node to represent a buffer */
 class BufferNode : public Object {
  public:
@@ -176,22 +118,65 @@ class BufferNode : public Object {
     return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
   }
 
-  // User can specify data_alignment and offset_factor to be 0
-  // A default value will be picked.
-  TVM_DLL static Buffer make(Var ptr, DataType dtype, Array<PrimExpr> shape,
-                             Array<PrimExpr> strides, PrimExpr elem_offset, std::string name,
-                             std::string scope, int data_alignment, int offset_factor,
-                             BufferType buffer_type);
-
   static constexpr const char* _type_key = "Buffer";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
 };
 
-inline const BufferNode* Buffer::operator->() const {
-  return static_cast<const BufferNode*>(get());
-}
+/*!
+ * \brief Buffer is a symbolic n-darray structure.
+ *  It is a composition of primitive symbolic types,
+ *  used to specify the memory layout of the Tensor used in program input.
+ */
+class Buffer : public ObjectRef {
+ public:
+  // User can specify data_alignment and offset_factor to be 0
+  // A default value will be picked.
+  TVM_DLL Buffer(Var ptr, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
+                 PrimExpr elem_offset, std::string name, std::string scope, int data_alignment,
+                 int offset_factor, BufferType buffer_type);
+
+  /*!
+   * \brief Return a new buffer that is equivalent with current one
+   *  but always add stride field.
+   * \return The strided version of the buffer.
+   */
+  TVM_DLL Buffer MakeStrideView() const;
+  /*!
+   * \brief Make a new symbolic buffer representing a slice of the buffer.
+   * \param begins The beginning position of each dimension.
+   * \param extents The extent of each dimension.
+   * \note This function will make target buffer as compact as possible.
+   *  If stride is not needed in the slice, it won't be presented
+   * \return the result buffer.
+   */
+  TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
+  /*!
+   * \brief Get access ptr to the entire buffer.
+   * \param access_mask The access mask
+   * \param ptr_type The type of the pointer.
+   * \param content_lanes The number of lanes for the (data) type.
+   * \param offset The offset of ptr.
+   */
+  TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(),
+                              int content_lanes = 1,
+                              PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
+  /*!
+   * \brief Create an Expr that does a vector load at begin index.
+   * \param begin The beginning index
+   * \param dtype The data type to be loaded.
+   */
+  TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
+  /*!
+   * \brief Create a Stmt that does a vector store at begin index.
+   * \param begin The beginning index
+   * \param value The value to be stored.
+   */
+  TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
+};
 
 /*!
  * \brief Construct a new buffer given shape, and dtype.
@@ -199,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const {
  * \param dtype The content data type.
  * \param name The name of the buffer
  * \return The created buffer.
- * \sa BufferNode::make for complete constructor.
+ * \sa Buffer for complete constructor.
  */
 TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
                            std::string name = "buffer");
index 0a20db6..f705247 100644 (file)
@@ -37,6 +37,8 @@
 namespace tvm {
 namespace tir {
 
+class Layout;
+
 class LayoutAxis {
  public:
   static const LayoutAxis& Get(const char name);
@@ -45,7 +47,7 @@ class LayoutAxis {
   static const LayoutAxis& Get(const tir::IterVar& itvar);
 
   // Get the singleton LayoutAxis using name[0] (size of name must be 1).
-  static const LayoutAxis& make(const std::string& name);
+  static const LayoutAxis& Get(const std::string& name);
 
   inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
   inline std::string name() const { return std::string(1, name_); }
@@ -83,8 +85,16 @@ class LayoutAxis {
   const char name_;
 };
 
-class Layout;
-// Internal node container Buffer
+/*!
+ * \brief Layout is to describe how data is organized within an N-dimention tensor.
+ *  It is composed of upper cases, lower cases and numbers,
+ *  where upper case indicates a primal axis and
+ *  the corresponding lower case with factor size indicates the subordinate axis.
+ *  For example, NCHW16c can describe a 5-D tensor of
+ *  [batch_size, channel, height, width, channel_block].
+ *  Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
+ *  Layout for scalar is defined, while both its name and axes have size 0.
+ */
 class LayoutNode : public Object {
  public:
   /*! \brief string representation of layout, "" for scalar. */
@@ -102,29 +112,16 @@ class LayoutNode : public Object {
     v->Visit("axes", &axes);
   }
 
-  TVM_DLL static Layout make(const std::string& layout);
-
   static constexpr const char* _type_key = "Layout";
   TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
 };
 
 /*!
- * \brief Layout is to describe how data is organized within an N-dimention tensor.
- *  It is composed of upper cases, lower cases and numbers,
- *  where upper case indicates a primal axis and
- *  the corresponding lower case with factor size indicates the subordinate axis.
- *  For example, NCHW16c can describe a 5-D tensor of
- *  [batch_size, channel, height, width, channel_block].
- *  Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
- *  Layout for scalar is defined, while both its name and axes have size 0.
+ * \brief Managed reference to LayoutNode
+ * \sa LayoutNode
  */
 class Layout : public ObjectRef {
  public:
-  explicit Layout(ObjectPtr<Object> n) : ObjectRef(n) {}
-
-  /*! \brief default constructor */
-  Layout() = default;
-
   explicit Layout(const Array<tir::IterVar>& axes);
 
   /*! \brief construct from a string */
@@ -138,13 +135,7 @@ class Layout : public ObjectRef {
    *        indicates the split dimension.
    *        return undefined layout if "__undef__" is passed.
    */
-  Layout(const std::string& name);  // NOLINT(*)
-
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  const LayoutNode* operator->() const { return static_cast<const LayoutNode*>(get()); }
+  TVM_DLL Layout(const std::string& name);  // NOLINT(*)
 
   /*!
    * \brief access the internal node container
@@ -292,10 +283,9 @@ class Layout : public ObjectRef {
     return os;
   }
 
-  using ContainerType = LayoutNode;
+  TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
 };
 
-class BijectiveLayout;
 // Internal node container BijectiveLayout
 class BijectiveLayoutNode : public Object {
  public:
@@ -329,8 +319,6 @@ class BijectiveLayoutNode : public Object {
  */
 class BijectiveLayout : public ObjectRef {
  public:
-  BijectiveLayout() = default;
-  explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
   /*!
    * \brief The constructor
    * \param src_layout The source layout
@@ -347,19 +335,9 @@ class BijectiveLayout : public ObjectRef {
   // Given the destination indices, recover the source indices.
   TVM_DLL Array<PrimExpr> BackwardIndex(const Array<PrimExpr>& dst_index) const;
 
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  inline const BijectiveLayoutNode* operator->() const;
-
-  /*! \brief specify container node */
-  using ContainerType = BijectiveLayoutNode;
+  TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode);
 };
 
-inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
-  return static_cast<const BijectiveLayoutNode*>(get());
-}
 }  // namespace tir
 }  // namespace tvm
 
index c667b49..9d2a11c 100644 (file)
@@ -88,8 +88,8 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
     elem_offset = PrimExpr();
   }
 
-  return tir::BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "",
-                               data_alignment, offset_factor, buffer_type);
+  return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", data_alignment,
+                     offset_factor, buffer_type);
 }
 
 void GetBinds(const Array<te::Tensor>& args, bool compact,
index 742c985..565439f 100644 (file)
@@ -61,17 +61,19 @@ TVM_REGISTER_NODE_TYPE(SourceNameNode)
       return static_cast<const SourceNameNode*>(n)->name;
     });
 
-Span SpanNode::make(SourceName source, int lineno, int col_offset) {
+Span::Span(SourceName source, int lineno, int col_offset) {
   auto n = make_object<SpanNode>();
   n->source = std::move(source);
   n->lineno = lineno;
   n->col_offset = col_offset;
-  return Span(n);
+  data_ = std::move(n);
 }
 
 TVM_REGISTER_NODE_TYPE(SpanNode);
 
-TVM_REGISTER_GLOBAL("ir.Span").set_body_typed(SpanNode::make);
+TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) {
+  return Span(source, lineno, col_offset);
+});
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
index be749fd..3687b75 100644 (file)
@@ -217,8 +217,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
     // Skip fcompute for device copy operators as it is not registered.
     if (op == device_copy_op_) {
       const auto* copy_input = inputs[0].operator->();
-      outputs.push_back(
-          te::TensorNode::make(copy_input->shape, copy_input->dtype, te::Operation(), 0));
+      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
     } else {
       LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
       outputs = lowered_out->outputs;
index d9be91d..9a75c0a 100644 (file)
@@ -181,22 +181,25 @@ class InterpreterStateObj : public Object {
     v->Visit("stack", &stack);
   }
 
-  static InterpreterState make(Expr current_expr, Stack stack);
-
   static constexpr const char* _type_key = "relay.InterpreterState";
   TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object);
 };
 
 class InterpreterState : public ObjectRef {
  public:
+  using Frame = tvm::Map<Var, ObjectRef>;
+  using Stack = tvm::Array<Frame>;
+
+  InterpreterState(Expr current_expr, Stack stack);
+
   TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj);
 };
 
-InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) {
+InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) {
   ObjectPtr<InterpreterStateObj> n = make_object<InterpreterStateObj>();
   n->current_expr = std::move(current_expr);
   n->stack = std::move(stack);
-  return InterpreterState(n);
+  data_ = std::move(n);
 }
 
 // NOTE: the current interpreter assumes A-normal form.
@@ -701,7 +704,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
       InterpreterStateObj::Frame frame = fr.locals;
       stack.push_back(frame);
     }
-    auto state = InterpreterStateObj::make(e, stack);
+    auto state = InterpreterState(e, stack);
     return state;
   }
 
index 92e12b5..1917096 100644 (file)
@@ -112,11 +112,11 @@ struct StorageScope {
     }
   }
   /*!
-   * \brief make storage scope from string
+   * \brief Create storage scope from string
    * \param s The string to be parsed.
    * \return The storage scope.
    */
-  static StorageScope make(const std::string& s) {
+  static StorageScope Create(const std::string& s) {
     StorageScope r;
     if (s.compare(0, 6, "global") == 0) {
       r.rank = StorageRank::kGlobal;
@@ -153,11 +153,11 @@ struct ThreadScope {
   /*! \brief the dimension index under the rank */
   int dim_index{0};
   /*!
-   * \brief make storage scope from string
+   * \brief Create storage scope from string
    * \param s The string to be parsed.
    * \return The storage scope.
    */
-  static ThreadScope make(const std::string& s) {
+  static ThreadScope Create(const std::string& s) {
     ThreadScope r;
     if (s == "vthread" || s == "cthread") {
       // virtual thread at the same level as local
@@ -199,7 +199,7 @@ class ThreadAxisConfig {
     std::vector<bool> filled(6, false);
     for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
       const std::string& tag = thread_axis_tags[i];
-      ThreadScope ts = ThreadScope::make(tag);
+      ThreadScope ts = ThreadScope::Create(tag);
       arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
       filled[ts.rank * 3 + ts.dim_index] = true;
     }
index 280c999..8e6b3a2 100644 (file)
@@ -125,7 +125,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
 
   // Return the thread index via intrinsics.
   llvm::Value* GetThreadIndex(const IterVar& iv) final {
-    runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
+    runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
     llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x;
     if (ts.rank == 1) {
       switch (ts.dim_index) {
index b43e988..3af9fc3 100644 (file)
@@ -1260,7 +1260,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
     const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
     alloc_storage_info_[v].scope =
-        runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
+        runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
   } else if (op->attr_key == tir::attr::storage_alignment) {
     const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
index 353f322..bc47ce1 100644 (file)
@@ -101,7 +101,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
 
   // Return the thread index via intrinsics.
   llvm::Value* GetThreadIndex(const IterVar& iv) final {
-    runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
+    runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
     llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x;
     if (ts.rank == 1) {
       switch (ts.dim_index) {
index e381afb..2c26ee9 100644 (file)
@@ -122,7 +122,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis).value();
 
   for (IterVar iv : thread_axis) {
-    runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
+    runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
     work_dim = std::max(work_dim, scope.dim_index + 1);
   }
   if (work_dim != 0) {
index 746d418..8616853 100644 (file)
@@ -75,7 +75,7 @@ std::string CodeGenOpenCL::Finish() {
 
 void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) {
   CHECK(!var_idmap_.count(iv->var.get()));
-  runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
+  runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
   std::ostringstream os;
   if (ts.rank == 1) {
     os << "get_local_id(" << ts.dim_index << ")";
index 364a62f..699d395 100644 (file)
@@ -92,7 +92,7 @@ void CodeGenSPIRV::InitFuncState() {
 }
 
 spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
-  runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag);
+  runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
   spirv::Value v;
   if (ts.rank == 1) {
     v = builder_->GetLocalID(ts.dim_index);
@@ -580,7 +580,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
   } else if (op->attr_key == tir::attr::storage_scope) {
     const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
-    storage_info_[v].scope = runtime::StorageScope::make(op->value.as<StringImmNode>()->value);
+    storage_info_[v].scope = runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
   } else if (op->attr_key == tir::attr::volatile_scope) {
     const VarNode* v = op->node.as<VarNode>();
     CHECK(v);
index a8a9a0b..1834aa3 100644 (file)
@@ -340,8 +340,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
     new_bodies.push_back(new_body);
   }
 
-  auto new_op =
-      ComputeOpNode::make(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
+  auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
 
   // Jacobian shape = output.shape + input.shape
   Array<PrimExpr> new_shape = output->shape;
@@ -349,7 +348,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
     new_shape.push_back(e);
   }
 
-  return TensorNode::make(new_shape, output->dtype, new_op, value_index);
+  return Tensor(new_shape, output->dtype, new_op, value_index);
 }
 
 }  // namespace te
index 7f957b5..1fc0520 100644 (file)
@@ -99,7 +99,7 @@ Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::
     args.push_back(axis.back()->var);
   }
 
-  return ComputeOpNode::make(name, tag, attrs, axis, {fcompute(args)}).output(0);
+  return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0);
 }
 
 Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string name,
@@ -116,7 +116,7 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string
     args.push_back(axis.back()->var);
   }
 
-  Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
+  Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args));
   Array<Tensor> outputs;
   for (int idx = 0; idx < op->num_outputs(); ++idx) {
     outputs.push_back(op.output(idx));
@@ -124,8 +124,8 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute, std::string
   return outputs;
 }
 
-Operation ComputeOpNode::make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                              Array<IterVar> axis, Array<PrimExpr> body) {
+ComputeOp::ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                     Array<IterVar> axis, Array<PrimExpr> body) {
   if (!attrs.defined()) {
     attrs = Map<String, ObjectRef>();
   }
@@ -140,10 +140,13 @@ Operation ComputeOpNode::make(std::string name, std::string tag, Map<String, Obj
     n->reduce_axis = reduce->axis;
   }
   VerifyComputeOp(n.get());
-  return Operation(n);
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("te.ComputeOp").set_body_typed(ComputeOpNode::make);
+TVM_REGISTER_GLOBAL("te.ComputeOp")
+    .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                       Array<IterVar> axis,
+                       Array<PrimExpr> body) { return ComputeOp(name, tag, attrs, axis, body); });
 
 // The schedule related logics
 Array<Tensor> ComputeOpNode::InputTensors() const {
@@ -188,7 +191,7 @@ Operation ComputeOpNode::ReplaceInputs(const Operation& self,
         UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); });
   }
   if (!arr.same_as(this->body)) {
-    return ComputeOpNode::make(this->name, this->tag, this->attrs, this->axis, arr);
+    return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr);
   } else {
     return self;
   }
@@ -331,7 +334,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
                      const std::unordered_map<IterVar, Range>& dom_map,
                      bool debug_keep_trivial_loop) {
   // grab the nest structure
-  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+  ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
   // Normal loop structure
   n.init_nest.emplace_back(MakeIfNest(n.init_predicates));
   n.main_nest.emplace_back(MakeIfNest(n.main_predicates));
@@ -424,9 +427,9 @@ Stmt ComputeOpNode::BuildProvide(const Stage& stage,
   }
 }
 
-ComputeLoopNest ComputeLoopNest::make(const BaseComputeOpNode* self, const Stage& stage,
-                                      const std::unordered_map<IterVar, Range>& dom_map,
-                                      bool debug_keep_trivial_loop) {
+ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage,
+                                        const std::unordered_map<IterVar, Range>& dom_map,
+                                        bool debug_keep_trivial_loop) {
   CHECK_EQ(stage->op.operator->(), self);
   ComputeLoopNest ret;
   // make main loop nest
index 610c014..2661eb9 100644 (file)
@@ -59,9 +59,9 @@ struct ComputeLoopNest {
    * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
    * \return The constructed loop nest
    */
-  static ComputeLoopNest make(const BaseComputeOpNode* self, const Stage& stage,
-                              const std::unordered_map<IterVar, Range>& dom_map,
-                              bool debug_keep_trivial_loop);
+  static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage,
+                                const std::unordered_map<IterVar, Range>& dom_map,
+                                bool debug_keep_trivial_loop);
 };
 
 /*!
index 0933e30..ef55c44 100644 (file)
@@ -50,9 +50,9 @@ DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders
 
 Array<PrimExpr> ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; }
 
-Operation ExternOpNode::make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                             Array<Tensor> inputs, Array<Buffer> input_placeholders,
-                             Array<Buffer> output_placeholders, Stmt body) {
+ExternOp::ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                   Array<Tensor> inputs, Array<Buffer> input_placeholders,
+                   Array<Buffer> output_placeholders, Stmt body) {
   if (!attrs.defined()) {
     attrs = Map<String, ObjectRef>();
   }
@@ -73,10 +73,15 @@ Operation ExternOpNode::make(std::string name, std::string tag, Map<String, Obje
   n->input_placeholders = std::move(input_placeholders);
   n->output_placeholders = std::move(output_placeholders);
   n->body = std::move(body);
-  return Operation(n);
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("te.ExternOp").set_body_typed(ExternOpNode::make);
+TVM_REGISTER_GLOBAL("te.ExternOp")
+    .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                       Array<Tensor> inputs, Array<Buffer> input_placeholders,
+                       Array<Buffer> output_placeholders, Stmt body) {
+      return ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body);
+    });
 
 Array<Tensor> ExternOpNode::InputTensors() const { return inputs; }
 
index 9b3a79f..9be474d 100644 (file)
@@ -57,8 +57,8 @@ DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype;
 
 Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
 
-Operation HybridOpNode::make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                             Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
+HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                   Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
   if (!attrs.defined()) {
     attrs = Map<String, ObjectRef>();
   }
@@ -70,11 +70,13 @@ Operation HybridOpNode::make(std::string name, std::string tag, Map<String, Obje
   n->outputs = std::move(outputs);
   n->axis = te::GatherLoopVars(body);
   n->body = std::move(body);
-  Operation res = Operation(n);
-  return res;
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("te.HybridOp").set_body_typed(HybridOpNode::make);
+TVM_REGISTER_GLOBAL("te.HybridOp")
+    .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                       Array<Tensor> inputs, Array<Tensor> outputs,
+                       Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); });
 
 Array<Tensor> HybridOpNode::InputTensors() const {
   // Because input tensors could be potentially inlined into hybrid scripts,
index f1b0527..61b7826 100644 (file)
@@ -156,9 +156,9 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         value_map[iv] = dom->min;
       } else {
-        runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
+        runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag);
         if (stage->scope == "" ||
-            static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
+            static_cast<int>(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) {
           value_map[iv] = var;
         } else if (stage->scope == "warp" && ts.rank == 1) {
           // To determine whether a thread index is inside or outside a warp, we need
index 9c536eb..5b7ede3 100644 (file)
@@ -50,16 +50,16 @@ Array<PrimExpr> PlaceholderOpNode::output_shape(size_t i) const {
   return shape;
 }
 
-Operation PlaceholderOpNode::make(std::string name, Array<PrimExpr> shape, DataType dtype) {
+PlaceholderOp::PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype) {
   auto n = make_object<PlaceholderOpNode>();
   n->name = name;
   n->shape = shape;
   n->dtype = dtype;
-  return Operation(n);
+  data_ = std::move(n);
 }
 
 Tensor placeholder(Array<PrimExpr> shape, DataType dtype, std::string name) {
-  return PlaceholderOpNode::make(name, shape, dtype).output(0);
+  return PlaceholderOp(name, shape, dtype).output(0);
 }
 
 TVM_REGISTER_GLOBAL("te.Placeholder")
index 45e86e2..cc86d0f 100644 (file)
@@ -55,9 +55,9 @@ Array<PrimExpr> ScanOpNode::output_shape(size_t i) const {
   return state_placeholder[i]->shape;
 }
 
-Operation ScanOpNode::make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
-                           IterVar axis, Array<Tensor> init, Array<Tensor> update,
-                           Array<Tensor> state_placeholder, Array<Tensor> inputs) {
+ScanOp::ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
+               Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
+               Array<Tensor> inputs) {
   if (!attrs.defined()) {
     attrs = Map<String, ObjectRef>();
   }
@@ -104,10 +104,15 @@ Operation ScanOpNode::make(std::string name, std::string tag, Map<String, Object
   n->update = std::move(update);
   n->state_placeholder = std::move(state_placeholder);
   n->inputs = std::move(inputs);
-  return Operation(n);
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("te.ScanOp").set_body_typed(ScanOpNode::make);
+TVM_REGISTER_GLOBAL("te.ScanOp")
+    .set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
+                       IterVar axis, Array<Tensor> init, Array<Tensor> update,
+                       Array<Tensor> state_placeholder, Array<Tensor> inputs) {
+      return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs);
+    });
 
 Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
                    Array<Tensor> inputs, std::string name, std::string tag,
@@ -115,8 +120,7 @@ Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state
   IterVar scan_axis =
       IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
               Var(name + ".idx"), kOrdered);
-  Operation op =
-      ScanOpNode::make(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs);
+  Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs);
   Array<Tensor> res;
   for (int i = 0; i < op->num_outputs(); ++i) {
     res.push_back(op.output(i));
index c8dfce8..8d5265b 100644 (file)
@@ -52,10 +52,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const {
   return this->intrin->buffers[this->inputs.size() + i]->dtype;
 }
 
-Operation TensorComputeOpNode::make(std::string name, std::string tag, Array<IterVar> axis,
-                                    Array<IterVar> reduce_axis, int schedulable_ndim,
-                                    TensorIntrin intrin, Array<Tensor> tensors,
-                                    Array<Region> regions, Array<PrimExpr> scalar_inputs) {
+TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
+                                 Array<IterVar> reduce_axis, int schedulable_ndim,
+                                 TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
+                                 Array<PrimExpr> scalar_inputs) {
   auto n = make_object<TensorComputeOpNode>();
   n->name = std::move(name);
   n->tag = std::move(tag);
@@ -66,10 +66,17 @@ Operation TensorComputeOpNode::make(std::string name, std::string tag, Array<Ite
   n->inputs = std::move(tensors);
   n->input_regions = std::move(regions);
   n->scalar_inputs = std::move(scalar_inputs);
-  return Operation(n);
+  data_ = std::move(n);
 }
 
-TVM_REGISTER_GLOBAL("te.TensorComputeOp").set_body_typed(TensorComputeOpNode::make);
+TVM_REGISTER_GLOBAL("te.TensorComputeOp")
+    .set_body_typed([](std::string name, std::string tag, Array<IterVar> axis,
+                       Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
+                       Array<Tensor> tensors, Array<Region> regions,
+                       Array<PrimExpr> scalar_inputs) {
+      return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors,
+                             regions, scalar_inputs);
+    });
 
 Array<Tensor> TensorComputeOpNode::InputTensors() const { return inputs; }
 
@@ -191,7 +198,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage,
   binder.BindArray(sp_expr, user_expr, this->name);
 
   size_t tloc = stage->leaf_iter_vars.size();
-  ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
+  ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop);
 
   if (this->reduce_axis.size() == 0) {
     std::vector<std::vector<Stmt> > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1);
index af4b08e..82832c9 100644 (file)
@@ -347,7 +347,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
   size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region);
   TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin;
   CHECK(intrin.defined());
-  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
+  ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
   VerifyTensorizeLoopNest(self, stage, n, tloc);
   VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
   // Start bind data.
index 01d4f93..099f488 100644 (file)
@@ -59,7 +59,7 @@ bool NeedRelax(const IterVar& iv, bool found_attach,
   if (tag.length() == 0 || tag == "pipeline") {
     return !found_attach;
   }
-  ThreadScope ts = ThreadScope::make(tag);
+  ThreadScope ts = ThreadScope::Create(tag);
 
   // When there is warp memory
   // threadIdx.x must be set to be warp index.
@@ -72,14 +72,14 @@ bool NeedRelax(const IterVar& iv, bool found_attach,
 // infer storage scope, if not given
 StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) {
   if (stage->scope.length() != 0) {
-    return StorageScope::make(stage->scope);
+    return StorageScope::Create(stage->scope);
   }
   int max_rank = -1;
   for (IterVar iv : ctx.attach_path.at(stage->op)) {
     auto it = ctx.bind_map.find(iv);
     const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
     if (tag != "pipeline" && tag.length() != 0) {
-      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
+      max_rank = std::max(max_rank, ThreadScope::Create(tag).rank);
     }
   }
   StorageScope s;
index c360513..af72d3b 100644 (file)
@@ -324,16 +324,16 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_a
       args.push_back(value_map.at(iv));
     }
   }
-  Operation cache_op = ComputeOpNode::make(compute->name + "." + scope, compute->tag,
-                                           compute->attrs, new_axis, body_list);
+  Operation cache_op =
+      ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list);
 
   Array<PrimExpr> cache_expr_list;
   for (size_t i = 0; i < tensor_size; i++) {
     Tensor cache_tensor = cache_op.output(i);
     cache_expr_list.push_back(cache_tensor(args));
   }
-  Operation orig_new_op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs,
-                                              compute->axis, cache_expr_list);
+  Operation orig_new_op =
+      ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list);
   return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
 }
 
@@ -380,10 +380,10 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& te
     new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
   }
 
-  Operation cache_op = TensorComputeOpNode::make(tensor_op->name + "." + scope, tensor_op->tag,
-                                                 new_axis, tensor_op->reduce_axis,
-                                                 tensor_op->schedulable_ndim, tensor_op->intrin,
-                                                 tensor_op->inputs, new_regions, new_scalar_inputs);
+  Operation cache_op =
+      TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis,
+                      tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin,
+                      tensor_op->inputs, new_regions, new_scalar_inputs);
 
   // axis will be used in generating compute op
   Array<IterVar> compute_axis = tensor_op->axis;
@@ -419,7 +419,7 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& te
     cache_expr_list.push_back(cache_tensor(args));
   }
   Operation orig_new_op =
-      ComputeOpNode::make(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
+      ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
   return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
 }
 
@@ -468,7 +468,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
       if (idx < leaf_vars->size()) {
         // insert rebase
         IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
-        s->relations.push_back(RebaseNode::make(iv, rebased));
+        s->relations.push_back(te::Rebase(iv, rebased));
         if (s->iter_var_attrs.count(iv)) {
           s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
         }
@@ -583,8 +583,7 @@ void InjectInline(ScheduleNode* sch) {
       CHECK(compute);
       Operation op = s->op;
       if (changed[i]) {
-        op = ComputeOpNode::make(compute->name, compute->tag, compute->attrs, compute->axis,
-                                 new_body[i]);
+        op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]);
       }
       op = op->ReplaceInputs(op, repl);
       if (!op.same_as(s->op)) {
@@ -596,8 +595,8 @@ void InjectInline(ScheduleNode* sch) {
     } else if (hybrid_changed[i]) {
       const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
       CHECK(hybrid);
-      Operation op = HybridOpNode::make(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
-                                        hybrid->outputs, new_hybrid_body[i]);
+      Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+                              hybrid->outputs, new_hybrid_body[i]);
       op = op->ReplaceInputs(op, repl);
       for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
         repl[s->op.output(idx)] = op.output(idx);
index 24d9102..707d52f 100644 (file)
@@ -55,8 +55,8 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
   return 0;
 }
 
-void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, IterVar* p_outer,
-           IterVar* p_inner) {
+void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts,
+                 IterVar* p_outer, IterVar* p_inner) {
   // Check if split is valid.
   CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce ||
         parent->iter_type == kOrdered)
@@ -69,7 +69,7 @@ void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, It
   Array<IterVar>& all_vars = self->all_iter_vars;
   Array<IterVar>& leaf_vars = self->leaf_iter_vars;
   size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent);
-  self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts));
+  self->relations.push_back(Split(parent, outer, inner, factor, nparts));
   // add vars to all vars
   all_vars.push_back(outer);
   all_vars.push_back(inner);
@@ -206,13 +206,13 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) {
 
 Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
                     IterVar* p_inner) {  // NOLINT(*)
-  Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
+  SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
   return *this;
 }
 
 Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
                               IterVar* p_inner) {  // NOLINT(*)
-  Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
+  SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
   return *this;
 }
 
@@ -242,7 +242,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) {  // NOLINT
   }
   CHECK_EQ(pos_inner, pos_outer + 1)
       << "Can only fuse iterations that are consecutive between each other";
-  self->relations.push_back(FuseNode::make(outer, inner, fused));
+  self->relations.push_back(Fuse(outer, inner, fused));
   all_vars.push_back(fused);
   leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1);
   leaf_vars.insert(leaf_vars.begin() + pos_outer, fused);
@@ -263,7 +263,7 @@ Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) {  // NOLINT(*
     // insert at the outer most loop
     IterVar singleton =
         IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar);
-    self->relations.push_back(SingletonNode::make(singleton));
+    self->relations.push_back(Singleton(singleton));
     Array<IterVar>& all_vars = self->all_iter_vars;
     Array<IterVar>& leaf_vars = self->leaf_iter_vars;
     all_vars.push_back(singleton);
@@ -624,9 +624,9 @@ bool ScheduleNode::Contain(const Operation& op) const {
   return stage_map.find(op) != stage_map.end();
 }
 
-Schedule ScheduleNode::make(Array<Operation> ops) {
+Schedule::Schedule(Array<Operation> ops) {
   auto n = make_object<ScheduleNode>();
-  Schedule sch(n);
+  data_ = n;
   n->outputs = ops;
   auto g = te::CreateReadGraph(n->outputs);
   Array<Operation> post_order = te::PostDFSOrder(n->outputs, g);
@@ -650,7 +650,7 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
         inputs.push_back(t);
       }
       // Create the scan group.
-      Stage scan_group = sch.create_group(scan->update, inputs, false);
+      Stage scan_group = this->create_group(scan->update, inputs, false);
       scan_group->attach_type = kScanUpdate;
       scan_group->attach_stage = stage;
 
@@ -660,39 +660,37 @@ Schedule ScheduleNode::make(Array<Operation> ops) {
       }
     }
   }
-  return sch;
 }
 
-IterVarRelation SplitNode::make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
-                                PrimExpr nparts) {
+Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) {
   auto n = make_object<SplitNode>();
   n->parent = parent;
   n->outer = outer;
   n->inner = inner;
   n->factor = factor;
   n->nparts = nparts;
-  return IterVarRelation(n);
+  data_ = std::move(n);
 }
 
-IterVarRelation FuseNode::make(IterVar outer, IterVar inner, IterVar fused) {
+Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) {
   auto n = make_object<FuseNode>();
   n->outer = outer;
   n->inner = inner;
   n->fused = fused;
-  return IterVarRelation(n);
+  data_ = std::move(n);
 }
 
-IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
+Rebase::Rebase(IterVar parent, IterVar rebased) {
   auto n = make_object<RebaseNode>();
   n->parent = parent;
   n->rebased = rebased;
-  return IterVarRelation(n);
+  data_ = std::move(n);
 }
 
-IterVarRelation SingletonNode::make(IterVar iter) {
+Singleton::Singleton(IterVar iter) {
   auto n = make_object<SingletonNode>();
   n->iter = iter;
-  return IterVarRelation(n);
+  data_ = std::move(n);
 }
 
 SpecializedCondition::SpecializedCondition(Array<PrimExpr> conditions) {
index 7e7f648..e66b963 100644 (file)
@@ -66,28 +66,32 @@ Tensor Operation::output(size_t i) const {
   return Tensor(node);
 }
 
-Tensor TensorNode::make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
+Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
   auto n = make_object<TensorNode>();
   n->shape = std::move(shape);
   n->dtype = dtype;
   n->op = op;
   n->value_index = value_index;
-  return Tensor(n);
+  data_ = std::move(n);
 }
 
+TVM_REGISTER_GLOBAL("te.Tensor")
+    .set_body_typed([](Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
+      return Tensor(shape, dtype, op, value_index);
+    });
+
+TVM_REGISTER_NODE_TYPE(TensorNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<TensorNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* t = static_cast<const TensorNode*>(node.get());
       p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')';
     });
 
-TVM_REGISTER_NODE_TYPE(TensorNode);
-
 // TensorIntrin
-
-TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array<Tensor> inputs,
-                                    Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
-                                    Stmt reduce_init, Stmt reduce_update) {
+TensorIntrin::TensorIntrin(std::string name, Operation op, Array<Tensor> inputs,
+                           Array<Buffer> buffers, Array<Var> scalar_params, Stmt body,
+                           Stmt reduce_init, Stmt reduce_update) {
   auto n = make_object<TensorIntrinNode>();
   n->name = std::move(name);
   n->op = std::move(op);
@@ -97,17 +101,24 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array<Tensor
   n->body = std::move(body);
   n->reduce_init = std::move(reduce_init);
   n->reduce_update = std::move(reduce_update);
-  return TensorIntrin(n);
+  data_ = std::move(n);
 }
 
+TVM_REGISTER_GLOBAL("te.TensorIntrin")
+    .set_body_typed([](std::string name, Operation op, Array<Tensor> inputs, Array<Buffer> buffers,
+                       Array<Var> scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) {
+      return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init,
+                          reduce_update);
+    });
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<TensorIntrinNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const TensorIntrinNode*>(node.get());
       p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")";
     });
 
-TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
-
 // TensorIntrinCall
 TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors,
                                    Array<Region> regions, Array<IterVar> reduce_axis,
@@ -135,10 +146,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode);
 
-TVM_REGISTER_GLOBAL("te.Tensor").set_body_typed(TensorNode::make);
-
-TVM_REGISTER_GLOBAL("te.TensorIntrin").set_body_typed(TensorIntrinNode::make);
-
+// Other tensor ops.
 TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==);
 
 TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t {
index 46f4160..2b64f11 100644 (file)
@@ -45,8 +45,8 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
 }
 
 Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
-  return BufferNode::make(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array<PrimExpr>(),
-                          PrimExpr(), name, "", 0, 0, kDefault);
+  return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array<PrimExpr>(),
+                PrimExpr(), name, "", 0, 0, kDefault);
 }
 
 // Split the given expression w.r.t the add operator
@@ -348,8 +348,8 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
       return MakeStrideView().MakeSlice(begins, extents);
     }
   }
-  return BufferNode::make(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice",
-                          n->scope, n->data_alignment, 0, n->buffer_type);
+  return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope,
+                n->data_alignment, 0, n->buffer_type);
 }
 
 PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
@@ -379,9 +379,9 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane
   return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic);
 }
 
-Buffer BufferNode::make(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
-                        PrimExpr elem_offset, std::string name, std::string scope,
-                        int data_alignment, int offset_factor, BufferType buffer_type) {
+Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr> shape, Array<PrimExpr> strides,
+               PrimExpr elem_offset, std::string name, std::string scope, int data_alignment,
+               int offset_factor, BufferType buffer_type) {
   auto n = make_object<BufferNode>();
   n->data = std::move(data);
   n->dtype = dtype;
@@ -410,7 +410,7 @@ Buffer BufferNode::make(Var data, DataType dtype, Array<PrimExpr> shape, Array<P
       n->strides.push_back(Var("stride", n->shape[i].dtype()));
     }
   }
-  return Buffer(n);
+  data_ = std::move(n);
 }
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -425,8 +425,8 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
   CHECK_EQ(args.size(), 10);
   auto buffer_type = args[9].operator std::string();
   BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault;
-  *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7],
-                          args[8], type);
+  *ret =
+      Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type);
 });
 
 TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr);
index 1f17c35..bc777db 100644 (file)
@@ -66,7 +66,7 @@ const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) {
   return LayoutAxis::Get(axis[0]);
 }
 
-const LayoutAxis& LayoutAxis::make(const std::string& name) {
+const LayoutAxis& LayoutAxis::Get(const std::string& name) {
   CHECK_EQ(name.length(), 1) << "Invalid axis " << name;
   return LayoutAxis::Get(name[0]);
 }
@@ -144,8 +144,6 @@ Layout::Layout(const std::string& name) {  // NOLINT(*)
   data_ = std::move(node);
 }
 
-Layout LayoutNode::make(const std::string& layout) { return Layout(layout); }
-
 Layout Layout::SubLayout(size_t pos, size_t len) const {
   if (!defined() || pos > ndim()) return Layout::Undef();
   if (len == 0) return Layout(Array<IterVar>());
@@ -365,15 +363,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
                 << ")";
     });
 
-TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed(LayoutNode::make);
+TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
 
 TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
-  return layout.IndexOf(LayoutAxis::make(axis));
+  return layout.IndexOf(LayoutAxis::Get(axis));
 });
 
 TVM_REGISTER_GLOBAL("tir.LayoutFactorOf")
     .set_body_typed([](Layout layout, std::string axis) -> int {
-      return layout.FactorOf(LayoutAxis::make(axis));
+      return layout.FactorOf(LayoutAxis::Get(axis));
     });
 
 TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int {
index c201b8f..416358c 100644 (file)
@@ -147,12 +147,12 @@ class CopyIntrinInjector : public StmtMutator {
       src_strides.push_back(make_const(DataType::Int(32), 1));
       dst_strides.push_back(make_const(DataType::Int(32), 1));
     }
-    Buffer dst = BufferNode::make(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
-                                  store_strides[loop_var_size], store->buffer_var->name_hint,
-                                  GetStorageScope(store->buffer_var.get()), 0, 0, kDefault);
-    Buffer src = BufferNode::make(load->buffer_var, load->dtype, src_shape, src_strides,
-                                  src_elem_offset, load->buffer_var->name_hint,
-                                  GetStorageScope(load->buffer_var.get()), 0, 0, kDefault);
+    Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides,
+                        store_strides[loop_var_size], store->buffer_var->name_hint,
+                        GetStorageScope(store->buffer_var.get()), 0, 0, kDefault);
+    Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
+                        load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0,
+                        kDefault);
     *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
     CHECK(out->defined()) << "flower function did not return correct stmt";
     return true;
index 7dbf0fc..3b2580c 100644 (file)
@@ -113,7 +113,7 @@ class CandidateSelector final : public StmtExprVisitor {
       const IterVarNode* iv = op->node.as<IterVarNode>();
       CHECK(iv);
       Var var = iv->var;
-      runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
+      runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
       if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) {
         record_.insert({var.get(), false});
         StmtExprVisitor::VisitStmt_(op);
@@ -361,7 +361,7 @@ class LoopPartitioner : public StmtMutator {
     }
 
     // normal path when loop parittion fails.
-    runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
+    runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
     Stmt res;
     if (scope.rank == 1) {
       // threadIdx should be put into relax map, in case of divergence.
index 0b87757..9d6b47a 100644 (file)
@@ -63,7 +63,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::storage_scope) {
       const VarNode* buf = op->node.as<VarNode>();
-      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
+      StorageScope scope = StorageScope::Create(op->value.as<StringImmNode>()->value);
       StorageEntry e;
       e.scope = scope;
       if (scope.tag.length() != 0) {
index 058014c..ee17f08 100644 (file)
@@ -165,7 +165,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     for (const AttrStmtNode* attr : thread_extents_) {
       ThreadEntry e;
       IterVar iv = Downcast<IterVar>(attr->node);
-      e.scope = runtime::ThreadScope::make(iv->thread_tag);
+      e.scope = runtime::ThreadScope::Create(iv->thread_tag);
       e.iv = iv;
       CHECK_LE(e.scope.rank, 1);
       CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
@@ -516,7 +516,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
 
     IterVar iv = Downcast<IterVar>(op->node);
     ThreadEntry e;
-    e.scope = runtime::ThreadScope::make(iv->thread_tag);
+    e.scope = runtime::ThreadScope::Create(iv->thread_tag);
     e.extent = 0;
     if (auto ptr = op->value.as<IntImmNode>()) {
       e.extent = static_cast<int>(ptr->value);
index a0ddf26..92f9ab5 100644 (file)
@@ -367,7 +367,7 @@ class WarpMemoryRewriter : private StmtMutator {
     using runtime::StorageScope;
     if (op->attr_key == attr::storage_scope) {
       const VarNode* buf = op->node.as<VarNode>();
-      StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
+      StorageScope scope = StorageScope::Create(op->value.as<StringImmNode>()->value);
       if (scope.rank == runtime::StorageRank::kWarp) {
         warp_buffer_.insert(buf);
         Stmt ret = StmtMutator::VisitStmt_(op);
index 3a42137..20cc640 100644 (file)
@@ -92,7 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) {
 void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == attr::storage_scope) {
     const VarNode* buf = op->node.as<VarNode>();
-    storage_scope_[buf] = StorageScope::make(op->value.as<StringImmNode>()->value);
+    storage_scope_[buf] = StorageScope::Create(op->value.as<StringImmNode>()->value);
     StmtExprVisitor::VisitStmt_(op);
   } else if (op->attr_key == attr::double_buffer_write) {
     CHECK(double_buffer_write_ == nullptr);
@@ -215,11 +215,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
     CHECK(allow_append_);
     const std::string& s = op->args[0].as<StringImmNode>()->value;
     if (s != "warp") {
-      StorageScope scope = StorageScope::make(s);
+      StorageScope scope = StorageScope::Create(s);
       AccessEntry e;
       e.threads = env_threads();
       e.type = kSync;
-      e.scope = StorageScope::make(s);
+      e.scope = StorageScope::Create(s);
       curr_stmt_.access.emplace_back(std::move(e));
     }
   } else {
index 4c3de58..e29d978 100644 (file)
@@ -91,7 +91,7 @@ class StorageFlattener : public StmtExprMutator {
       return body;
     } else if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
-      ThreadScope ts = ThreadScope::make(iv->thread_tag);
+      ThreadScope ts = ThreadScope::Create(iv->thread_tag);
       curr_thread_scope_.push_back(ts);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       curr_thread_scope_.pop_back();
@@ -165,7 +165,7 @@ class StorageFlattener : public StmtExprMutator {
           skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
         }
       } else {
-        skey = StorageScope::make(strkey);
+        skey = StorageScope::Create(strkey);
       }
 
       // use small alignment for small arrays
@@ -200,9 +200,9 @@ class StorageFlattener : public StmtExprMutator {
         strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
       }
 
-      e.buffer = BufferNode::make(Var(op->buffer->data->name_hint, DataType::Handle()),
-                                  op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
-                                  skey.to_string(), align, 0, kDefault);
+      e.buffer =
+          Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape,
+                 strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault);
 
       buf_map_[key] = e;
       Stmt body = this->VisitStmt(op->body);
index 2d09e8b..283ab0f 100644 (file)
@@ -178,7 +178,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
       VisitNewScope(op);
     } else if (op->attr_key == attr::storage_scope) {
       const VarNode* buf = op->node.as<VarNode>();
-      alloc_info_[buf].storage_scope = StorageScope::make(op->value.as<StringImmNode>()->value);
+      alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as<StringImmNode>()->value);
       StmtExprVisitor::VisitStmt_(op);
     } else {
       StmtExprVisitor::VisitStmt_(op);
index e5b4bdd..b8575d2 100644 (file)
@@ -251,7 +251,7 @@ class ThreadSyncInserter : public StmtExprMutator {
       return ret;
     } else if (op->attr_key == attr::storage_scope) {
       const VarNode* buf = op->node.as<VarNode>();
-      storage_scope_[buf] = StorageScope::make(op->value.as<StringImmNode>()->value);
+      storage_scope_[buf] = StorageScope::Create(op->value.as<StringImmNode>()->value);
       return StmtExprMutator::VisitStmt_(op);
     } else {
       return StmtExprMutator::VisitStmt_(op);
@@ -321,7 +321,7 @@ class ThreadSyncInserter : public StmtExprMutator {
       num_work_dim_ = thread_extents_.size();
       for (const AttrStmtNode* attr : thread_extents_) {
         IterVar iv = Downcast<IterVar>(attr->node);
-        runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag);
+        runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
         if (s.rank == 0) {
           num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value);
         } else if (s.rank == 1) {
@@ -353,7 +353,7 @@ class ThreadSyncInserter : public StmtExprMutator {
 };
 
 Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
-  StorageScope sync_scope = StorageScope::make(storage_scope);
+  StorageScope sync_scope = StorageScope::Create(storage_scope);
   ThreadSyncPlanner planner(sync_scope);
   planner(stmt);
   return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
index 8823134..70709b0 100644 (file)
@@ -51,11 +51,11 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue*
 TEST(MicroStandaloneRuntime, BuildModule) {
   using namespace tvm;
   auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32));
-  auto a = relay::VarNode::make("a", tensor_type);
-  auto b = relay::VarNode::make("b", tensor_type);
+  auto a = relay::Var("a", tensor_type);
+  auto b = relay::Var("b", tensor_type);
   auto add_op = relay::Op::Get("add");
   auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {});
-  auto c = relay::VarNode::make("c", tensor_type);
+  auto c = relay::Var("c", tensor_type);
   auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {});
   auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
   auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
index 25b3800..b84fbc7 100644 (file)
@@ -46,8 +46,7 @@ using namespace tvm::te;
 inline Buffer DeclExternBuffer(Array<PrimExpr> shape, DataType dtype, std::string name) {
   auto data = var(name, DataType::Handle());
   auto elem_offset = PrimExpr();
-  return BufferNode::make(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", -1, 0,
-                          kDefault);
+  return Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", -1, 0, kDefault);
 }
 
 /*!
@@ -93,8 +92,7 @@ inline Array<Tensor> make_extern(const Array<Array<PrimExpr> >& out_shapes,
   auto body = fextern(input_placeholders, output_placeholders);
   auto body_stmt = tvm::tir::Evaluate(body);
 
-  auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders,
-                               body_stmt);
+  auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt);
 
   Array<Tensor> outputs;
   for (size_t i = 0; i < output_placeholders.size(); ++i) {
index 9aa4e35..e830e09 100644 (file)
@@ -1194,8 +1194,8 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout,
                                const std::string& dst_layout,
                                const std::string name = "T_layout_trans",
                                const std::string tag = kInjective) {
-  Layout src_layout_struct = LayoutNode::make(src_layout);
-  Layout dst_layout_struct = LayoutNode::make(dst_layout);
+  Layout src_layout_struct(src_layout);
+  Layout dst_layout_struct(dst_layout);
 
   if (src_layout_struct.Equals(dst_layout_struct)) {
     return src;