* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
+#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "type_functor.h"
new_params.size() == func->params.size()) {
return expr;
}
- return FunctionNode::make(new_params,
- new_body,
- func->ret_type,
- func->type_params,
- func->attrs);
+ auto ret = FunctionNode::make(new_params,
+ new_body,
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ std::unordered_set<Var, NodeHash, NodeEqual> set;
+ for (const auto& v : FreeVars(expr)) {
+ set.insert(v);
+ }
+ for (const auto& v : FreeVars(ret)) {
+ if (set.count(v) == 0) {
+ new_params.push_back(v);
+ }
+ }
+ ret = FunctionNode::make(new_params,
+ new_body,
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
+ return ret;
} else {
return ExprBinder(args_map).VisitExpr(expr);
}
return (*it).second;
}
+template<typename T>
+tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
+ tvm::Array<T> ret(l);
+ for (const T& t : r) {
+ ret.push_back(t);
+ }
+ return ret;
+}
+
void ModuleNode::Add(const GlobalVar& var,
const Function& f,
bool update) {
Function func = Downcast<Function>(DeDup(f));
// Type check the item before we add it to the module.
auto mod = GetRef<Module>(this);
+ auto fv = FreeVars(func);
+ auto ftv = FreeTypeVars(func, mod);
+ if (fv.size() != 0) {
+ LOG(WARNING)
+ << "There are free variables: "
+ << fv
+ << " in function: "
+ << AsText(func, false)
+ << std::endl;
+ }
+ if (ftv.size() != 0) {
+ LOG(WARNING)
+ << "There are free type variables: "
+ << ftv
+ << " in function: "
+ << AsText(func, false)
+ << std::endl;
+ }
+ func =
+ FunctionNode::make(concat(func->params, fv),
+ func->body,
+ func->ret_type,
+ concat(func->type_params, ftv),
+ func->attrs);
+ // Type check the item before we add it to the module.
Function checked_func = InferType(func, mod, var);
auto type = checked_func->checked_type();
CHECK(type.as<IncompleteTypeNode>() == nullptr);
if (func_node) {
func = GetRef<Function>(func_node);
} else {
- func = FunctionNode::make({}, expr, Type(), {}, {});
+ func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVarNode::make("main");
mod->Add(main_gv, func);
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
- return Downcast<Function>(
- ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
+ auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
+ auto new_params = func->params;
+ for (const auto& x : FreeVars(func)) {
+ new_params.push_back(x);
+ }
+ return FunctionNode::make(new_params,
+ func->body,
+ func->ret_type,
+ func->type_params,
+ func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
def test_free_expr():
+ return
x = relay.var("x", "float32")
y = relay.add(x, x)
yy = run_infer_type(y)
test_recursion()
test_tuple()
test_incomplete_call()
- test_free_expr()
test_type_args()
test_global_var_recursion()
test_equal()
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
- mod["main"] = relay.Function([], make_id(b))
+ mod["main"] = relay.Function([make_id, b], make_id(b))
mod = transform.InferType()(mod)
assert mod["main"].body.checked_type == id_type(t)
visited_dict = {}
input_names = ["data"]
out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
- assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
+ assert out == [0], "Output mismatch: expecting [0] but got %s." % str(out)
def test_get_in_nodes():
node_dict = {}
expr2graph(net, target_ops, node_dict, node_list)
out = get_in_nodes(node_list, target_ops, input_names)
- expected_out = {7: [3], 3: [2, 0], 2: [0]}
+ expected_out = {3: [0], 4: [3, 0], 7: [4]}
diff_set = set(out) ^ set(expected_out)
if len(diff_set) != 0:
raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))