[Relay] Fix for recursive let (#5757)
authorHaichen Shen <shenhaichen@gmail.com>
Thu, 11 Jun 2020 03:49:57 +0000 (20:49 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Jun 2020 03:49:57 +0000 (20:49 -0700)
* Make let processing iterative

* Try again

* Fix pretty printer overflow

* cleanup

* fix lint

* Fix text printer

Co-authored-by: Jared Roesch <roeschinc@gmail.com>
Co-authored-by: Jared Roesch <jroesch@octoml.ai>
python/tvm/relay/transform/memory_plan.py
src/printer/relay_text_printer.cc

index e359a9e..8f21af9 100644 (file)
@@ -330,6 +330,17 @@ class LiftConst(ExprMutator):
             fn.type_params,
             fn.attrs)
 
+    def visit_let(self, let):
+        bindings = []
+        while isinstance(let, expr.Let):
+            new_var = self.visit(let.var)
+            new_val = self.visit(let.value)
+            bindings.append((new_var, new_val))
+            let = let.body
+
+        new_body = self.visit(let)
+        return mk_let(bindings, new_body)
+
 @function_pass(opt_level=0)
 class MemoryPlan:
     """An explicit pass wrapper around StorageCoalesce."""
index 981d0c3..a09e24b 100644 (file)
@@ -364,12 +364,21 @@ Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
 }
 
 Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
-  Doc doc;
-  doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << ";"
-      << Doc::NewLine();
-  // we use a scope here so GNF hoisting doesn't escape too far
-  // and nested, unique lets are not hoisted
-  doc << PrintScope(op->body);
+  int n = 0;
+  Expr let = GetRef<Let>(op);
+  while (auto let_node = let.as<LetNode>()) {
+    Doc doc;
+    doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";"
+        << Doc::NewLine();
+    doc_stack_.push_back(doc);
+    let = let_node->body;
+    ++n;
+  }
+  Doc doc = PrintScope(let);
+  for (int i = 0; i < n; ++i) {
+    doc = doc_stack_.back() << doc;
+    doc_stack_.pop_back();
+  }
   return doc;
 }