fix calibration pass to support multiple functions (#5768)
authorYi-Hsiang (Sean) Lai <seanlatias@users.noreply.github.com>
Fri, 12 Jun 2020 15:33:20 +0000 (11:33 -0400)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 15:33:20 +0000 (08:33 -0700)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-43-142.us-east-2.compute.internal>
python/tvm/relay/quantize/_calibrate.py

index 59ee51b..9590e87 100644 (file)
@@ -138,10 +138,14 @@ def _set_params(mod, input_scale_func, weight_scale_func):
             const_params[nclip_min] = _make_const(- (valid_range - 1))
             const_params[nclip_max] = _make_const((valid_range - 1))
 
-    func = mod['main']
-    _analysis.post_order_visit(func, visit_func)
-    func = _expr.bind(func, const_params)
-    return IRModule.from_expr(func)
+    main_func = mod['main']
+    _analysis.post_order_visit(main_func, visit_func)
+    main_func = _expr.bind(main_func, const_params)
+    func_dict = {}
+    for global_var, func in mod.functions.items():
+        if global_var.name_hint != 'main':
+            func_dict[global_var] = func
+    return IRModule.from_expr(main_func, func_dict)
 
 
 # weight scale functions