From: 雾雨魔理沙 Date: Thu, 24 Oct 2019 18:50:25 +0000 (-0700) Subject: [Relay] Fix memory leak in the interpreter (#4155) X-Git-Tag: upstream/0.7.0~1740 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2e0dbaa62aadae094bf4819da8211296cd78fe0b;p=platform%2Fupstream%2Ftvm.git [Relay] Fix memory leak in the interpreter (#4155) * save lint * address reviewer comment --- diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index a0422fa..f0b1e7c 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -119,6 +119,32 @@ class ClosureNode : public ValueNode { RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); +/*! \brief A Relay Recursive Closure. A closure that has a name. */ +class RecClosure; + +/*! \brief The container type of RecClosure. */ +class RecClosureNode : public ValueNode { + public: + /*! \brief The closure. */ + Closure clos; + /*! \brief variable the closure bind to. */ + Var bind; + + RecClosureNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("clos", &clos); + v->Visit("bind", &bind); + } + + TVM_DLL static RecClosure make(Closure clos, Var bind); + + static constexpr const char* _type_key = "relay.RecClosure"; + TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value); + /*! \brief A tuple value. */ class TupleValue; diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ae60b7a..1d53f6a 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -73,6 +73,11 @@ class Closure(Value): @register_relay_node +class RecClosure(Value): + """A recursive closure produced by the interpreter.""" + + +@register_relay_node class ConstructorValue(Value): def __init__(self, tag, fields, constructor): self.__init_handle_by_constructor__( diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 86a4ebb..2703b1c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ClosureNode* node, tvm::IRPrinter* p) { - p->stream << "ClosureNode(" << node->func << ")"; + p->stream << "ClosureNode(" << node->func << ", " << node->env << ")"; }); + +// TODO(@jroesch): this doesn't support mutual letrec +/* Value Implementation */ +RecClosure RecClosureNode::make(Closure clos, Var bind) { + NodePtr n = make_node(); + n->clos = std::move(clos); + n->bind = std::move(bind); + return RecClosure(n); +} + +TVM_REGISTER_API("relay._make.RecClosure") +.set_body_typed(RecClosureNode::make); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const RecClosureNode* node, tvm::IRPrinter* p) { + p->stream << "RecClosureNode(" << node->clos << ")"; + }); + TupleValue TupleValueNode::make(tvm::Array value) { NodePtr n = make_node(); n->fields = value; @@ -281,7 +299,6 @@ class Interpreter : return TupleValueNode::make(values); } - // TODO(@jroesch): this doesn't support mutual letrec inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); @@ -298,10 +315,8 @@ class Interpreter : // We must use mutation here to build a self referential closure. auto closure = ClosureNode::make(captured_mod, func); - auto mut_closure = - static_cast(const_cast(closure.get())); if (letrec_name.defined()) { - mut_closure->env.Set(letrec_name, closure); + return RecClosureNode::make(closure, letrec_name); } return std::move(closure); } @@ -559,7 +574,7 @@ class Interpreter : } // Invoke the closure - Value Invoke(const Closure& closure, const tvm::Array& args) { + Value Invoke(const Closure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->IsPrimitive()) { return InvokePrimitiveOp(closure->func, args); @@ -575,12 +590,16 @@ class Interpreter : locals.Set(func->params[i], args[i]); } - // Add the var to value mappings from the Closure's modironment. + // Add the var to value mappings from the Closure's environment. for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { CHECK_EQ(locals.count((*it).first), 0); locals.Set((*it).first, (*it).second); } + if (bind.defined()) { + locals.Set(bind, RecClosureNode::make(closure, bind)); + } + return WithFrame(Frame(locals), [&]() { return Eval(func->body); }); } @@ -607,6 +626,8 @@ class Interpreter : if (const ClosureNode* closure_node = fn_val.as()) { auto closure = GetRef(closure_node); return this->Invoke(closure, args); + } else if (const RecClosureNode* closure_node = fn_val.as()) { + return this->Invoke(closure_node->clos, args, closure_node->bind); } else { LOG(FATAL) << "internal error: type error, expected function value in the call " << "position";