[Relay] Fix memory leak in the interpreter (#4155)
author雾雨魔理沙 <lolisa@marisa.moe>
Thu, 24 Oct 2019 18:50:25 +0000 (11:50 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Thu, 24 Oct 2019 18:50:25 +0000 (11:50 -0700)
* save

lint

* address reviewer comment

include/tvm/relay/interpreter.h
python/tvm/relay/backend/interpreter.py
src/relay/backend/interpreter.cc

index a0422fa..f0b1e7c 100644 (file)
@@ -119,6 +119,32 @@ class ClosureNode : public ValueNode {
 
 RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
 
+/*! \brief A Relay Recursive Closure. A closure that has a name. */
+class RecClosure;
+
+/*! \brief The container type of RecClosure. */
+class RecClosureNode : public ValueNode {
+ public:
+  /*! \brief The closure. */
+  Closure clos;
+  /*! \brief variable the closure bind to. */
+  Var bind;
+
+  RecClosureNode() {}
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("clos", &clos);
+    v->Visit("bind", &bind);
+  }
+
+  TVM_DLL static RecClosure make(Closure clos, Var bind);
+
+  static constexpr const char* _type_key = "relay.RecClosure";
+  TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
+};
+
+RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);
+
 /*! \brief A tuple value. */
 class TupleValue;
 
index ae60b7a..1d53f6a 100644 (file)
@@ -73,6 +73,11 @@ class Closure(Value):
 
 
 @register_relay_node
+class RecClosure(Value):
+    """A recursive closure produced by the interpreter."""
+
+
+@register_relay_node
 class ConstructorValue(Value):
     def __init__(self, tag, fields, constructor):
         self.__init_handle_by_constructor__(
index 86a4ebb..2703b1c 100644 (file)
@@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure")
 
 TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
 .set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
-    p->stream << "ClosureNode(" << node->func << ")";
+    p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
   });
 
+
+// TODO(@jroesch): this doesn't support mutual letrec
+/* Value Implementation */
+RecClosure RecClosureNode::make(Closure clos, Var bind) {
+  NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
+  n->clos = std::move(clos);
+  n->bind = std::move(bind);
+  return RecClosure(n);
+}
+
+TVM_REGISTER_API("relay._make.RecClosure")
+.set_body_typed(RecClosureNode::make);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
+                                p->stream << "RecClosureNode(" << node->clos << ")";
+                              });
+
 TupleValue TupleValueNode::make(tvm::Array<Value> value) {
   NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
   n->fields = value;
@@ -281,7 +299,6 @@ class Interpreter :
     return TupleValueNode::make(values);
   }
 
-  // TODO(@jroesch): this doesn't support mutual letrec
   inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
     tvm::Map<Var, Value> captured_mod;
     Array<Var> free_vars = FreeVars(func);
@@ -298,10 +315,8 @@ class Interpreter :
 
     // We must use mutation here to build a self referential closure.
     auto closure = ClosureNode::make(captured_mod, func);
-    auto mut_closure =
-        static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
     if (letrec_name.defined()) {
-      mut_closure->env.Set(letrec_name, closure);
+      return RecClosureNode::make(closure, letrec_name);
     }
     return std::move(closure);
   }
@@ -559,7 +574,7 @@ class Interpreter :
   }
 
   // Invoke the closure
-  Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
+  Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
     // Get a reference to the function inside the closure.
     if (closure->func->IsPrimitive()) {
       return InvokePrimitiveOp(closure->func, args);
@@ -575,12 +590,16 @@ class Interpreter :
       locals.Set(func->params[i], args[i]);
     }
 
-    // Add the var to value mappings from the Closure's modironment.
+    // Add the var to value mappings from the Closure's environment.
     for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
       CHECK_EQ(locals.count((*it).first), 0);
       locals.Set((*it).first, (*it).second);
     }
 
+    if (bind.defined()) {
+      locals.Set(bind, RecClosureNode::make(closure, bind));
+    }
+
     return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
   }
 
@@ -607,6 +626,8 @@ class Interpreter :
     if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
       auto closure = GetRef<Closure>(closure_node);
       return this->Invoke(closure, args);
+    } else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
+      return this->Invoke(closure_node->clos, args, closure_node->bind);
     } else {
       LOG(FATAL) << "internal error: type error, expected function value in the call "
                  << "position";