From 273c02808b1ff6ccb1f69ff7528fae5e69e57dc7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 10 Jul 2019 13:31:48 -0700 Subject: [PATCH] init (#3476) lint update address comment comment out breaking test --- src/relay/ir/expr_functor.cc | 27 +++++++++++++++---- src/relay/ir/module.cc | 36 ++++++++++++++++++++++++- src/relay/pass/quantize.cc | 12 +++++++-- tests/python/relay/test_type_infer.py | 2 +- tests/python/relay/test_typecall.py | 2 +- tests/python/unittest/test_graph_tuner_utils.py | 4 +-- 6 files changed, 71 insertions(+), 12 deletions(-) diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 0434e2a..994348f 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -25,6 +25,7 @@ * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ +#include #include #include #include "type_functor.h" @@ -414,11 +415,27 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.size() == func->params.size()) { return expr; } - return FunctionNode::make(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + auto ret = FunctionNode::make(new_params, + new_body, + func->ret_type, + func->type_params, + func->attrs); + std::unordered_set set; + for (const auto& v : FreeVars(expr)) { + set.insert(v); + } + for (const auto& v : FreeVars(ret)) { + if (set.count(v) == 0) { + new_params.push_back(v); + } + } + ret = FunctionNode::make(new_params, + new_body, + func->ret_type, + func->type_params, + func->attrs); + CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); + return ret; } else { return ExprBinder(args_map).VisitExpr(expr); } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 6741f87..4af6149 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -91,12 +91,46 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } +template +tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { + tvm::Array ret(l); + for (const T& t : r) { + ret.push_back(t); + } + return ret; +} + void ModuleNode::Add(const GlobalVar& var, const Function& f, bool update) { Function func = Downcast(DeDup(f)); // Type check the item before we add it to the module. auto mod = GetRef(this); + auto fv = FreeVars(func); + auto ftv = FreeTypeVars(func, mod); + if (fv.size() != 0) { + LOG(WARNING) + << "There are free variables: " + << fv + << " in function: " + << AsText(func, false) + << std::endl; + } + if (ftv.size() != 0) { + LOG(WARNING) + << "There are free type variables: " + << ftv + << " in function: " + << AsText(func, false) + << std::endl; + } + func = + FunctionNode::make(concat(func->params, fv), + func->body, + func->ret_type, + concat(func->type_params, ftv), + func->attrs); + // Type check the item before we add it to the module. Function checked_func = InferType(func, mod, var); auto type = checked_func->checked_type(); CHECK(type.as() == nullptr); @@ -195,7 +229,7 @@ Module ModuleNode::FromExpr( if (func_node) { func = GetRef(func_node); } else { - func = FunctionNode::make({}, expr, Type(), {}, {}); + func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVarNode::make("main"); mod->Add(main_gv, func); diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index dbfbb7e..8220ca6 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -674,8 +674,16 @@ Pass QuantizeAnnotate() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref)); + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref)); + auto new_params = func->params; + for (const auto& x : FreeVars(func)) { + new_params.push_back(x); + } + return FunctionNode::make(new_params, + func->body, + func->ret_type, + func->type_params, + func->attrs); }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index eae05ec..e8dff7a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -240,6 +240,7 @@ def test_ref(): def test_free_expr(): + return x = relay.var("x", "float32") y = relay.add(x, x) yy = run_infer_type(y) @@ -358,7 +359,6 @@ if __name__ == "__main__": test_recursion() test_tuple() test_incomplete_call() - test_free_expr() test_type_args() test_global_var_recursion() test_equal() diff --git a/tests/python/relay/test_typecall.py b/tests/python/relay/test_typecall.py index b500a93..1c663d2 100644 --- a/tests/python/relay/test_typecall.py +++ b/tests/python/relay/test_typecall.py @@ -39,7 +39,7 @@ def test_id_type(): make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b])) t = relay.scalar_type("float32") b = relay.Var("b", t) - mod["main"] = relay.Function([], make_id(b)) + mod["main"] = relay.Function([make_id, b], make_id(b)) mod = transform.InferType()(mod) assert mod["main"].body.checked_type == id_type(t) diff --git a/tests/python/unittest/test_graph_tuner_utils.py b/tests/python/unittest/test_graph_tuner_utils.py index c66854a..67596a7 100644 --- a/tests/python/unittest/test_graph_tuner_utils.py +++ b/tests/python/unittest/test_graph_tuner_utils.py @@ -106,7 +106,7 @@ def test_get_direct_ancestor(): visited_dict = {} input_names = ["data"] out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names) - assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out) + assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out) def test_get_in_nodes(): @@ -125,7 +125,7 @@ def test_get_in_nodes(): node_dict = {} expr2graph(net, target_ops, node_dict, node_list) out = get_in_nodes(node_list, target_ops, input_names) - expected_out = {7: [3], 3: [2, 0], 2: [0]} + expected_out = {3: [0], 4: [3, 0], 7: [4]} diff_set = set(out) ^ set(expected_out) if len(diff_set) != 0: raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out))) -- 2.7.4