called_funcs_{},
visiting_{} {}
- void CheckExpr(const Expr& expr) {
- if (auto func_node = expr.as<FunctionNode>()) {
- auto func = GetRef<Function>(func_node);
- auto it = visiting_.find(func);
- if (it != visiting_.end()) {
- return;
- }
- visiting_.insert(func);
- VisitExpr(func);
- } else if (auto global = expr.as<GlobalVarNode>()) {
- called_funcs_.insert(global->name_hint);
- auto func = module_->Lookup(global->name_hint);
- auto it = visiting_.find(func);
- if (it != visiting_.end()) {
- return;
- }
- visiting_.insert(func);
- VisitExpr(func);
- } else {
- VisitExpr(expr);
- }
+ void VisitExpr_(const GlobalVarNode* op) final {
+ called_funcs_.insert(op->name_hint);
+ auto func = module_->Lookup(op->name_hint);
+ VisitExpr(func);
}
- void VisitExpr_(const CallNode* call_node) final {
- CheckExpr(call_node->op);
- for (auto param : call_node->args) {
- CheckExpr(param);
+ void VisitExpr_(const FunctionNode* func_node) final {
+ auto func = GetRef<Function>(func_node);
+ if (visiting_.find(func) == visiting_.end()) {
+ visiting_.insert(func);
+ for (auto param : func_node->params) {
+ ExprVisitor::VisitExpr(param);
+ }
+ ExprVisitor::VisitExpr(func_node->body);
}
}
from tvm.relay import transform
from tvm.relay.prelude import Prelude
+
def test_remove_all_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['main'])
+
def test_remove_all_prelude_functions_but_referenced_functions():
mod = relay.Module()
p = Prelude(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['id_func', 'main'])
+
def test_keep_only_referenced_prelude_functions():
mod = relay.Module()
p = Prelude(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main'])
+
def test_multiple_entry_functions():
mod = relay.Module()
p = Prelude(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
+
def test_globalvar_as_call_arg():
mod = relay.Module()
p = Prelude(mod)
l = set([x[0].name_hint for x in mod.functions.items()])
assert 'tensor_array_int32' in l
+
+def test_call_globalvar_without_args():
+ def get_mod():
+ mod = relay.Module({})
+ fn1 = relay.Function([], relay.const(1))
+ fn2 = relay.Function([], relay.const(2))
+ g1 = relay.GlobalVar('g1')
+ g2 = relay.GlobalVar('g2')
+ mod[g1] = fn1
+ mod[g2] = fn2
+ p = relay.var('p', 'bool')
+ mod['main'] = relay.Function([p], relay.Call(relay.If(p, g1, g2), []))
+ return mod
+ mod = get_mod()
+ ref_mod = get_mod()
+ mod = relay.transform.RemoveUnusedFunctions()(mod)
+ assert relay.alpha_equal(mod, ref_mod)
+
+
if __name__ == '__main__':
pytest.main()