fix RemoveUnusedFunctions pass
authorZhi Chen <chzhi@amazon.com>
Mon, 13 Jan 2020 22:40:26 +0000 (22:40 +0000)
committerWei Chen <ipondering.weic@gmail.com>
Tue, 14 Jan 2020 02:51:26 +0000 (18:51 -0800)
src/relay/backend/vm/removed_unused_funcs.cc
tests/python/relay/test_pass_remove_unused_functions.py

index 419b09588a7b1301ecaa719957ba71b6a28da6b2..23bcdc373e260d282fd34baf7038c6a5b794c740 100644 (file)
@@ -53,33 +53,20 @@ struct CallTracer : ExprVisitor {
       called_funcs_{},
       visiting_{} {}
 
-  void CheckExpr(const Expr& expr) {
-    if (auto func_node = expr.as<FunctionNode>()) {
-      auto func = GetRef<Function>(func_node);
-      auto it = visiting_.find(func);
-      if (it != visiting_.end()) {
-        return;
-      }
-      visiting_.insert(func);
-      VisitExpr(func);
-    } else if (auto global = expr.as<GlobalVarNode>()) {
-      called_funcs_.insert(global->name_hint);
-      auto func = module_->Lookup(global->name_hint);
-      auto it = visiting_.find(func);
-      if (it != visiting_.end()) {
-        return;
-      }
-      visiting_.insert(func);
-      VisitExpr(func);
-    } else {
-      VisitExpr(expr);
-    }
+  void VisitExpr_(const GlobalVarNode* op) final {
+    called_funcs_.insert(op->name_hint);
+    auto func = module_->Lookup(op->name_hint);
+    VisitExpr(func);
   }
 
-  void VisitExpr_(const CallNode* call_node) final {
-    CheckExpr(call_node->op);
-    for (auto param : call_node->args) {
-      CheckExpr(param);
+  void VisitExpr_(const FunctionNode* func_node) final {
+    auto func = GetRef<Function>(func_node);
+    if (visiting_.find(func) == visiting_.end()) {
+      visiting_.insert(func);
+      for (auto param : func_node->params) {
+        ExprVisitor::VisitExpr(param);
+      }
+      ExprVisitor::VisitExpr(func_node->body);
     }
   }
 
index 97d8646922c08f86f23b970f354f1ef2773f85ce..2a4cbd2579e7f0a630b88da5e9703e5377face4c 100644 (file)
@@ -20,6 +20,7 @@ from tvm import relay
 from tvm.relay import transform
 from tvm.relay.prelude import Prelude
 
+
 def test_remove_all_prelude_functions():
     mod = relay.Module()
     p = Prelude(mod)
@@ -29,6 +30,7 @@ def test_remove_all_prelude_functions():
     l = set([x[0].name_hint for x in mod.functions.items()])
     assert l == set(['main'])
 
+
 def test_remove_all_prelude_functions_but_referenced_functions():
     mod = relay.Module()
     p = Prelude(mod)
@@ -42,6 +44,7 @@ def test_remove_all_prelude_functions_but_referenced_functions():
     l = set([x[0].name_hint for x in mod.functions.items()])
     assert l == set(['id_func', 'main'])
 
+
 def test_keep_only_referenced_prelude_functions():
     mod = relay.Module()
     p = Prelude(mod)
@@ -54,6 +57,7 @@ def test_keep_only_referenced_prelude_functions():
     l = set([x[0].name_hint for x in mod.functions.items()])
     assert l == set(['tl', 'hd', 'main'])
 
+
 def test_multiple_entry_functions():
     mod = relay.Module()
     p = Prelude(mod)
@@ -72,6 +76,7 @@ def test_multiple_entry_functions():
     l = set([x[0].name_hint for x in mod.functions.items()])
     assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
 
+
 def test_globalvar_as_call_arg():
     mod = relay.Module()
     p = Prelude(mod)
@@ -88,5 +93,24 @@ def test_globalvar_as_call_arg():
     l = set([x[0].name_hint for x in mod.functions.items()])
     assert 'tensor_array_int32' in l
 
+
+def test_call_globalvar_without_args():
+    def get_mod():
+        mod = relay.Module({})
+        fn1 = relay.Function([], relay.const(1))
+        fn2 = relay.Function([], relay.const(2))
+        g1 = relay.GlobalVar('g1')
+        g2 = relay.GlobalVar('g2')
+        mod[g1] = fn1
+        mod[g2] = fn2
+        p = relay.var('p', 'bool')
+        mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
+        return mod
+    mod = get_mod()
+    ref_mod = get_mod()
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    assert relay.alpha_equal(mod, ref_mod)
+
+
 if __name__ == '__main__':
     pytest.main()