[Relay] Expose vm OptimizeModule to Python (#4800)
authormasahi <masahi129@gmail.com>
Sun, 2 Feb 2020 02:04:44 +0000 (11:04 +0900)
committerGitHub <noreply@github.com>
Sun, 2 Feb 2020 02:04:44 +0000 (11:04 +0900)
* Expose VM OptimizeModule to python

* added missing imports

* fix import

python/tvm/relay/backend/vm.py
python/tvm/relay/scope_builder.py
src/relay/backend/vm/compiler.cc
tests/python/relay/test_vm.py

index 3100900..f1cdefc 100644 (file)
@@ -23,6 +23,7 @@ Implements a Python interface to compiling and executing on the Relay VM.
 import numpy as np
 
 import tvm
+import tvm.ndarray as _nd
 from tvm import autotvm, container
 from tvm.object import Object
 from tvm.relay import expr as _expr
@@ -409,6 +410,8 @@ class VMCompiler(object):
         self._codegen = self.mod["codegen"]
         self._get_exec = self.mod["get_executable"]
         self._set_params_func = self.mod["set_params"]
+        self._get_params_func = self.mod["get_params"]
+        self._optimize = self.mod["optimize"]
 
     def set_params(self, params):
         """Set constant parameters for the model.
@@ -426,6 +429,14 @@ class VMCompiler(object):
             inputs[name] = _expr.const(param)
         self._set_params_func(inputs)
 
+    def get_params(self):
+        """Return the updated weights."""
+        params = self._get_params_func()
+        ret = {}
+        for key, value in params.items():
+            ret[key] = value.data
+        return ret
+
     def lower(self, mod, target=None, target_host=None):
         """Lower the module to VM bytecode.
 
@@ -458,6 +469,33 @@ class VMCompiler(object):
         """Generate the kernel library."""
         self._codegen()
 
+    def optimize(self, mod, target=None, params=None):
+        """Helper method that optimizes a Relay module via VM.
+
+        Parameters
+        ----------
+        mod : relay.Module
+
+        target : str, :any:`tvm.target.Target`, or dict of str (i.e.
+            device/context name) to str/tvm.target.Target, optional
+
+        params : dict of str to NDArray
+            Input parameters to the graph that do not change
+            during inference time. Used for constant folding.
+
+        Returns
+        -------
+        mod : relay.Module
+            The optimized relay module.
+
+        params : dict
+            The parameters of the final module.
+        """
+        target = self._update_target(target)
+        if params:
+            self.set_params(params)
+        return self._optimize(mod, target), self.get_params()
+
     def get_exec(self):
         """Get the VM executable.
 
index 16044c1..43c6532 100644 (file)
@@ -18,6 +18,7 @@
 """The scope builder interface."""
 from __future__ import absolute_import
 
+from . import ty as _ty
 from . import expr as _expr
 from .._ffi import base as _base
 
index cc5d6bc..8d4f4ad 100644 (file)
@@ -772,6 +772,19 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
         this->SetParam(kv.first, kv.second->data);
       }
     });
+  } else if (name == "get_params") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      Map<std::string, Constant> ret;
+      for (const auto& kv : params_) {
+        ret.Set(kv.first, ConstantNode::make(kv.second));
+      }
+      *rv = ret;
+    });
+  } else if (name == "optimize") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK_EQ(args.num_args, 2);
+      *rv = this->OptimizeModule(args[0], args[1]);
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
index d4a7a1a..9ea939c 100644 (file)
@@ -22,6 +22,7 @@ from tvm import relay
 from tvm.relay.scope_builder import ScopeBuilder
 from tvm.relay.testing.config import ctx_list
 from tvm.relay.prelude import Prelude
+from tvm.relay import testing
 import pytest
 
 def check_result(args, expected_result, mod=None):
@@ -570,6 +571,10 @@ def test_add_op_broadcast():
     mod["main"] = func
     check_result([x_data, y_data], x_data + y_data, mod=mod)
 
+def test_vm_optimize():
+    mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18)
+    comp = relay.backend.vm.VMCompiler()
+    opt_mod, _ = comp.optimize(mod, "llvm", params)
 
 if __name__ == "__main__":
     pytest.main([__file__])