[Relay][Pass] Add pass to remove unused functions in relay module (#4334)
authorWei Chen <ipondering.weic@gmail.com>
Fri, 15 Nov 2019 01:52:01 +0000 (17:52 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 15 Nov 2019 01:52:01 +0000 (17:52 -0800)
* [Relay][Pass] Add pass to remove unused functions in relay module

* Add tests

* Fix lint

* Fix visit order

* Add pass argument

* Fix

python/tvm/relay/transform.py
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/removed_unused_funcs.cc [new file with mode: 0644]
tests/python/relay/test_pass_remove_unused_functions.py [new file with mode: 0644]

index d3509da..0a7512a 100644 (file)
@@ -297,6 +297,22 @@ def BackwardFoldScaleAxis():
     """
     return _transform.BackwardFoldScaleAxis()
 
+def RemoveUnusedFunctions(entry_functions=None):
+    """Remove unused global relay functions in a relay module.
+
+    Parameters
+    ----------
+    entry_functions: list[string]
+        The set of entry functions to start from.
+
+    Returns
+    -------
+    ret : tvm.relay.Pass
+        The registered pass to remove unused functions.
+    """
+    if entry_functions is None:
+        entry_functions = ['main']
+    return _transform.RemoveUnusedFunctions(entry_functions)
 
 def ForwardFoldScaleAxis():
     """Fold the scaling of axis into weights of conv2d/dense.
index 7f828c4..06705b4 100644 (file)
@@ -54,6 +54,7 @@ namespace transform {
 
 Pass LambdaLift();
 Pass InlinePrimitives();
+Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
 
 Pass ManifestAlloc(Target target_host) {
   auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
@@ -863,6 +864,8 @@ void VMCompiler::Compile(Module mod,
 
 Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
   Array<Pass> pass_seqs;
+  Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
+  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
   // Run all dialect legalization passes.
   pass_seqs.push_back(relay::qnn::transform::Legalize());
 
diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc
new file mode 100644 (file)
index 0000000..a012040
--- /dev/null
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/remove_unused_funcs.cc
+ * \brief Remove unused global relay functions in a relay module.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+/**
+ * \brief Detects all the functions that can be possibly called by entry function.
+ */
+struct CallTracer : ExprVisitor {
+  Module module_;
+
+  // Record the names of all encountered functions
+  std::unordered_set<std::string> called_funcs_;
+
+  // Record the expressions that are being visited
+  std::unordered_set<Expr, NodeHash, NodeEqual> visiting_;
+
+  explicit CallTracer(const Module& module)
+    : module_{module},
+      called_funcs_{},
+      visiting_{} {}
+
+  void VisitExpr_(const CallNode* call_node) final {
+    Expr op = call_node->op;
+    for (auto param : call_node->args) {
+      VisitExpr(param);
+    }
+    if (auto func_node = op.as<FunctionNode>()) {
+      auto func = GetRef<Function>(func_node);
+      auto it = visiting_.find(func);
+      if (it != visiting_.end()) {
+        return;
+      }
+      visiting_.insert(func);
+      VisitExpr(func);
+    } else if (auto global = op.as<GlobalVarNode>()) {
+      called_funcs_.insert(global->name_hint);
+      auto func = module_->Lookup(global->name_hint);
+      auto it = visiting_.find(func);
+      if (it != visiting_.end()) {
+        return;
+      }
+      visiting_.insert(func);
+      VisitExpr(func);
+    }
+  }
+
+  std::unordered_set<std::string> Trace(const std::string& entry) {
+    called_funcs_.insert(entry);
+    auto main_func = module_->Lookup(entry);
+    VisitExpr(main_func);
+    return called_funcs_;
+  }
+};
+
+/*!
+ * \brief Remove functions that are not used.
+ *
+ * \param module The Relay module.
+ * \param entry_funcs The set of functions that can be entry function.
+ * 
+ * \return The module with dead functions removed.
+ */
+Module RemoveUnusedFunctions(const Module& module,
+                             Array<tvm::Expr> entry_funcs) {
+  std::unordered_set<std::string> called_funcs{};
+  for (auto entry : entry_funcs) {
+    auto* str_name = entry.as<ir::StringImm>();
+    auto funcs = CallTracer(module).Trace(str_name->value);
+    called_funcs.insert(funcs.cbegin(), funcs.cend());
+  }
+  auto existing_functions = module->functions;
+  for (auto f : existing_functions) {
+    auto it = called_funcs.find(f.first->name_hint);
+    if (it == called_funcs.end()) {
+      module->Remove(f.first);
+    }
+  }
+  return module;
+}
+
+}  // namespace vm
+
+namespace transform {
+
+Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
+  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
+    [=](Module m, PassContext pc) {
+    return relay::vm::RemoveUnusedFunctions(m, entry_functions);
+  };
+  return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
+}
+
+TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions")
+.set_body_typed(RemoveUnusedFunctions);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py
new file mode 100644 (file)
index 0000000..c4a0c41
--- /dev/null
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import transform
+from tvm.relay.prelude import Prelude
+
+def test_remove_all_prelude_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    x = relay.var("x", shape=(1, 16))
+    mod["main"] = relay.Function([x], x)
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['main'])
+
+def test_remove_all_prelude_functions_but_referenced_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    x = relay.var("x", shape=(1, 16))
+    id_func = relay.Function([x], x)
+    id_name = relay.GlobalVar('id_func')
+    mod[id_name] = id_func
+
+    mod["main"] = relay.Function([x], id_name(x))
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['id_func', 'main'])
+
+def test_keep_only_referenced_prelude_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    l = p.nil()
+    for i in [4, 3, 2, 1, 0]:
+        l = p.cons(relay.const(i), l)
+    body = p.hd(p.tl(p.tl(l)))
+    mod["main"] = relay.Function([], body)
+    mod = relay.transform.RemoveUnusedFunctions()(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['tl', 'hd', 'main'])
+
+def test_multiple_entry_functions():
+    mod = relay.Module()
+    p = Prelude(mod)
+    l = p.nil()
+    for i in [4, 3, 2, 1, 0]:
+        l = p.cons(relay.const(i), l)
+    body = p.hd(p.tl(p.tl(l)))
+    mod["main1"] = relay.Function([], body)
+
+    x = relay.var("x", shape=(1, 16))
+    id_func = relay.Function([x], x)
+    id_name = relay.GlobalVar('id_func')
+    mod[id_name] = id_func
+    mod["main2"] = relay.Function([x], id_name(x))
+    mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod)
+    l = set([x[0].name_hint for x in mod.functions.items()])
+    assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
+
+if __name__ == '__main__':
+    pytest.main()