RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);
+/*! \brief A Relay Recursive Closure. A closure that has a name. */
+class RecClosure;
+
+/*! \brief The container type of RecClosure. */
+class RecClosureNode : public ValueNode {
+ public:
+ /*! \brief The closure. */
+ Closure clos;
+ /*! \brief variable the closure bind to. */
+ Var bind;
+
+ RecClosureNode() {}
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("clos", &clos);
+ v->Visit("bind", &bind);
+ }
+
+ TVM_DLL static RecClosure make(Closure clos, Var bind);
+
+ static constexpr const char* _type_key = "relay.RecClosure";
+ TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
+};
+
+RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);
+
/*! \brief A tuple value. */
class TupleValue;
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
- p->stream << "ClosureNode(" << node->func << ")";
+ p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});
+
+// TODO(@jroesch): this doesn't support mutual letrec
+/* Value Implementation */
+RecClosure RecClosureNode::make(Closure clos, Var bind) {
+ NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
+ n->clos = std::move(clos);
+ n->bind = std::move(bind);
+ return RecClosure(n);
+}
+
+TVM_REGISTER_API("relay._make.RecClosure")
+.set_body_typed(RecClosureNode::make);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
+ p->stream << "RecClosureNode(" << node->clos << ")";
+ });
+
TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value;
return TupleValueNode::make(values);
}
- // TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
- auto mut_closure =
- static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
if (letrec_name.defined()) {
- mut_closure->env.Set(letrec_name, closure);
+ return RecClosureNode::make(closure, letrec_name);
}
return std::move(closure);
}
}
// Invoke the closure
- Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
+ Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args);
locals.Set(func->params[i], args[i]);
}
- // Add the var to value mappings from the Closure's modironment.
+ // Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
}
+ if (bind.defined()) {
+ locals.Set(bind, RecClosureNode::make(closure, bind));
+ }
+
return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
+ } else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
+ return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";