[Relay][VM] Fix constant folding issue in VM compiler (#4077)
authorWei Chen <ipondering.weic@gmail.com>
Thu, 10 Oct 2019 00:47:04 +0000 (17:47 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Thu, 10 Oct 2019 00:47:04 +0000 (17:47 -0700)
* [Relay][VM] Fix constant folding issue in VM compiler

1. allow pass params when compile a module
2. enhance profiler robustness

* remove dead code

* fix lint

* add get_params

* fix test

* don't pass params back

* remove get_params

* docs

* move compile function to api

* compile clashes with builtin name

* fix compilation error

* remove dead code

python/tvm/relay/backend/profiler_vm.py
python/tvm/relay/backend/vm.py
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
src/runtime/vm/profiler/vm.cc
tests/python/relay/test_vm.py
tests/python/relay/test_vm_serialization.py
tests/python/unittest/test_runtime_vm_profiler.py

index 3adbeca..8ae3161 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
+# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
 """
 The Relay Virtual Machine profiler.
 
 Provides extra APIs for profiling vm execution.
 """
-import tvm
 from . import vm, _vm
 
-def _update_target(target):
-    target = target if target else tvm.target.current_target()
-    if target is None:
-        raise ValueError("Target is not set in env or passed as argument.")
+def compile(mod, target=None, target_host=None, params=None):
+    """
+    Parameters
+    ----------
+    mod : relay.Module
+        The Relay module to build.
 
-    tgts = {}
-    if isinstance(target, (str, tvm.target.Target)):
-        dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
-        tgts[dev_type] = tvm.target.create(target)
-    elif isinstance(target, dict):
-        for dev, tgt in target.items():
-            dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
-            tgts[dev_type] = tvm.target.create(tgt)
-    else:
-        raise TypeError("target is expected to be str, tvm.target.Target, " +
-                        "or dict of str to str/tvm.target.Target, but received " +
-                        "{}".format(type(target)))
-    return tgts
+    target : str, :any:`tvm.target.Target`, or dict of str(i.e.
+        device/context name) to str/tvm.target.Target, optional
+        For heterogeneous compilation, it is a dictionary indicating context
+        to target mapping. For homogeneous compilation, it is a build target.
+
+    target_host : str or :any:`tvm.target.Target`, optional
+        Host compilation target, if target is device.
+        When TVM compiles device specific program such as CUDA,
+        we also need host(CPU) side code to interact with the driver
+        to setup the dimensions and parameters correctly.
+        target_host is used to specify the host side codegen target.
+        By default, llvm is used if it is enabled,
+        otherwise a stackvm intepreter is used.
+
+    params : dict of str to NDArray
+        Input parameters to the graph that do not change
+        during inference time. Used for constant folding.
+
+    Returns
+    -------
+    vm : VirtualMachineProfiler
+        The profile VM runtime.
+    """
+    compiler = VMCompilerProfiler()
+    target = compiler.update_target(target)
+    target_host = compiler.update_target_host(target, target_host)
+    if params:
+        compiler.set_params(params)
+    tophub_context = compiler.tophub_context(target)
+    with tophub_context:
+        compiler._compile(mod, target, target_host)
+    return VirtualMachineProfiler(compiler._get_vm())
 
 class VMCompilerProfiler(vm.VMCompiler):
     """Build Relay module to run on VM runtime."""
@@ -49,36 +69,7 @@ class VMCompilerProfiler(vm.VMCompiler):
         self.mod = _vm._VMCompilerProfiler()
         self._compile = self.mod["compile"]
         self._get_vm = self.mod["get_vm"]
-
-    def compile(self, mod, target=None, target_host=None):
-        """
-        Parameters
-        ----------
-        mod : relay.Module
-            The Relay module to build.
-
-        target : str, :any:`tvm.target.Target`, or dict of str(i.e.
-            device/context name) to str/tvm.target.Target, optional
-            For heterogeneous compilation, it is a dictionary indicating context
-            to target mapping. For homogeneous compilation, it is a build target.
-
-        target_host : str or :any:`tvm.target.Target`, optional
-            Host compilation target, if target is device.
-            When TVM compiles device specific program such as CUDA,
-            we also need host(CPU) side code to interact with the driver
-            to setup the dimensions and parameters correctly.
-            target_host is used to specify the host side codegen target.
-            By default, llvm is used if it is enabled,
-            otherwise a stackvm intepreter is used.
-
-        Returns
-        -------
-        vm : VirtualMachineProfiler
-            The profile VM runtime.
-        """
-        target = _update_target(target)
-        self._compile(mod, target, target_host)
-        return VirtualMachineProfiler(self._get_vm())
+        self._set_params_func = self.mod["set_params"]
 
 class VirtualMachineProfiler(vm.VirtualMachine):
     """Relay profile VM runtime."""
index a6cb91c..e54629d 100644 (file)
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
+# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name, redefined-builtin
 """
 The Relay Virtual Machine.
 
@@ -25,30 +25,11 @@ import numpy as np
 import tvm
 from tvm import autotvm
 from tvm._ffi.runtime_ctypes import TVMByteArray
+from tvm.relay import expr as _expr
 from . import _vm
 from . import vmobj as _obj
 from .interpreter import Executor
 
-
-def _update_target(target):
-    target = target if target else tvm.target.current_target()
-    if target is None:
-        raise ValueError("Target is not set in env or passed as argument.")
-
-    tgts = {}
-    if isinstance(target, (str, tvm.target.Target)):
-        dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
-        tgts[dev_type] = tvm.target.create(target)
-    elif isinstance(target, dict):
-        for dev, tgt in target.items():
-            dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
-            tgts[dev_type] = tvm.target.create(tgt)
-    else:
-        raise TypeError("target is expected to be str, tvm.target.Target, " +
-                        "or dict of str to str/tvm.target.Target, but received " +
-                        "{}".format(type(target)))
-    return tgts
-
 def _convert(arg, cargs):
     if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
         cargs.append(_obj.tensor_object(arg))
@@ -144,40 +125,85 @@ class VirtualMachine(object):
         return self.mod
 
 
+def compile(mod, target=None, target_host=None, params=None):
+    """
+    Parameters
+    ----------
+    mod : relay.Module
+        The Relay module to build.
+
+    target : str, :any:`tvm.target.Target`, or dict of str(i.e.
+        device/context name) to str/tvm.target.Target, optional
+        For heterogeneous compilation, it is a dictionary indicating context
+        to target mapping. For homogeneous compilation, it is a build target.
+
+    target_host : str or :any:`tvm.target.Target`, optional
+        Host compilation target, if target is device.
+        When TVM compiles device specific program such as CUDA,
+        we also need host(CPU) side code to interact with the driver
+        to setup the dimensions and parameters correctly.
+        target_host is used to specify the host side codegen target.
+        By default, llvm is used if it is enabled,
+        otherwise a stackvm intepreter is used.
+
+    params : dict of str to NDArray
+        Input parameters to the graph that do not change
+        during inference time. Used for constant folding.
+
+    Returns
+    -------
+    vm : VirtualMachine
+        The VM runtime.
+    """
+    compiler = VMCompiler()
+
+    target = compiler.update_target(target)
+    target_host = compiler.update_target_host(target, target_host)
+    if params:
+        compiler.set_params(params)
+    tophub_context = compiler.tophub_context(target)
+    with tophub_context:
+        compiler._compile(mod, target, target_host)
+    return VirtualMachine(compiler._get_vm())
+
 class VMCompiler(object):
     """Build Relay module to run on VM runtime."""
     def __init__(self):
         self.mod = _vm._VMCompiler()
         self._compile = self.mod["compile"]
         self._get_vm = self.mod["get_vm"]
+        self._set_params_func = self.mod["set_params"]
+
+    def set_params(self, params):
+        """Set constant parameters for the model"""
+        inputs = {}
+        for name, param in params.items():
+            if isinstance(param, np.ndarray):
+                param = _nd.array(param)
+            inputs[name] = _expr.const(param)
+        self._set_params_func(inputs)
+
+    def update_target(self, target):
+        """Update target"""
+        target = target if target else tvm.target.current_target()
+        if target is None:
+            raise ValueError("Target is not set in env or passed as argument.")
+        tgts = {}
+        if isinstance(target, (str, tvm.target.Target)):
+            dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
+            tgts[dev_type] = tvm.target.create(target)
+        elif isinstance(target, dict):
+            for dev, tgt in target.items():
+                dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
+                tgts[dev_type] = tvm.target.create(tgt)
+        else:
+            raise TypeError("target is expected to be str, tvm.target.Target, " +
+                            "or dict of str to str/tvm.target.Target, but received " +
+                            "{}".format(type(target)))
+        return tgts
 
-    def compile(self, mod, target=None, target_host=None):
-        """
-        Parameters
-        ----------
-        mod : relay.Module
-            The Relay module to build.
-
-        target : str, :any:`tvm.target.Target`, or dict of str(i.e.
-            device/context name) to str/tvm.target.Target, optional
-            For heterogeneous compilation, it is a dictionary indicating context
-            to target mapping. For homogeneous compilation, it is a build target.
-
-        target_host : str or :any:`tvm.target.Target`, optional
-            Host compilation target, if target is device.
-            When TVM compiles device specific program such as CUDA,
-            we also need host(CPU) side code to interact with the driver
-            to setup the dimensions and parameters correctly.
-            target_host is used to specify the host side codegen target.
-            By default, llvm is used if it is enabled,
-            otherwise a stackvm intepreter is used.
-
-        Returns
-        -------
-        vm : VirtualMachine
-            The VM runtime.
-        """
-        target = _update_target(target)
+    def update_target_host(self, target, target_host):
+        """Update target host"""
         target_host = None if target_host == "" else target_host
         if not target_host:
             for device_type, tgt in target.items():
@@ -186,19 +212,16 @@ class VMCompiler(object):
                     break
         if not target_host:
             target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
-        target_host = tvm.target.create(target_host)
+        return tvm.target.create(target_host)
 
+    def tophub_context(self, target):
         # If current dispatch context is fallback context (the default root context),
         # then load pre-tuned parameters from TopHub
         if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
             tophub_context = autotvm.tophub.context(list(target.values()))
         else:
             tophub_context = autotvm.util.EmptyContext()
-
-        with tophub_context:
-            self._compile(mod, target, target_host)
-        return VirtualMachine(self._get_vm())
-
+        return tophub_context
 
 class VMExecutor(Executor):
     """
@@ -226,8 +249,7 @@ class VMExecutor(Executor):
         self.mod = mod
         self.ctx = ctx
         self.target = target
-        compiler = VMCompiler()
-        self.vm = compiler.compile(mod, target)
+        self.vm = compile(mod, target)
         self.vm.init(ctx)
 
     def _make_executor(self, expr=None):
index 013d8b0..00d4fb4 100644 (file)
@@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
   if (name == "compile") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       CHECK_EQ(args.num_args, 3);
-      this->Compile(args[0], args[1], args[2]);
+      Module mod = args[0];
+      this->Compile(mod, args[1], args[2]);
     });
   } else if (name == "get_vm") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       *rv = runtime::Module(vm_);
     });
+  } else if (name == "set_params") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      Map<std::string, Constant> params = args[0];
+      for (const auto& kv : params) {
+        this->SetParam(kv.first, kv.second->data);
+      }
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
   }
 }
 
-void VMCompiler::Compile(const Module& mod_ref,
+void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
+  params_[name] = data_in;
+}
+
+relay::Function VMCompiler::BindParamsByName(
+    relay::Function func,
+    const std::unordered_map<std::string, runtime::NDArray>& params) {
+  std::unordered_map<std::string, relay::Var> name_dict;
+  std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
+  for (auto arg : func->params) {
+    const auto &name = arg->name_hint();
+    if (name_dict.count(name)) {
+      repeat_var.insert(arg);
+    } else {
+      name_dict[name] = arg;
+    }
+  }
+  std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
+  for (auto &kv : params) {
+    if (name_dict.count(kv.first) == 0) {
+      continue;
+    }
+    auto arg = name_dict.at(kv.first);
+    if (repeat_var.count(arg)) {
+      LOG(FATAL) << "Multiple args in the function have name " << kv.first;
+    }
+    bind_dict[arg] = ConstantNode::make(kv.second);
+  }
+  Expr bound_expr = relay::Bind(func, bind_dict);
+  Function ret = Downcast<Function>(bound_expr);
+  CHECK(ret.defined())
+      << "The returning type is expected to be a Relay Function."
+      << "\n";
+  return ret;
+}
+
+
+void VMCompiler::Compile(Module mod,
                          const TargetsMap& targets,
                          const tvm::Target& target_host) {
   CHECK_EQ(targets.size(), 1)
     << "Currently VM compiler doesn't support heterogeneous compilation";
+  if (params_.size()) {
+    auto f = BindParamsByName(mod->Lookup("main"), params_);
+    auto gvar = mod->GetGlobalVar("main");
+    mod->Add(gvar, f);
+  }
 
   InitVM();
   targets_ = targets;
@@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref,
 
   // Run some optimizations first, this code should
   // be moved to pass manager.
-  context_.module = OptimizeModule(mod_ref, targets_);
+  context_.module = OptimizeModule(mod, targets_);
 
   // Populate the global map.
   //
index 14a5035..dff1ef7 100644 (file)
@@ -100,11 +100,37 @@ class VMCompiler : public runtime::ModuleNode {
     vm_ = std::make_shared<VirtualMachine>();
   }
 
-  void Compile(const Module& mod_ref,
+  /*!
+   * \brief Set the parameters
+   *
+   * \param name name of parameter
+   * \param data_in input DLTensor
+   */
+  void SetParam(const std::string& name, runtime::NDArray data_in);
+
+  /*!
+   * \brief Compile functions in a Module
+   *
+   * \param mod Relay Module
+   * \param targets For heterogeneous compilation, it is a dictionary indicating context
+                    to target mapping. For homogeneous compilation, it is a build target.
+   * \param target_host Host compilation target, if target is device.
+   */
+  void Compile(Module mod,
                const TargetsMap& targets,
                const tvm::Target& target_host);
 
  protected:
+  /*!
+   * \brief Bind params to function by using name
+   * \param func Relay function
+   * \param params params dict
+   * \return relay::Function
+   */
+  relay::Function BindParamsByName(
+      relay::Function func,
+      const std::unordered_map<std::string, runtime::NDArray>& params);
+
   Module OptimizeModule(const Module& mod, const TargetsMap& targets);
 
   void PopulateGlobalMap();
@@ -120,6 +146,8 @@ class VMCompiler : public runtime::ModuleNode {
   VMCompilerContext context_;
   /*! \brief Compiled virtual machine. */
   std::shared_ptr<VirtualMachine> vm_;
+  /*! \brief parameters */
+  std::unordered_map<std::string, runtime::NDArray> params_;
 };
 
 }  // namespace vm
index 1d3ac83..5f59f6e 100644 (file)
@@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
                                        Index output_size,
                                        const std::vector<Object>& args) {
   auto ctx = VirtualMachine::GetParamsContext();
+  // warmup
+  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
+                               args);
+  TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
+
   auto op_begin = std::chrono::high_resolution_clock::now();
   VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
                                args);
index f60c533..f643f8a 100644 (file)
@@ -47,15 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
     if isinstance(f, relay.Expr):
         mod = relay.Module()
         mod["main"] = f
-        compiler = relay.vm.VMCompiler()
-        vm = compiler.compile(mod, target)
+        vm = relay.vm.compile(mod, target)
         vm.init(tvm.cpu())
         return vm.invoke("main", *args)
     else:
         assert isinstance(f, relay.Module), "expected expression or module"
         mod = f
-        compiler = relay.vm.VMCompiler()
-        vm = compiler.compile(mod, target)
+        vm = relay.vm.compile(mod, target)
         vm.init(tvm.cpu())
         ret = vm.invoke("main", *args)
         return ret
@@ -582,8 +580,7 @@ def test_set_params():
     b = relay.var('b', shape=(6,))
     y = relay.nn.bias_add(relay.nn.dense(x, w), b)
     mod["main"] = relay.Function([x, w, b], y)
-    compiler = relay.vm.VMCompiler()
-    vm = compiler.compile(mod, 'llvm')
+    vm = relay.vm.compile(mod, 'llvm')
     vm.init(tvm.cpu())
     
     x_np = np.random.uniform(size=(10, 5)).astype('float32')
index a32ec27..3a317fc 100644 (file)
@@ -28,18 +28,16 @@ from tvm.relay.prelude import Prelude
 from tvm.contrib import util
 from tvm.relay import testing
 
-def create_vm(f, ctx=tvm.cpu(), target="llvm"):
+def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None):
     if isinstance(f, relay.Expr):
         mod = relay.Module()
         mod["main"] = f
-        compiler = relay.vm.VMCompiler()
-        vm = compiler.compile(mod, target)
+        vm = _vm.compile(mod, target=target, params=params)
         vm.init(ctx)
         return vm
     else:
         assert isinstance(f, relay.Module), "expected mod as relay.Module"
-        compiler = relay.vm.VMCompiler()
-        vm = compiler.compile(f, target)
+        vm = _vm.compile(f, target=target, params=params)
         vm.init(ctx)
         return vm
 
@@ -61,7 +59,7 @@ def run_network(mod,
         return result.asnumpy().astype(dtype)
 
     def get_serialized_output(mod, data, params, target, ctx, dtype='float32'):
-        vm = create_vm(mod, ctx, target)
+        vm = create_vm(mod, ctx, target, params=params)
         ser = serializer.Serializer(vm)
         code, lib = ser.serialize()
         deser = deserializer.Deserializer(code, lib)
index 4281ccc..b5ce0ec 100644 (file)
@@ -22,13 +22,11 @@ import pytest
 from tvm import relay
 from tvm.relay.testing import resnet
 
-@pytest.mark.skip
 def test_basic():
     mod, params = resnet.get_workload()
-    compiler = relay.profiler_vm.VMCompilerProfiler()
     target = 'llvm'
     ctx = tvm.cpu()
-    vm = compiler.compile(mod, target)
+    vm = relay.profiler_vm.compile(mod, target)
     vm.init(ctx)
     vm.load_params(params)