fix relay.build to not change the module argument in place (#5822)
authorThomas Viehmann <tv.code@beamnet.de>
Tue, 16 Jun 2020 21:13:40 +0000 (23:13 +0200)
committerGitHub <noreply@github.com>
Tue, 16 Jun 2020 21:13:40 +0000 (14:13 -0700)
src/relay/backend/build_module.cc
tests/python/relay/test_cpp_build_module.py

index f9ce24d..dea923d 100644 (file)
@@ -244,7 +244,8 @@ class RelayBuildModule : public runtime::ModuleNode {
       GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
       Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
       auto new_main = BindParamsByName(main_func, params);
-      relay_module->Update(main_glb_var, new_main);
+      IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
+      relay_module_ptr->Update(main_glb_var, new_main);
     }
 
     Array<Pass> pass_seqs;
index 8d54384..fa56eb0 100644 (file)
@@ -44,7 +44,12 @@ def test_basic_build():
     targets = {
         tvm.tir.IntImm("int32", ctx.device_type): tgt
     }
-    g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), targets, "llvm", params=params)
+    mod = tvm.IRModule.from_expr(func)
+    func_in_mod = mod["main"]
+    assert mod["main"] == func_in_mod, "cannot compare function to itself"
+
+    g_json, mmod, params = relay.build(mod, targets, "llvm", params=params)
+    assert mod["main"] == func_in_mod, "relay.build changed module in-place"
 
     # test
     rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)