[TensorExpr] Switch Exprs and Stmt from kernel-arena to shared_ptr. (#63216)
authorMikhail Zolotukhin <mvz@fb.com>
Tue, 24 Aug 2021 07:29:22 +0000 (00:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 07:32:11 +0000 (00:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63216

Currently there are three classes managed by KernelArena: Expr, Stmt,
and Tensor (and derived classes). KernelArena has been a long standing
painpoint for NNC devs and we're moving away from that memory management
model to ref-count based memory model (using shared_ptr). This commit
switches Expr and Stmt to shared_ptr and is the biggest change in this
transition. Later commits will detach Tensor from KernelArena and kill
the arena + scope altogether.

Differential Revision:
D30353195
D30353195

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: 9575225ada3d0fb65087ae40435f3dfea4792cae

18 files changed:
test/test_tensorexpr_pybind.py
torch/csrc/jit/tensorexpr/eval.cpp
torch/csrc/jit/tensorexpr/expr.h
torch/csrc/jit/tensorexpr/fwd_decls.h
torch/csrc/jit/tensorexpr/hash_provider.cpp
torch/csrc/jit/tensorexpr/hash_provider.h
torch/csrc/jit/tensorexpr/ir.h
torch/csrc/jit/tensorexpr/ir_cloner.cpp
torch/csrc/jit/tensorexpr/ir_mutator.cpp
torch/csrc/jit/tensorexpr/ir_printer.cpp
torch/csrc/jit/tensorexpr/ir_simplifier.cpp
torch/csrc/jit/tensorexpr/ir_simplifier.h
torch/csrc/jit/tensorexpr/ir_verifier.cpp
torch/csrc/jit/tensorexpr/ir_visitor.cpp
torch/csrc/jit/tensorexpr/llvm_codegen.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp
torch/csrc/jit/tensorexpr/stmt.h
torch/csrc/jit/tensorexpr/tensorexpr_init.cpp

index d838892975c0c785c90002528f04d6c859c5f35b..0ae59e1c564840b6551c7eeb26c5a19b2254f893 100644 (file)
@@ -394,9 +394,6 @@ graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
         np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
         np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
 
-    def test_forgot_kernel_arena(self):
-        self.assertRaises(RuntimeError, lambda: torch._C._te.VarHandle("n", torch._C._te.Dtype.Int))
-
     @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
     def test_alloc_in_loop(self):
         with kernel_arena_scope():
index c7a28bdbb23ac8906a0b2fd6d9366fbe0e2ddba8..05c3ff82452215a68de0696396f6acdb5e3dca73 100644 (file)
@@ -281,8 +281,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
     return Value(result_v);
   }
 
-  template <typename Op>
-  void visit_binary_op(BinaryOpNode<Op>* v, bool option = false) {
+  template <
+      typename D,
+      typename std::enable_if<std::is_same<
+          decltype(detail::bin_op_deducer(std::declval<D>())),
+          void>::value>::type* = nullptr>
+  void visit_binary_op(NodePtr<D> v, bool option = false) {
     v->lhs()->accept(this);
     Value lhs_v = value_;
     v->rhs()->accept(this);
index fae24ec34be2895b700e327fb93594b9df9393fc..1b942eaf353fc47dedf087ed350888536ffc9970 100644 (file)
@@ -36,10 +36,11 @@ enum IRNodeType {
 };
 
 // The common base between all expression node.
-class TORCH_API Expr : public KernelScopedObject {
+class TORCH_API Expr : public std::enable_shared_from_this<Expr> {
  public:
   explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
       : dtype_(dtype), expr_type_(expr_type) {}
+  virtual ~Expr() = default;
   Dtype dtype() const {
     return dtype_;
   }
@@ -66,6 +67,11 @@ class TORCH_API Expr : public KernelScopedObject {
    */
   static ExprPtr clone(ExprPtr s);
 
+ protected:
+  std::shared_ptr<Expr> getptr() {
+    return shared_from_this();
+  }
+
  private:
   Dtype dtype_;
   IRNodeType expr_type_;
@@ -78,7 +84,7 @@ class ExprNode : public Base {
  public:
   using ExprNodeBase = ExprNode<Op>;
   void accept(IRVisitor* visitor) override {
-    visitor->visit(static_to<Op>(this));
+    visitor->visit(static_to<Op>(Base::getptr()));
   }
   ExprPtr accept_mutator(IRMutator* mutator) override;
   // pass the constructor to the base class
@@ -335,7 +341,7 @@ class TORCH_API VarHandle : public ExprHandle {
 
 template <class Op, class Base>
 ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
-  return mutator->mutate(static_to<Op>(this));
+  return mutator->mutate(static_to<Op>(Base::getptr()));
 }
 
 inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
index 01a767067f62081a21fabf20cb4f464a038505b5..1b3dde560b427000d960ed95057167de79696fe5 100644 (file)
@@ -1,26 +1,27 @@
 #pragma once
 #include <c10/core/ScalarType.h>
+#include <memory>
 
 namespace torch {
 namespace jit {
 namespace tensorexpr {
 
 template <typename Node>
-using NodePtr = Node*;
+using NodePtr = std::shared_ptr<Node>;
 
 template <typename To, typename From>
 NodePtr<To> to(NodePtr<From> x) {
-  return dynamic_cast<NodePtr<To>>(x);
+  return std::dynamic_pointer_cast<To>(x);
 }
 
 template <typename To, typename From>
 NodePtr<To> static_to(NodePtr<From> x) {
-  return static_cast<NodePtr<To>>(x);
+  return std::static_pointer_cast<To>(x);
 }
 
 template <typename Node, typename... Args>
 NodePtr<Node> alloc(Args&&... args) {
-  return new Node(std::forward<Args>(args)...);
+  return std::make_shared<Node>(std::forward<Args>(args)...);
 }
 
 class Buf;
index fbc257d1988df1db26d1b86506529cc68a6c760d..dce25669bf32387cd9d732539af035708d4f98c1 100644 (file)
@@ -63,6 +63,13 @@ void HashProvider::visit(ModPtr v) {
   putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
 }
 
+void HashProvider::visit(RoundOffPtr v) {
+  CACHE_GUARD();
+  v->lhs()->accept(this);
+  v->rhs()->accept(this);
+  putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs())));
+}
+
 void HashProvider::visit(MaxPtr v) {
   CACHE_GUARD();
   v->lhs()->accept(this);
index 5a33f048fec84fbb3b92d73156e79c7b3d778bb3..91ce269edeb5cfe01808bbcc2cc8a558a76b4520 100644 (file)
@@ -59,12 +59,16 @@ class TORCH_API HashProvider : public IRVisitor {
     return hashOf(e);
   }
 
-  bool cachedHash(const KernelScopedObject* e) {
+  bool cachedHash(ExprPtr e) {
     return exprToHash_.find(e) != exprToHash_.end();
   }
+  bool cachedHash(StmtPtr s) {
+    return stmtToHash_.find(s) != stmtToHash_.end();
+  }
 
   void clearCache() {
     exprToHash_.clear();
+    stmtToHash_.clear();
   }
 
   void visit(AddPtr v) override;
@@ -72,6 +76,7 @@ class TORCH_API HashProvider : public IRVisitor {
   void visit(MulPtr v) override;
   void visit(DivPtr v) override;
   void visit(ModPtr v) override;
+  void visit(RoundOffPtr v) override;
   void visit(MaxPtr v) override;
   void visit(MinPtr v) override;
   void visit(AndPtr v) override;
@@ -133,8 +138,8 @@ class TORCH_API HashProvider : public IRVisitor {
   }
 
   SimplifierHashType hashOf(StmtPtr s) {
-    auto it = exprToHash_.find(s);
-    if (it != exprToHash_.end()) {
+    auto it = stmtToHash_.find(s);
+    if (it != stmtToHash_.end()) {
       return it->second;
     }
 
@@ -182,15 +187,23 @@ class TORCH_API HashProvider : public IRVisitor {
     _hash_combine(seed, args...);
   }
 
-  void putHash(const KernelScopedObject* e, SimplifierHashType h) {
+  void putHash(ExprPtr e, SimplifierHashType h) {
     auto res = exprToHash_.emplace(e, h);
     if (res.second == false) {
       // This is always a logic bug since we should check the cache first.
       throw std::runtime_error("hash collision");
     }
   }
+  void putHash(StmtPtr s, SimplifierHashType h) {
+    auto res = stmtToHash_.emplace(s, h);
+    if (res.second == false) {
+      // This is always a logic bug since we should check the cache first.
+      throw std::runtime_error("hash collision");
+    }
+  }
 
-  std::unordered_map<const KernelScopedObject*, SimplifierHashType> exprToHash_;
+  std::unordered_map<ExprPtr, SimplifierHashType> exprToHash_;
+  std::unordered_map<StmtPtr, SimplifierHashType> stmtToHash_;
   UniqueNameManager name_manager_;
 
   size_t te_hash(SimplifierHashType val) {
index 761b233fe83755fd6c4d3c42e304a7eb8a2e766e..f9fc7dcfc4246d0718682331d956bbef9a382fed 100644 (file)
@@ -178,6 +178,12 @@ class BinaryOpNode : public ExprNode<Op> {
   ExprPtr rhs_;
 };
 
+namespace detail {
+template <typename T>
+void bin_op_deducer(BinaryOpNode<T>);
+bool bin_op_deducer(...);
+} // namespace detail
+
 class TORCH_API Add : public BinaryOpNode<Add> {
  public:
   Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
index f724f2cbeb16f65fad465ed6d7915f11d07c257d..e225826df66e29bfddfa87fe6bd3d16b8cbe284f 100644 (file)
@@ -10,9 +10,13 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-template <typename Op>
+template <
+    typename Op,
+    typename std::enable_if<std::is_same<
+        decltype(detail::bin_op_deducer(std::declval<Op>())),
+        void>::value>::type* = nullptr>
 static ExprPtr mutate_binary_op(
-    NodePtr<BinaryOpNode<Op>> v,
+    NodePtr<Op> v,
     IRCloner* cloner,
     bool option = false) {
   ExprPtr lhs_new = v->lhs()->accept_mutator(cloner);
index 96635acab8c90534d38fac8b32307d054dceb516..45121581eebf06687e8b32a1c19e5804edb8d178 100644 (file)
@@ -11,9 +11,13 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-template <typename Op>
+template <
+    typename Op,
+    typename std::enable_if<std::is_same<
+        decltype(detail::bin_op_deducer(std::declval<Op>())),
+        void>::value>::type* = nullptr>
 static ExprPtr mutate_binary_op(
-    BinaryOpNode<Op>* v,
+    NodePtr<Op> v,
     IRMutator* mutator,
     bool option = false) {
   ExprPtr lhs = v->lhs();
index 23466f39160c83fbd14a4331a546b5a9fd3226a9..f885246e24d2b34247dee8f2af8d1d06448bcf62 100644 (file)
@@ -28,9 +28,13 @@ void IRPrinter::print(Stmt& stmt) {
 
 // TODO: change whether to include the parenthesis to the parent expression,
 // we need to look at the operator precedence to make the output simpler.
-template <typename Op>
+template <
+    typename Op,
+    typename std::enable_if<std::is_same<
+        decltype(detail::bin_op_deducer(std::declval<Op>())),
+        void>::value>::type* = nullptr>
 void visitBinaryOp(
-    BinaryOpNode<Op>* v,
+    NodePtr<Op> v,
     const std::string& op_str,
     IRPrinter* printer,
     bool parens = true) {
index cb731d2525e7183eec18d3cf7c7db8d6fc63bcb9..23216dd4002f73e0200f0f47c8af98b14cc5549e 100644 (file)
@@ -6,6 +6,70 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
+// Creates a new Expr of the given type with the provided lhs and rhs.
+inline ExprPtr newBinaryOpOfType(
+    IRNodeType expr_type,
+    ExprPtr lhs,
+    ExprPtr rhs,
+    bool option) {
+  switch (expr_type) {
+    // NOLINTNEXTLINE(bugprone-branch-clone)
+    case IRNodeType::kAdd:
+      return alloc<Add>(lhs, rhs);
+    case IRNodeType::kSub:
+      return alloc<Sub>(lhs, rhs);
+    case IRNodeType::kMul:
+      return alloc<Mul>(lhs, rhs);
+    case IRNodeType::kDiv:
+      return alloc<Div>(lhs, rhs);
+    case IRNodeType::kMod:
+      return alloc<Mod>(lhs, rhs);
+    case IRNodeType::kMax:
+      return alloc<Max>(lhs, rhs, option);
+    case IRNodeType::kMin:
+      return alloc<Min>(lhs, rhs, option);
+    case IRNodeType::kAnd:
+      return alloc<And>(lhs, rhs);
+    case IRNodeType::kXor:
+      return alloc<Xor>(lhs, rhs);
+    case IRNodeType::kLshift:
+      return alloc<Lshift>(lhs, rhs);
+    case IRNodeType::kRshift:
+      return alloc<Rshift>(lhs, rhs);
+    default:
+      LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
+      return nullptr;
+  }
+}
+
+template <
+    typename Op,
+    typename std::enable_if<std::is_same<
+        decltype(detail::bin_op_deducer(std::declval<Op>())),
+        void>::value>::type* = nullptr>
+static ExprPtr mutateBinaryOp(
+    NodePtr<Op> v,
+    IRMutator* mutator,
+    bool option = false) {
+  ExprPtr lhs = v->lhs();
+  ExprPtr rhs = v->rhs();
+  ExprPtr lhs_new = lhs->accept_mutator(mutator);
+  ExprPtr rhs_new = rhs->accept_mutator(mutator);
+
+  ExprPtr node = v;
+
+  if (lhs != lhs_new || rhs != rhs_new) {
+    node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
+  }
+
+  // Can only fold if both sides are constant.
+  if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
+    return node;
+  }
+
+  return evaluateOp(node);
+}
+
 // Simple recursive GCD.
 template <typename T>
 T gcd(T a, T b) {
@@ -1499,6 +1563,22 @@ ExprPtr PolynomialTransformer::mutate(IfThenElsePtr v) {
   return alloc<IfThenElse>(condition_new, true_value_new, false_value_new);
 }
 
+ExprPtr PolynomialTransformer::mutate(AndPtr v) {
+  return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(XorPtr v) {
+  return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(LshiftPtr v) {
+  return mutateBinaryOp(v, this);
+}
+
+ExprPtr PolynomialTransformer::mutate(RshiftPtr v) {
+  return mutateBinaryOp(v, this);
+}
+
 StmtPtr PolynomialBase::mutate(CondPtr v) {
   ExprPtr cond_old = v->condition();
   StmtPtr true_old = v->true_stmt();
@@ -1904,6 +1984,7 @@ c10::optional<class ModRound*> isModRound(TermPtr e) {
     scalar = getImmediateByType(multiplier->dtype(), 1);
   }
 
+  // TODO: this leaks memory!
   return new ModRound(scalar, denom, divisor, mod_divisor);
 }
 
index 87c476242e8dec6a9d7b1f779734e051007d571e..1df8b5d8f3501a0910238a3a708f8bca0172d192 100644 (file)
@@ -55,7 +55,7 @@ Dtype promoteTypesVec(std::vector<ExprType>& v) {
 template <class ExprType>
 Dtype promoteTypesMap(
     ExprPtr s,
-    std::unordered_map<SimplifierHashType, ExprType*>& m) {
+    std::unordered_map<SimplifierHashType, ExprType>& m) {
   Dtype t = s->dtype();
   bool first = true;
   for (auto& e : m) {
@@ -69,12 +69,12 @@ Dtype promoteTypesMap(
 }
 
 template <class ExprType>
-Dtype promoteTypesVar(ExprType* e) {
+Dtype promoteTypesVar(ExprType e) {
   return e->dtype();
 }
 
 template <class ExprType, class... Args>
-Dtype promoteTypesVar(ExprType* e, Args... es) {
+Dtype promoteTypesVar(ExprType e, Args... es) {
   Dtype lhs = e->dtype();
   Dtype rhs = promoteTypesVar(es...);
   if (e->isConstant()) {
@@ -84,42 +84,6 @@ Dtype promoteTypesVar(ExprType* e, Args... es) {
   return promoteTypes(lhs, rhs);
 }
 
-// Creates a new Expr of the given type with the provided lhs and rhs.
-inline ExprPtr newBinaryOpOfType(
-    IRNodeType expr_type,
-    ExprPtr lhs,
-    ExprPtr rhs,
-    bool option) {
-  switch (expr_type) {
-    // NOLINTNEXTLINE(bugprone-branch-clone)
-    case IRNodeType::kAdd:
-      return alloc<Add>(lhs, rhs);
-    case IRNodeType::kSub:
-      return alloc<Sub>(lhs, rhs);
-    case IRNodeType::kMul:
-      return alloc<Mul>(lhs, rhs);
-    case IRNodeType::kDiv:
-      return alloc<Div>(lhs, rhs);
-    case IRNodeType::kMod:
-      return alloc<Mod>(lhs, rhs);
-    case IRNodeType::kMax:
-      return alloc<Max>(lhs, rhs, option);
-    case IRNodeType::kMin:
-      return alloc<Min>(lhs, rhs, option);
-    case IRNodeType::kAnd:
-      return alloc<And>(lhs, rhs);
-    case IRNodeType::kXor:
-      return alloc<Xor>(lhs, rhs);
-    case IRNodeType::kLshift:
-      return alloc<Lshift>(lhs, rhs);
-    case IRNodeType::kRshift:
-      return alloc<Rshift>(lhs, rhs);
-    default:
-      LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
-      return nullptr;
-  }
-}
-
 // Uses the evaluator to fold an Expression with constant terms.
 // E.g. evaluateOp(Add(3, 4)) => 7.
 // Expr v must not have any unbound Vars.
@@ -498,21 +462,13 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
 
   ExprPtr mutate(ModPtr v) override;
 
-  ExprPtr mutate(AndPtr v) override {
-    return mutateBinaryOp(v, this);
-  }
+  ExprPtr mutate(AndPtr v) override;
 
-  ExprPtr mutate(XorPtr v) override {
-    return mutateBinaryOp(v, this);
-  }
+  ExprPtr mutate(XorPtr v) override;
 
-  ExprPtr mutate(LshiftPtr v) override {
-    return mutateBinaryOp(v, this);
-  }
+  ExprPtr mutate(LshiftPtr v) override;
 
-  ExprPtr mutate(RshiftPtr v) override {
-    return mutateBinaryOp(v, this);
-  }
+  ExprPtr mutate(RshiftPtr v) override;
 
   ExprPtr mutate(MaxPtr v) override;
 
@@ -526,30 +482,6 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
 
   ExprPtr mutate(IfThenElsePtr v) override;
 
-  template <typename Op>
-  static ExprPtr mutateBinaryOp(
-      BinaryOpNode<Op>* v,
-      IRMutator* mutator,
-      bool option = false) {
-    ExprPtr lhs = v->lhs();
-    ExprPtr rhs = v->rhs();
-    ExprPtr lhs_new = lhs->accept_mutator(mutator);
-    ExprPtr rhs_new = rhs->accept_mutator(mutator);
-
-    ExprPtr node = v;
-
-    if (lhs != lhs_new || rhs != rhs_new) {
-      node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
-    }
-
-    // Can only fold if both sides are constant.
-    if (!lhs_new->isConstant() || !rhs_new->isConstant()) {
-      return node;
-    }
-
-    return evaluateOp(node);
-  }
-
   static ExprPtr simplify(ExprPtr e);
   static ExprHandle simplify(const ExprHandle& e);
   static StmtPtr simplify(StmtPtr e);
index c88e92c9a7a8258782971fc44c12d6c60607f10c..f7adbdee939929d39c11d12a939698b44d62c6fa 100644 (file)
@@ -9,8 +9,19 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-template <typename Op>
-void verifyBitwiseOp(const BitwiseOpNode<Op>* v, IRVerifier* verifier) {
+namespace detail {
+template <typename T>
+void deducer(BinaryOpNode<T>);
+
+bool deducer(...);
+} // namespace detail
+
+template <
+    typename D,
+    typename std::enable_if<std::is_same<
+        decltype(detail::deducer(std::declval<D>())),
+        void>::value>::type* = nullptr>
+void verifyBitwiseOp(NodePtr<D> v, IRVerifier* verifier) {
   if (!v->lhs()->dtype().is_integral()) {
     throw unsupported_dtype();
   }
index 9066544bd229150ae5365060e8f032e033dccad0..eb2a4280c4f88de8be2f086bb83005175b82a649 100644 (file)
@@ -11,8 +11,12 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-template <typename Op>
-static void visit_binary_op(BinaryOpNode<Op>* v, IRVisitor* visitor) {
+template <
+    typename Op,
+    typename std::enable_if<std::is_same<
+        decltype(detail::bin_op_deducer(std::declval<Op>())),
+        void>::value>::type* = nullptr>
+static void visit_binary_op(NodePtr<Op> v, IRVisitor* visitor) {
   v->lhs()->accept(visitor);
   v->rhs()->accept(visitor);
 }
index eac1f82f25c4bbd2f74eef2d554ac39534665a53..4ab2d53cc4942c32f3f84eea2c0902823ae93ac0 100644 (file)
@@ -488,12 +488,13 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
     if (v->op_type() == kTanh) {
       ScalarType stype = v->dtype().scalar_type();
       if (stype == ScalarType::Float) {
-        return fast_tanh(v->param(0)->accept_mutator(this)).node();
+        return fast_tanh(ExprHandle(v->param(0)->accept_mutator(this))).node();
       }
     } else if (v->op_type() == kSigmoid) {
       ScalarType stype = v->dtype().scalar_type();
       if (stype == ScalarType::Float) {
-        return fast_sigmoid(v->param(0)->accept_mutator(this)).node();
+        return fast_sigmoid(ExprHandle(v->param(0)->accept_mutator(this)))
+            .node();
       }
     }
     // TODO: fast exp
index a296d8c7af79b3053dc937fbf2c7f4a276797ad7..d9d20736057fb38a8de451bfa3cd22baf151f4a0 100644 (file)
@@ -2380,7 +2380,7 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) {
 
 void LoopNest::compressAllBuffers(StmtPtr stmt) {
   for (auto buf : BufFinder::find(stmt)) {
-    compressBuffer(const_cast<BufPtr>(buf), stmt);
+    compressBuffer(buf, stmt);
   }
 }
 
index 0b4a2e4c5361c44b9ce05f1404cd5e12a9d3c417..7e4914fbc4aa7ad6f33b87f9a60dedc4cd829ca9 100644 (file)
@@ -14,14 +14,15 @@ namespace tensorexpr {
 class Placeholder;
 
 // The common base between all statement node.
-class TORCH_API Stmt : public KernelScopedObject {
+class TORCH_API Stmt : public std::enable_shared_from_this<Stmt> {
  public:
   Stmt() = default;
+  virtual ~Stmt() = default;
   virtual void accept(IRVisitor* visitor) = 0;
   virtual StmtPtr accept_mutator(IRMutator* mutator) = 0;
 
   StmtPtr get_parent() const {
-    return parent_;
+    return parent_ ? parent_->getptr() : nullptr;
   }
 
   /*
@@ -34,12 +35,15 @@ class TORCH_API Stmt : public KernelScopedObject {
   static StmtPtr clone(StmtPtr s);
 
  protected:
-  static void set_parent(StmtPtr s, StmtPtr new_parent) {
+  static void set_parent(StmtPtr s, Stmt* new_parent) {
     s->parent_ = new_parent;
   }
+  std::shared_ptr<Stmt> getptr() {
+    return shared_from_this();
+  }
 
  private:
-  StmtPtr parent_ = nullptr;
+  Stmt* parent_ = nullptr;
 };
 
 template <class Op>
@@ -47,7 +51,7 @@ class StmtNode : public Stmt {
  public:
   using StmtNodeBase = StmtNode<Op>;
   void accept(IRVisitor* visitor) override {
-    visitor->visit(static_to<Op>(this));
+    visitor->visit(static_to<Op>(getptr()));
   }
   StmtPtr accept_mutator(IRMutator* mutator) override;
   StmtNode() = default;
@@ -55,7 +59,7 @@ class StmtNode : public Stmt {
 
 template <class Op>
 StmtPtr StmtNode<Op>::accept_mutator(IRMutator* mutator) {
-  return mutator->mutate(static_to<Op>(this));
+  return mutator->mutate(static_to<Op>(getptr()));
 }
 
 // Concrete Stmt classes
@@ -193,7 +197,7 @@ class TORCH_API Block : public StmtNode<Block> {
   }
 
   void clear() {
-    for (auto* s : stmts_) {
+    for (auto s : stmts_) {
       set_parent(s, nullptr);
     }
     stmts_.clear();
@@ -281,7 +285,7 @@ class TORCH_API Block : public StmtNode<Block> {
 
   // returns the immediate child containing statement s.
   StmtPtr getEnclosedRoot(StmtPtr s) const {
-    while (s && s->get_parent() != this) {
+    while (s && s->get_parent().get() != this) {
       s = s->get_parent();
     }
     return s;
index 304a317076c05ed073d4f1b772dac185dfee1543..4e1618a8745d74cfbc3288420e14e1251080e1e7 100644 (file)
@@ -184,10 +184,7 @@ void initTensorExprBindings(PyObject* module) {
           [](Placeholder& self,
              const std::vector<ExprHandle>& args,
              const ExprHandle& val) { return self.store(args, val); })
-      .def(
-          "data",
-          [](Placeholder& self) { return BufHandle(self.data()); },
-          py::return_value_policy::reference);
+      .def("data", [](Placeholder& self) { return BufHandle(self.data()); });
   py::class_<Tensor, std::unique_ptr<Tensor, py::nodelete>>(te, "Tensor")
       .def(py::init(
           [](BufHandle& b, StmtPtr s) { return new Tensor(b.node(), s); }))
@@ -197,8 +194,9 @@ void initTensorExprBindings(PyObject* module) {
             return self.load(v);
           })
       .def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
-      .def("stmt", &Tensor::stmt, py::return_value_policy::reference);
-  py::class_<Cast>(te, "Cast").def_static("make", &Cast::make);
+      .def("stmt", &Tensor::stmt);
+  py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
+      .def_static("make", &Cast::make);
 
   py::class_<DimArg>(te, "DimArg")
       .def(py::init<const ExprHandle&>())
@@ -321,7 +319,7 @@ void initTensorExprBindings(PyObject* module) {
       },
       py::return_value_policy::reference);
 
-  py::class_<Stmt, std::unique_ptr<Stmt, py::nodelete>>(te, "Stmt")
+  py::class_<Stmt, std::shared_ptr<Stmt>>(te, "Stmt")
       .def(py::init([](const std::vector<StmtPtr>& stmts) {
         return tensorexpr::Block::make(stmts);
       }))
@@ -330,22 +328,18 @@ void initTensorExprBindings(PyObject* module) {
         ss << self;
         return ss.str();
       });
-  py::class_<Store, Stmt, std::unique_ptr<Store, py::nodelete>>(te, "Store")
+  py::class_<Store, Stmt, std::shared_ptr<Store>>(te, "Store")
       .def_static(
           "make",
           [](const BufHandle& buf,
              std::vector<ExprHandle>& indices,
              const ExprHandle& value) {
             return Store::make(buf, indices, value);
-          },
-          py::return_value_policy::reference);
+          });
 
-  py::class_<For, Stmt, std::unique_ptr<For, py::nodelete>>(te, "For")
-      .def(
-          "index_var",
-          [](For& self) { return VarHandle(self.var()); },
-          py::return_value_policy::reference)
-      .def("body", &For::body, py::return_value_policy::reference)
+  py::class_<For, Stmt, std::shared_ptr<For>>(te, "For")
+      .def("index_var", [](For& self) { return VarHandle(self.var()); })
+      .def("body", &For::body)
       .def("set_parallel", &For::set_parallel)
       .def(
           "set_gpu_block_index",
@@ -362,35 +356,28 @@ void initTensorExprBindings(PyObject* module) {
           [](const VarHandle& var,
              const ExprHandle& start,
              const ExprHandle& stop,
-             StmtPtr body) { return For::make(var, start, stop, body); },
-          py::return_value_policy::reference);
+             StmtPtr body) { return For::make(var, start, stop, body); });
 
-  py::class_<Cond, Stmt, std::unique_ptr<Cond, py::nodelete>>(te, "Cond")
+  py::class_<Cond, Stmt, std::shared_ptr<Cond>>(te, "Cond")
       .def_static(
           "make",
           [](const ExprHandle& condition,
              StmtPtr true_stmt,
              StmtPtr false_stmt) {
-            return alloc<Cond>(condition.node(), true_stmt, false_stmt);
-          },
-          py::return_value_policy::reference)
-      .def("true_stmt", &Cond::true_stmt, py::return_value_policy::reference)
-      .def("false_stmt", &Cond::false_stmt, py::return_value_policy::reference);
+            return Cond::make(condition, true_stmt, false_stmt);
+          })
+      .def("true_stmt", &Cond::true_stmt)
+      .def("false_stmt", &Cond::false_stmt);
 
-  py::class_<
-      tensorexpr::Block,
-      Stmt,
-      std::unique_ptr<tensorexpr::Block, py::nodelete>>(te, "Block")
+  py::class_<tensorexpr::Block, Stmt, std::shared_ptr<tensorexpr::Block>>(
+      te, "Block")
       .def(py::init([](const std::vector<StmtPtr>& stmts) {
         return tensorexpr::Block::make(stmts);
       }))
-      .def(
-          "stmts",
-          &tensorexpr::Block::stmts,
-          py::return_value_policy::reference);
-  py::class_<ExternalCall, Stmt, std::unique_ptr<ExternalCall, py::nodelete>>(
+      .def("stmts", &tensorexpr::Block::stmts);
+  py::class_<ExternalCall, Stmt, std::shared_ptr<ExternalCall>>(
       te, "ExternalCall")
-      .def(py::init(&ExternalCall::make), py::return_value_policy::reference);
+      .def(py::init(&ExternalCall::make));
 
   py::class_<LoopNest>(te, "LoopNest")
       .def(py::init<const std::vector<Tensor*>&>())