From c7274fd3b0f693fc6214a450e13d5e99026337ae Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 10 Jun 2020 20:49:57 -0700 Subject: [PATCH] [Relay] Fix for recursive let (#5757) * Make let processing iterative * Try again * Fix pretty printer overflow * cleanup * fix lint * Fix text printer Co-authored-by: Jared Roesch Co-authored-by: Jared Roesch --- python/tvm/relay/transform/memory_plan.py | 11 +++++++++++ src/printer/relay_text_printer.cc | 21 +++++++++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index e359a9e..8f21af9 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -330,6 +330,17 @@ class LiftConst(ExprMutator): fn.type_params, fn.attrs) + def visit_let(self, let): + bindings = [] + while isinstance(let, expr.Let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + bindings.append((new_var, new_val)) + let = let.body + + new_body = self.visit(let) + return mk_let(bindings, new_body) + @function_pass(opt_level=0) class MemoryPlan: """An explicit pass wrapper around StorageCoalesce.""" diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 981d0c3..a09e24b 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -364,12 +364,21 @@ Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { } Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { - Doc doc; - doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << ";" - << Doc::NewLine(); - // we use a scope here so GNF hoisting doesn't escape too far - // and nested, unique lets are not hoisted - doc << PrintScope(op->body); + int n = 0; + Expr let = GetRef(op); + while (auto let_node = let.as()) { + Doc doc; + doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";" + << Doc::NewLine(); + doc_stack_.push_back(doc); + let = let_node->body; + ++n; + } + Doc doc = PrintScope(let); + for (int i = 0; i < n; ++i) { + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + } return doc; } -- 2.7.4