[FIX][VM] Fix relay vm optimize (#6322)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 25 Aug 2020 16:07:02 +0000 (09:07 -0700)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 16:07:02 +0000 (09:07 -0700)
* [FIX][VM] Fix relay vm optimize

* retrigger ci

python/tvm/relay/backend/vm.py
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
tests/python/relay/test_vm.py

index cb7761b..73b0d22 100644 (file)
@@ -139,7 +139,7 @@ class VMCompiler(object):
         """Generate the kernel library."""
         self._codegen()
 
-    def optimize(self, mod, target=None, params=None):
+    def optimize(self, mod, target=None, target_host=None, params=None):
         """Helper method that optimizes a Relay module via VM.
 
         Parameters
@@ -149,6 +149,11 @@ class VMCompiler(object):
         target : str, :any:`tvm.target.Target`, or dict of str (i.e.
             device/context name) to str/tvm.target.Target, optional
 
+        target_host : str or :any:`tvm.target.Target`, optional
+            The compilation target for host.
+            By default, llvm is used if it is enabled,
+            otherwise a stackvm intepreter is used.
+
         params : dict of str to NDArray
             Input parameters to the graph that do not change
             during inference time. Used for constant folding.
@@ -162,9 +167,10 @@ class VMCompiler(object):
             The parameters of the final module.
         """
         target = self._update_target(target)
+        target_host = self._update_target_host(target, target_host)
         if params:
             self.set_params(params)
-        return self._optimize(mod, target), self.get_params()
+        return self._optimize(mod, target, target_host), self.get_params()
 
     def get_exec(self):
         """Get the VM executable.
index a98f1ef..33854f7 100644 (file)
@@ -806,8 +806,8 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Obje
     });
   } else if (name == "optimize") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      CHECK_EQ(args.num_args, 2);
-      *rv = this->OptimizeModule(args[0], args[1]);
+      CHECK_EQ(args.num_args, 3);
+      *rv = this->OptimizeModule(args[0], args[1], args[2]);
     });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
@@ -835,7 +835,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
   target_host_ = target_host;
 
   // Run the optimizations necessary to target the VM.
-  context_.module = OptimizeModule(mod, targets_);
+  context_.module = OptimizeModule(mod, targets_, target_host_);
 
   // Populate the global map.
   //
@@ -923,7 +923,8 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
   return transform::Sequential(pass_seqs);
 }
 
-IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
+IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
+                                    const Target& target_host) {
   Array<Pass> pass_seqs;
   Array<runtime::String> entry_functions{"main"};
   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
@@ -988,7 +989,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
   // external codegen.
   pass_seqs.push_back(transform::Inline());
 
-  pass_seqs.push_back(MemoryOpt(this->target_host_));
+  pass_seqs.push_back(MemoryOpt(target_host));
 
   transform::Sequential seq(pass_seqs);
   transform::PassContext pass_ctx = PassContext::Current();
index d1e1f7e..b4b86d3 100644 (file)
@@ -112,7 +112,8 @@ class VMCompiler : public runtime::ModuleNode {
   void Codegen();
 
  protected:
-  IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);
+  IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets,
+                          const Target& target_host);
 
   void PopulateGlobalMap();
 
index e96d362..a69f928 100644 (file)
@@ -593,10 +593,20 @@ def test_add_op_broadcast():
     mod["main"] = func
     check_result([x_data, y_data], x_data + y_data, mod=mod)
 
+def test_vm_optimize_dynamic():
+    dtype = 'float32'
+    x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype=dtype)
+    y = relay.var('y', shape=(relay.Any(), relay.Any()), dtype=dtype)
+    mod = tvm.IRModule()
+    mod['main'] = relay.Function([x, y], relay.add(x, y))
+    comp = relay.vm.VMCompiler()
+    opt_mod, _ = comp.optimize(mod, target="llvm")
+    assert "shape_func" in opt_mod.astext(False)
+
 def test_vm_optimize():
     mod, params = testing.synthetic.get_workload()
     comp = relay.vm.VMCompiler()
-    opt_mod, _ = comp.optimize(mod, "llvm", params)
+    opt_mod, _ = comp.optimize(mod, target="llvm", params=params)
 
 def test_loop_free_var():
     x = relay.var('x', shape=(), dtype='int32')