[Bugfix] Fix the issue that function pass modifies original module (#3712)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 6 Aug 2019 19:25:59 +0000 (12:25 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 6 Aug 2019 19:25:59 +0000 (12:25 -0700)
* fix

* fix interpreter

python/tvm/relay/backend/interpreter.py
src/relay/pass/pass_manager.cc
tests/python/relay/test_pass_fuse_ops.py

index 64024d6..491720d 100644 (file)
@@ -269,7 +269,6 @@ class Interpreter(Executor):
         self.mod = mod
         self.ctx = ctx
         self.target = target
-        self._intrp = _backend.CreateInterpreter(mod, ctx, target)
 
     def optimize(self):
         """Optimize functions in a module.
@@ -313,5 +312,6 @@ class Interpreter(Executor):
 
             mod = self.optimize()
             opt_expr = Call(mod["main"], relay_args)
-            return self._intrp(opt_expr)
+            _intrp = _backend.CreateInterpreter(mod, self.ctx, self.target)
+            return _intrp(opt_expr)
         return _interp_wrapper
index d63d912..cef8e72 100644 (file)
@@ -314,11 +314,10 @@ Module FunctionPassNode::operator()(const Module& mod,
              << " with opt level: "
              << pass_info->opt_level;
 
-  Module updated_mod = mod;
   // Execute the pass function and return a new module.
+  Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions);
   std::vector<std::pair<GlobalVar, Function> > updates;
-  auto original = mod->functions;
-  for (const auto& it : original) {
+  for (const auto& it : updated_mod->functions) {
     auto updated_func = SkipFunction(it.second)
                             ? it.second
                             : pass_func(it.second, updated_mod, pass_ctx);
index 8bcde88..4c03840 100644 (file)
@@ -512,6 +512,35 @@ def test_fuse_parallel_injective():
     assert relay.analysis.alpha_equal(zz, after)
 
 
+def test_immutable():
+    """Verify the fusion pass won't change original module."""
+    def before():
+        x = relay.var("x", shape=(10, 20))
+        y = relay.add(x, relay.const(1, "float32"))
+        z = relay.exp(y)
+        w = relay.squeeze(z)
+        mod = relay.module.Module()
+        mod["main"] = relay.Function([x], w)
+        return mod
+
+    def expected():
+        x = relay.var("p", shape=(10, 20))
+        y = relay.add(x, relay.const(1, "float32"))
+        z = relay.exp(y)
+        w = relay.squeeze(z)
+        f1 = relay.Function([x], w)
+        x = relay.var("x", shape=(10, 20))
+        y = relay.Call(f1, [x])
+        mod = relay.module.Module()
+        mod["main"] = relay.Function([x], y)
+        return mod
+
+    mod = before()
+    new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
+    assert relay.analysis.alpha_equal(mod, before())
+    assert relay.analysis.alpha_equal(new_mod, expected())
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
@@ -525,3 +554,4 @@ if __name__ == "__main__":
     test_tuple_consecutive()
     test_inception_like()
     test_fuse_parallel_injective()
+    test_immutable()