[TIR][REFACTOR] Deprecate FreeStmt (#5890)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 24 Jun 2020 21:48:12 +0000 (14:48 -0700)
committerGitHub <noreply@github.com>
Wed, 24 Jun 2020 21:48:12 +0000 (14:48 -0700)
Currently FreeStmt is not being used.
While it can be useful to have an early free hint
we can always use an intrinsic instead of a first class statement.

include/tvm/tir/stmt.h
include/tvm/tir/stmt_functor.h
python/tvm/tir/__init__.py
python/tvm/tir/stmt.py
src/printer/text_printer.h
src/printer/tir_text_printer.cc
src/tir/ir/stmt.cc
src/tir/ir/stmt_functor.cc
tests/python/unittest/test_tir_constructor.py

index b928aec..16800d5 100644 (file)
@@ -545,35 +545,6 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
-/*! \brief Free the resources in the buffer before the scope ends. */
-class FreeNode : public StmtNode {
- public:
-  /*! \brief The buffer variable. */
-  Var buffer_var;
-
-  void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); }
-
-  bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
-    return equal(buffer_var, other->buffer_var);
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
-
-  static constexpr const char* _type_key = "tir.Free";
-  TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
-};
-
-/*!
- * \brief Managed reference to FreeNode.
- * \sa FreeNode
- */
-class Free : public Stmt {
- public:
-  TVM_DLL Free(Var buffer_var);
-
-  TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
-};
-
 /*!
  * \brief The container of seq statement.
  *        Represent a sequence of statements.
index f037de7..0f4238d 100644 (file)
@@ -90,7 +90,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
   virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
-  virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
   virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -112,7 +111,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     IR_STMT_FUNCTOR_DISPATCH(ForNode);
     IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
     IR_STMT_FUNCTOR_DISPATCH(StoreNode);
-    IR_STMT_FUNCTOR_DISPATCH(FreeNode);
     IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
     IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
     IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
@@ -154,7 +152,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
   void VisitStmt_(const StoreNode* op) override;
   void VisitStmt_(const BufferStoreNode* op) override;
   void VisitStmt_(const BufferRealizeNode* op) override;
-  void VisitStmt_(const FreeNode* op) override;
   void VisitStmt_(const AssertStmtNode* op) override;
   void VisitStmt_(const ProducerStoreNode* op) override;
   void VisitStmt_(const ProducerRealizeNode* op) override;
@@ -246,7 +243,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
   Stmt VisitStmt_(const StoreNode* op) override;
   Stmt VisitStmt_(const BufferStoreNode* op) override;
   Stmt VisitStmt_(const BufferRealizeNode* op) override;
-  Stmt VisitStmt_(const FreeNode* op) override;
   Stmt VisitStmt_(const AssertStmtNode* op) override;
   Stmt VisitStmt_(const ProducerStoreNode* op) override;
   Stmt VisitStmt_(const ProducerRealizeNode* op) override;
index 982b31c..90ccde4 100644 (file)
@@ -29,7 +29,7 @@ from .expr import IterVar, Any
 
 from .stmt import Stmt, LetStmt, AssertStmt, For
 from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
-from .stmt import Free, ProducerRealize, SeqStmt
+from .stmt import ProducerRealize, SeqStmt
 from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
 
 from .function import PrimFunc
index 4536580..757b2ac 100644 (file)
@@ -258,20 +258,6 @@ class AttrStmt(Stmt):
             _ffi_api.AttrStmt, node, attr_key, value, body)
 
 
-@tvm._ffi.register_object("tir.Free")
-class Free(Stmt):
-    """Free node.
-
-    Parameters
-    ----------
-    buffer_var : Var
-        The buffer variable.
-    """
-    def __init__(self, buffer_var):
-        self.__init_handle_by_constructor__(
-            _ffi_api.Free, buffer_var)
-
-
 @tvm._ffi.register_object("tir.ProducerRealize")
 class ProducerRealize(Stmt):
     """ProducerRealize node.
index c7b2b31..65b4a81 100644 (file)
@@ -303,7 +303,6 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitStmt_(const BufferStoreNode* op) override;
   Doc VisitStmt_(const BufferRealizeNode* op) override;
   Doc VisitStmt_(const AllocateNode* op) override;
-  Doc VisitStmt_(const FreeNode* op) override;
   Doc VisitStmt_(const IfThenElseNode* op) override;
   Doc VisitStmt_(const SeqStmtNode* op) override;
   Doc VisitStmt_(const EvaluateNode* op) override;
index 233a739..a11de01 100644 (file)
@@ -438,12 +438,6 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
   return doc;
 }
 
-Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) {
-  Doc doc;
-  doc << "free(" << Print(op->buffer_var) << ")";
-  return doc;
-}
-
 Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
   Doc doc;
   doc << "if " << Print(op->condition) << PrintBody(op->then_case);
index c3ddb66..7b4ac7e 100644 (file)
@@ -374,25 +374,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "}\n";
     });
 
-// Free
-Free::Free(Var buffer_var) {
-  ObjectPtr<FreeNode> node = make_object<FreeNode>();
-  node->buffer_var = buffer_var;
-  data_ = std::move(node);
-}
-
-TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); });
-
-TVM_REGISTER_NODE_TYPE(FreeNode);
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-    .set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const FreeNode*>(node.get());
-      p->PrintIndent();
-      p->stream << "free " << op->buffer_var;
-      p->stream << '\n';
-    });
-
 // Prefetch
 Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
   data_ = make_object<PrefetchNode>(buffer, bounds);
index 67329aa..abf6438 100644 (file)
@@ -79,8 +79,6 @@ void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
   }
 }
 
-void StmtVisitor::VisitStmt_(const FreeNode* op) {}
-
 void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
   this->VisitExpr(op->condition);
   this->VisitExpr(op->message);
@@ -381,8 +379,6 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
   }
 }
 
-Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef<Stmt>(op); }
-
 // Implementations of IRTransform, PostOrderVisit and Substitute
 class IRApplyVisit : public StmtExprVisitor {
  public:
index 0f8a023..d2c504b 100644 (file)
@@ -171,10 +171,6 @@ def test_stmt_constructor():
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.Free(buffer_var)
-    assert isinstance(x, tvm.tir.Free)
-    assert x.buffer_var == buffer_var
-
     x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"),
                             tvm.tir.Evaluate(11),
                             nop)