[Relay][VM] Clean up the VM and VM profiler code (#4391)
authorHaichen Shen <shenhaichen@gmail.com>
Fri, 22 Nov 2019 00:01:01 +0000 (16:01 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 22 Nov 2019 00:01:01 +0000 (16:01 -0800)
* [VM] add a few more API to vm

* [VM][Fix] fix vm convert args

* [VM] a few fixes

* rename fields

* update

* update vm profiler

* x

* add doc

* lint

* fix test

* address comments

include/tvm/runtime/vm.h
python/tvm/relay/backend/profiler_vm.py
python/tvm/relay/backend/vm.py
src/relay/backend/vm/profiler/compiler.cc [deleted file]
src/runtime/vm/executable.cc
src/runtime/vm/profiler/vm.cc
src/runtime/vm/profiler/vm.h
src/runtime/vm/vm.cc
tests/python/relay/test_vm.py
tests/python/relay/test_vm_serialization.py
tests/python/unittest/test_runtime_vm_profiler.py

index 317b535..f7188e4 100644 (file)
@@ -268,125 +268,142 @@ struct Instruction {
     } alloc_storage;
   };
 
-  /*! \brief Construct a return instruction.
-   *  \param return_reg The register containing the return value.
-   *  \return The return instruction.
-   * */
+  /*!
+   * \brief Construct a return instruction.
+   * \param return_reg The register containing the return value.
+   * \return The return instruction.
+   */
   static Instruction Ret(RegName return_reg);
-  /*! \brief Construct a fatal instruction.
-   *  \return The fatal instruction.
-   * */
+  /*!
+   * \brief Construct a fatal instruction.
+   * \return The fatal instruction.
+   */
   static Instruction Fatal();
-  /*! \brief Construct a invoke packed instruction.
-   *  \param packed_index The index of the packed function.
-   *  \param arity The arity of the function.
-   *  \param output_size The number of outputs of the packed function.
-   *  \param args The argument registers.
-   *  \return The invoke packed instruction.
+  /*!
+   * \brief Construct a invoke packed instruction.
+   * \param packed_index The index of the packed function.
+   * \param arity The arity of the function.
+   * \param output_size The number of outputs of the packed function.
+   * \param args The argument registers.
+   * \return The invoke packed instruction.
    */
   static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
                                   const std::vector<RegName>& args);
-  /*! \brief Construct an allocate tensor instruction with constant shape.
-   *  \param storage The storage to allocate out of.
-   *  \param shape The shape of the tensor.
-   *  \param dtype The dtype of the tensor.
-   *  \param dst The destination register.
-   *  \return The allocate tensor instruction.
+  /*!
+   * \brief Construct an allocate tensor instruction with constant shape.
+   * \param storage The storage to allocate out of.
+   * \param shape The shape of the tensor.
+   * \param dtype The dtype of the tensor.
+   * \param dst The destination register.
+   * \return The allocate tensor instruction.
    */
   static Instruction AllocTensor(RegName storage,
                                  const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
-  /*! \brief Construct an allocate tensor instruction with register.
-   *  \param storage The storage to allocate out of.
-   *  \param shape_register The register containing the shape.
-   *  \param dtype The dtype of the tensor.
-   *  \param dst The destination register.
-   *  \return The allocate tensor instruction.
+  /*!
+   * \brief Construct an allocate tensor instruction with register.
+   * \param storage The storage to allocate out of.
+   * \param shape_register The register containing the shape.
+   * \param dtype The dtype of the tensor.
+   * \param dst The destination register.
+   * \return The allocate tensor instruction.
    */
   static Instruction AllocTensorReg(RegName storage,
                                     RegName shape_register, DLDataType dtype, RegName dst);
-  /*! \brief Construct an allocate datatype instruction.
-   *  \param tag The datatype tag.
-   *  \param num_fields The number of fields for the datatype.
-   *  \param fields The registers containing the fields.
-   *  \param dst The register name of the destination.
-   *  \return The allocate instruction tensor.
+  /*!
+   * \brief Construct an allocate datatype instruction.
+   * \param tag The datatype tag.
+   * \param num_fields The number of fields for the datatype.
+   * \param fields The registers containing the fields.
+   * \param dst The register name of the destination.
+   * \return The allocate instruction tensor.
    */
   static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
                               RegName dst);
-  /*! \brief Construct an allocate closure instruction.
-   *  \param func_index The index of the function table.
-   *  \param num_freevar The number of free variables.
-   *  \param free_vars The registers of the free variables.
-   *  \param dst The destination register.
-   *  \return The allocate closure instruction.
+  /*!
+   * \brief Construct an allocate closure instruction.
+   * \param func_index The index of the function table.
+   * \param num_freevar The number of free variables.
+   * \param free_vars The registers of the free variables.
+   * \param dst The destination register.
+   * \return The allocate closure instruction.
    */
   static Instruction AllocClosure(Index func_index, Index num_freevar,
                                   const std::vector<RegName>& free_vars, RegName dst);
-  /*! \brief Construct a get field instruction.
-   *  \param object_reg The register containing the object to project from.
-   *  \param field_index The field to read out of the object.
-   *  \param dst The destination register.
-   *  \return The get field instruction.
+  /*!
+   * \brief Construct a get field instruction.
+   * \param object_reg The register containing the object to project from.
+   * \param field_index The field to read out of the object.
+   * \param dst The destination register.
+   * \return The get field instruction.
    */
   static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
-  /*! \brief Construct a get_tag instruction.
-   *  \param object_reg The register containing the object to project from.
-   *  \param dst The destination register.
-   *  \return The get_tag instruction.
+  /*!
+   * \brief Construct a get_tag instruction.
+   * \param object_reg The register containing the object to project from.
+   * \param dst The destination register.
+   * \return The get_tag instruction.
    */
   static Instruction GetTag(RegName object_reg, RegName dst);
-  /*! \brief Construct an if instruction.
-   *  \param test The register containing the test value.
-   *  \param target The register containing the target value.
-   *  \param true_branch The offset to the true branch.
-   *  \param false_branch The offset to the false branch.
-   *  \return The if instruction.
+  /*!
+   * \brief Construct an if instruction.
+   * \param test The register containing the test value.
+   * \param target The register containing the target value.
+   * \param true_branch The offset to the true branch.
+   * \param false_branch The offset to the false branch.
+   * \return The if instruction.
    */
   static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
-  /*! \brief Construct a goto instruction.
-   *  \param pc_offset The offset from the current pc.
-   *  \return The goto instruction.
+  /*!
+   * \brief Construct a goto instruction.
+   * \param pc_offset The offset from the current pc.
+   * \return The goto instruction.
    */
   static Instruction Goto(Index pc_offset);
-  /*! \brief Construct an invoke instruction.
-   *  \param func_index The index of the function to invoke.
-   *  \param args The registers containing the arguments.
-   *  \param dst The destination register.
-   *  \return The invoke instruction.
+  /*!
+   * \brief Construct an invoke instruction.
+   * \param func_index The index of the function to invoke.
+   * \param args The registers containing the arguments.
+   * \param dst The destination register.
+   * \return The invoke instruction.
    */
   static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
-  /*! \brief Construct an invoke closure instruction.
-   *  \param closure The register of the closure to invoke.
-   *  \param args The registers containing the arguments.
-   *  \param dst The destination register.
-   *  \return The invoke closure instruction.
+  /*!
+   * \brief Construct an invoke closure instruction.
+   * \param closure The register of the closure to invoke.
+   * \param args The registers containing the arguments.
+   * \param dst The destination register.
+   * \return The invoke closure instruction.
    */
   static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
-  /*! \brief Construct a load constant instruction.
-   *  \param const_index The index of the constant.
-   *  \param dst The destination register.
-   *  \return The load constant instruction.
+  /*!
+   * \brief Construct a load constant instruction.
+   * \param const_index The index of the constant.
+   * \param dst The destination register.
+   * \return The load constant instruction.
    */
   static Instruction LoadConst(Index const_index, RegName dst);
-  /*! \brief Construct a load_constanti instruction.
-   *  \param val The interger constant value.
-   *  \param dst The destination register.
-   *  \return The load_constanti instruction.
+  /*!
+   * \brief Construct a load_constanti instruction.
+   * \param val The interger constant value.
+   * \param dst The destination register.
+   * \return The load_constanti instruction.
    */
   static Instruction LoadConsti(Index val, RegName dst);
-  /*! \brief Construct a move instruction.
-   *  \param src The source register.
-   *  \param dst The destination register.
-   *  \return The move instruction.
+  /*!
+   * \brief Construct a move instruction.
+   * \param src The source register.
+   * \param dst The destination register.
+   * \return The move instruction.
    */
   static Instruction Move(RegName src, RegName dst);
 
-   /*! \brief Allocate a storage block.
-   *  \param size The size of the allocation.
-   *  \param alignment The allocation's alignment.
-   *  \param dtype_hint The data type hint for the allocator.
-   *  \param dst The destination to place the storage.
-   *  \return The alloc storage instruction.
+  /*!
+   * \brief Allocate a storage block.
+   * \param size The size of the allocation.
+   * \param alignment The allocation's alignment.
+   * \param dtype_hint The data type hint for the allocator.
+   * \param dst The destination to place the storage.
+   * \return The alloc storage instruction.
    */
   static Instruction AllocStorage(RegName size, RegName alignment,
                                   DLDataType dtype_hint, RegName dst);
@@ -399,7 +416,8 @@ struct Instruction {
   friend std::ostream& operator<<(std::ostream& os, const Instruction&);
 };
 
-/*! \brief A representation of a Relay function in the VM.
+/*!
+ * \brief A representation of a Relay function in the VM.
  *
  * Contains metadata about the compiled function, as
  * well as the compiled VM instructions.
@@ -427,7 +445,8 @@ struct VMFunction {
   friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
 };
 
-/*! \brief A representation of a stack frame.
+/*!
+ * \brief A representation of a stack frame.
  *
  * A stack frame is a record containing the information needed
  * to restore the caller's virtual machine state after returning
@@ -458,7 +477,8 @@ struct VMFrame {
         caller_return_register(0) {}
 };
 
-/*! \brief The executable emitted by the VM compiler.
+/*!
+ * \brief The executable emitted by the VM compiler.
  *
  * The executable contains information (e.g. data in different memory regions)
  * to run in a virtual machine.
@@ -534,19 +554,35 @@ class Executable : public ModuleNode {
    */
   std::string GetBytecode() const;
 
-/*!
+  /*!
    * \brief Print the detailed statistics of the given code, i.e. number of
    * globls and constants, etc.
    */
   std::string Stats() const;
 
-  /*! \brief Get the `lib` module in an executable. Users have the flexibility to call
+  /*!
+   * \brief Get the `lib` module in an executable. Users have the flexibility to call
    * `export_library` from the frontend to save the library to disk.
    *
    * \return The runtime module that contains the hardwre dependent code.
    */
   runtime::Module GetLib() const { return lib; }
 
+  /*!
+   * \brief Get the arity of the VM Fucntion.
+   * \param func Function name.
+   * \return The number of parameters.
+   */
+  int GetFunctionArity(std::string func) const;
+
+  /*!
+   * \brief Get the parameter name given the function name and parameter index.
+   * \param func Function name.
+   * \param index Parameter index.
+   * \return The parameter name.
+   */
+  std::string GetFunctionParameterName(std::string func, uint32_t index) const;
+
   virtual ~Executable() {}
 
   const char* type_key() const final {
@@ -628,7 +664,8 @@ class Executable : public ModuleNode {
   std::string code_;
 };
 
-/*! \brief The virtual machine.
+/*!
+ * \brief The virtual machine.
  *
  * The virtual machine contains all the current execution state,
  * as well as the executable.
@@ -660,83 +697,72 @@ class VirtualMachine : public runtime::ModuleNode {
   virtual PackedFunc GetFunction(const std::string& name,
                                  const ObjectPtr<Object>& sptr_to_self);
 
-  /*!
-   * \brief Invoke a PackedFunction
-   *
-   * \param packed_index The offset of the PackedFunction in all functions.
-   * \param func The PackedFunction to be invoked.
-   * \param arg_count The number of arguments to the PackedFunction.
-   * \param output_size The number of outputs of the PackedFunction.
-   * \param args Arguments to the PackedFunction.
-   *
-   * \note The return value will be stored in the last output_size slots of args.
-   */
-  virtual void InvokePacked(Index packed_index,
-                            const PackedFunc& func,
-                            Index arg_count,
-                            Index output_size,
-                            const std::vector<ObjectRef>& args);
-
   virtual ~VirtualMachine() {}
 
   const char* type_key() const final {
     return "VirtualMachine";
   }
 
-  VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}
+  VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
 
-  /*! \brief load the executable for the virtual machine.
-   *  \param exec The executable.
+  /*!
+   * \brief load the executable for the virtual machine.
+   * \param exec The executable.
    */
-  void LoadExecutable(const Executable* exec);
+  virtual void LoadExecutable(const Executable* exec);
 
  protected:
   /*! \brief The virtual machine's packed function table. */
-  std::vector<PackedFunc> packed_funcs;
+  std::vector<PackedFunc> packed_funcs_;
   /*! \brief The current stack of call frames. */
-  std::vector<VMFrame> frames;
+  std::vector<VMFrame> frames_;
   /*! \brief The fuction table index of the current function. */
-  Index func_index;
+  Index func_index_;
   /*! \brief The current pointer to the code section. */
-  const Instruction* code;
+  const Instruction* code_;
   /*! \brief The virtual machine PC. */
-  Index pc;
-
+  Index pc_;
   /*! \brief The special return register. */
-  ObjectRef return_register;
-
+  ObjectRef return_register_;
   /*! \brief The executable the VM will operate on. */
-  const Executable* exec;
-
+  const Executable* exec_;
+  /*! \brief The function name to inputs mapping. */
+  std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
   /*! \brief The set of TVM contexts the VM is currently executing on. */
-  std::vector<TVMContext> ctxs;
+  std::vector<TVMContext> ctxs_;
 
   /*! \brief Push a call frame on to the call stack. */
   void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
-  /*! \brief Pop a frame off the call stack.
-   *  \return The number of frames left.
+
+  /*!
+   * \brief Pop a frame off the call stack.
+   * \return The number of frames left.
    */
   Index PopFrame();
 
-  /*! \brief Write to a VM register.
-   *  \param reg The register to write to.
-   *  \param obj The object to write to.
+  /*!
+   * \brief Write to a VM register.
+   * \param reg The register to write to.
+   * \param obj The object to write to.
    */
   inline void WriteRegister(RegName reg, const ObjectRef& obj);
 
-  /*! \brief Read a VM register.
-   *  \param reg The register to read from.
-   *  \return The read object.
+  /*!
+   * \brief Read a VM register.
+   * \param reg The register to read from.
+   * \return The read object.
    */
   inline ObjectRef ReadRegister(RegName reg) const;
 
-  /*! \brief Read a VM register and cast it to int32_t
-   *  \param reg The register to read from.
-   *  \return The read scalar.
+  /*!
+   * \brief Read a VM register and cast it to int32_t
+   * \param reg The register to read from.
+   * \return The read scalar.
    */
   int32_t LoadScalarInt(RegName reg) const;
 
-  /*! \brief Invoke a VM function.
+  /*!
+   * \brief Invoke a VM function.
    * \param func The function.
    * \param args The arguments to the function.
    * \return The object representing the result.
@@ -752,29 +778,43 @@ class VirtualMachine : public runtime::ModuleNode {
    */
   ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
 
-  /*! \brief Initialize the virtual machine for a set of contexts.
-   *  \param contexts The set of TVM contexts.
+  /*!
+   * \brief Invoke a PackedFunction
+   *
+   * \param packed_index The offset of the PackedFunction in all functions.
+   * \param func The PackedFunction to be invoked.
+   * \param arg_count The number of arguments to the PackedFunction.
+   * \param output_size The number of outputs of the PackedFunction.
+   * \param args Arguments to the PackedFunction.
+   *
+   * \note The return value will be stored in the last output_size slots of args.
+   */
+  virtual void InvokePacked(Index packed_index,
+                            const PackedFunc& func,
+                            Index arg_count,
+                            Index output_size,
+                            const std::vector<ObjectRef>& args);
+
+  /*!
+   * \brief Initialize the virtual machine for a set of contexts.
+   * \param contexts The set of TVM contexts.
    */
   void Init(const std::vector<TVMContext>& contexts);
 
-  /*! \brief Run VM dispatch loop.
-   */
+  /*! \brief Run VM dispatch loop. */
   void RunLoop();
 
-  /*! \brief Get device context for params.
-   */
+  /*! \brief Get device context for params. */
   TVMContext GetParamsContext() const;
 
  private:
-  /*! \brief Invoke a global setting up the VM state to execute.
+  /*!
+   * \brief Invoke a global setting up the VM state to execute.
    *
    * This does not begin execution of the VM.
    */
   void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
 
-  /*! \brief The parameter name to data mapping. */
-  std::unordered_map<std::string, ObjectRef> params_;
-
   /*!
    * \brief The constant pool for runtime. It caches the device dependent
    * object to avoid rellocation of constants during inference.
index ded5d0d..5ee2d66 100644 (file)
@@ -22,68 +22,24 @@ Provides extra APIs for profiling vm execution.
 """
 from . import vm, _vm
 
-def compile(mod, target=None, target_host=None, params=None):
-    """
-    Parameters
-    ----------
-    mod : relay.Module
-        The Relay module to build.
-
-    target : str, :any:`tvm.target.Target`, or dict of str(i.e.
-        device/context name) to str/tvm.target.Target, optional
-        For heterogeneous compilation, it is a dictionary indicating context
-        to target mapping. For homogeneous compilation, it is a build target.
-
-    target_host : str or :any:`tvm.target.Target`, optional
-        Host compilation target, if target is device.
-        When TVM compiles device specific program such as CUDA,
-        we also need host(CPU) side code to interact with the driver
-        to setup the dimensions and parameters correctly.
-        target_host is used to specify the host side codegen target.
-        By default, llvm is used if it is enabled,
-        otherwise a stackvm intepreter is used.
-
-    params : dict of str to NDArray
-        Input parameters to the graph that do not change
-        during inference time. Used for constant folding.
-
-    Returns
-    -------
-    exec : Executable
-        The executable with profiling code.
-    """
-    compiler = VMCompilerProfiler()
-    target = compiler.update_target(target)
-    target_host = compiler.update_target_host(target, target_host)
-    if params:
-        compiler.set_params(params)
-    tophub_context = compiler.tophub_context(target)
-    with tophub_context:
-        compiler._compile(mod, target, target_host)
-    return vm.Executable(compiler._get_exec())
-
 def enabled():
     """Whether vm profiler is enabled."""
-    return hasattr(_vm, "_VMCompilerProfiler")
-
-class VMCompilerProfiler(vm.VMCompiler):
-    """Build Relay module to run on VM runtime."""
-    def __init__(self):
-        super().__init__()
-        self.mod = _vm._VMCompilerProfiler()
-        self._compile = self.mod["compile"]
-        self._get_exec = self.mod["get_executable"]
-        self._set_params_func = self.mod["set_params"]
+    return hasattr(_vm, "_VirtualMachineDebug")
 
 class VirtualMachineProfiler(vm.VirtualMachine):
     """Relay profile VM runtime."""
     def __init__(self, mod):
-        super().__init__(mod)
+        super(VirtualMachineProfiler, self).__init__(mod)
         m = mod.module if isinstance(mod, vm.Executable) else mod
         self.mod = _vm._VirtualMachineDebug(m)
         self._init = self.mod["init"]
         self._invoke = self.mod["invoke"]
         self._get_stat = self.mod["get_stat"]
+        self._set_input = self.mod["set_input"]
+        self._reset = self.mod["reset"]
 
     def get_stat(self):
         return self._get_stat()
+
+    def reset(self):
+        self._reset()
index 5a4c5f7..bfdf3b6 100644 (file)
@@ -34,7 +34,9 @@ Tensor = _obj.Tensor
 ADT = _obj.ADT
 
 def _convert(arg, cargs):
-    if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
+    if isinstance(arg, _obj.Object):
+        cargs.append(arg)
+    elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
         cargs.append(_obj.Tensor(arg))
     elif isinstance(arg, (tuple, list)):
         field_args = []
@@ -42,7 +44,7 @@ def _convert(arg, cargs):
             _convert(field, field_args)
         cargs.append(_obj.tuple_object(field_args))
     else:
-        raise "unsupported type"
+        raise "Unsupported type: %s" % (type(arg))
 
 
 def convert(args):
@@ -57,10 +59,13 @@ class Executable(object):
     """Relay VM executable"""
     def __init__(self, mod):
         self.mod = mod
+        self._function_params = {}
         self._save = self.mod["save"]
         self._get_lib = self.mod["get_lib"]
         self._get_bytecode = self.mod["get_bytecode"]
         self._get_stats = self.mod["get_stats"]
+        self._get_function_arity = self.mod["get_function_arity"]
+        self._get_function_param_name = self.mod["get_function_param_name"]
 
     def save(self):
         """Save the Relay VM Executable.
@@ -239,6 +244,20 @@ class Executable(object):
         """Return the runtime module contained in a virtual machine executable."""
         return self.mod
 
+    def get_function_params(self, func_name):
+        """Get VM Function parameters"""
+        if func_name in self._function_params:
+            return self._function_params[func_name]
+        arity = self._get_function_arity(func_name)
+        assert arity >= 0
+        params = []
+        for i in range(arity):
+            p = self._get_function_param_name(func_name, i)
+            assert p
+            params.append(p)
+        self._function_params[func_name] = params
+        return params
+
 
 class VirtualMachine(object):
     """Relay VM runtime."""
@@ -248,8 +267,10 @@ class VirtualMachine(object):
                             "tvm.Module, but received {}".format(type(mod)))
         m = mod.module if isinstance(mod, Executable) else mod
         self.mod = _vm._VirtualMachine(m)
+        self._exec = mod
         self._init = self.mod["init"]
         self._invoke = self.mod["invoke"]
+        self._set_input = self.mod["set_input"]
 
     def init(self, ctx):
         """Initialize the context in the VM.
@@ -262,7 +283,37 @@ class VirtualMachine(object):
         args = [ctx.device_type, ctx.device_id]
         self._init(*args)
 
-    def invoke(self, func_name, *args):
+    def set_input(self, func_name, *args, **kwargs):
+        """Set the input to a function.
+
+        Parameters
+        ----------
+        func_name : str
+            The name of the function.
+
+        args : list[NDArray] or list[np.ndarray]
+            The arguments to the function.
+
+        kwargs: dict of str to NDArray or np.ndarray
+            Named arguments to the function.
+        """
+        if kwargs:
+            func_params = self._exec.get_function_params(func_name)
+            new_args = [None] * len(func_params)
+            assert len(args) + len(kwargs) == len(func_params)
+            for k in kwargs:
+                idx = func_params.index(k)
+                new_args[idx] = kwargs[k]
+            idx = 0
+            for i, arg in enumerate(new_args):
+                if arg is None:
+                    new_args[i] = args[idx]
+                    idx += 1
+            args = new_args
+        cargs = convert(args)
+        self._set_input(func_name, *cargs)
+
+    def invoke(self, func_name, *args, **kwargs):
         """Invoke a function.
 
         Parameters
@@ -273,15 +324,19 @@ class VirtualMachine(object):
         args : list[NDArray] or list[np.ndarray]
             The arguments to the function.
 
+        kwargs: dict of str to NDArray or np.ndarray
+            Named arguments to the function.
+
         Returns
         -------
         result : Object
             The output.
         """
-        cargs = convert(args)
-        return self._invoke(func_name, *cargs)
+        if args or kwargs:
+            self.set_input(func_name, *args, **kwargs)
+        return self._invoke(func_name)
 
-    def run(self, *args):
+    def run(self, *args, **kwargs):
         """Run the main function.
 
         Parameters
@@ -289,12 +344,15 @@ class VirtualMachine(object):
         args : list[NDArray] or list[np.ndarray]
             The arguments to the function.
 
+        kwargs: dict of str to NDArray or np.ndarray
+            Named arguments to the function.
+
         Returns
         -------
         result : Object
             The output.
         """
-        return self.invoke("main", *args)
+        return self.invoke("main", *args, **kwargs)
 
 
 def compile(mod, target=None, target_host=None, params=None):
diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc
deleted file mode 100644 (file)
index 4727f15..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/relay/backend/vm/profiler/compiler.cc
- * \brief A compiler from relay::Module to the VM byte code.
- */
-
-#include "../../../../runtime/vm/profiler/vm.h"
-#include "../compiler.h"
-
-namespace tvm {
-namespace relay {
-namespace vm {
-
-class VMCompilerDebug : public VMCompiler {
- public:
-  VMCompilerDebug() {}
-  virtual ~VMCompilerDebug() {}
-};
-
-runtime::Module CreateVMCompilerDebug() {
-  auto exec = make_object<VMCompilerDebug>();
-  return runtime::Module(exec);
-}
-
-TVM_REGISTER_GLOBAL("relay._vm._VMCompilerProfiler")
-    .set_body([](TVMArgs args, TVMRetValue* rv) {
-      *rv = CreateVMCompilerDebug();
-    });
-
-}  // namespace vm
-}  // namespace relay
-}  // namespace tvm
index 2aeecc5..f02fadb 100644 (file)
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include <memory>
 #include <iostream>
+#include <iomanip>
 #include <sstream>
 #include <utility>
 #include <vector>
@@ -67,44 +68,76 @@ PackedFunc Executable::GetFunction(const std::string& name,
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       *rv = this->Save();
     });
+  } else if (name == "get_function_arity") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      std::string func_name = args[0];
+      *rv = this->GetFunctionArity(func_name);
+    });
+  } else if (name == "get_function_param_name") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      std::string func_name = args[0];
+      int index = args[1];
+      *rv = this->GetFunctionParameterName(func_name, index);
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc(nullptr);
   }
 }
 
+int Executable::GetFunctionArity(std::string func_name) const {
+  auto it = global_map.find(func_name);
+  if (it == global_map.end()) {
+    LOG(ERROR) << "Cannot find function " << func_name << " in executable";
+    return -1;
+  }
+  const auto& func = functions[it->second];
+  return func.params.size();
+}
+
+std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
+  auto it = global_map.find(func_name);
+  if (it == global_map.end()) {
+    LOG(ERROR) << "Cannot find function " << func_name << " in executable";
+    return "";
+  }
+  const auto& func = functions[it->second];
+  if (index > func.params.size()) {
+    LOG(ERROR) << "Invalid parameter index";
+    return "";
+  }
+  return func.params[index];
+}
+
 std::string Executable::GetBytecode() const {
   std::ostringstream oss;
 
-  for (const auto& func : functions) {
+  for (size_t i = 0; i < functions.size(); ++i) {
+    const auto& func = functions[i];
     // Print the header of the function format.
-    oss << "# func name, reg file size, param count, inst count:"
-        << std::endl;
-    oss << func.name << " "
-        << func.register_file_size << " "
-        << func.params.size() << " "
-        << func.instructions.size() << std::endl;
-
-    // Print pramams of a `VMFunction`.
-    oss << "# Parameters: "<< std::endl;
+    oss << "VM Function[" << i << "]: " << func.name << "(";
     for (const auto& param : func.params) {
-      oss << param << " ";
+      oss << param << ", ";
     }
-    oss << std::endl;
+    oss.seekp(-2, std::ios_base::end);
+    oss << ")" << std::endl;
+    oss << "# reg file size = " << func.register_file_size << std::endl;
+    oss << "# instruction count = " << func.instructions.size() << std::endl;
 
     // Print the instructions of a `VMFunction`.
     // The part after ";" is the instruction in text format.
-    oss << "hash, opcode, fields # inst(text):"<< std::endl;
-    for (const auto& instr : func.instructions) {
+    oss << "opcode, fields # inst(text):" << std::endl;
+    for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
+      const auto& instr = func.instructions[idx];
       const auto& serialized_instr = SerializeInstruction(instr);
-      oss << std::hex << "0x" << serialized_instr.Hash() << " "
-          << std::dec << serialized_instr.opcode << " ";
+      oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
       for (auto it : serialized_instr.fields) {
         oss << it << " ";
       }
       oss << "  # " << instr;
       if (oss.str().back() != '\n') oss << std::endl;
     }
+    oss << std::endl;
   }
 
   return oss.str();
index ed6cddb..b004f67 100644 (file)
@@ -50,15 +50,15 @@ PackedFunc VirtualMachineDebug::GetFunction(
          << "\t"
          << "#Duration(us): Sum/Mean/Min/Max" << std::endl;
 
-      for (auto kv : op_durations) {
-        auto vals = op_durations[kv.first];
+      for (auto kv : op_durations_) {
+        auto vals = op_durations_[kv.first];
         auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
         auto mean = sum / static_cast<double>(vals.size());
         auto min_value = *std::min_element(vals.begin(), vals.end());
         auto max_value = *std::max_element(vals.begin(), vals.end());
 
-        os << std::setw(30) << std::left << packed_index_map[kv.first] << "\t"
-           << std::setw(10) << std::left << op_invokes[kv.first] << "\t"
+        os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t"
+           << std::setw(10) << std::left << op_invokes_[kv.first] << "\t"
            <<  sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
 
         total_duration += sum;
@@ -66,18 +66,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
       os << "Total Duration " << total_duration << " us" << std::endl;
       *rv = os.str();
     });
-  } else if (name == "init") {
+  } else if (name == "reset") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      CHECK_EQ(args.size() % 2, 0);
-      std::vector<TVMContext> contexts;
-      for (int i = 0; i < args.size() / 2; ++i) {
-        TVMContext ctx;
-        int device_type = args[i * 2];
-        ctx.device_type = DLDeviceType(device_type);
-        ctx.device_id = args[i * 2 + 1];
-        contexts.push_back(ctx);
-      }
-      this->Init(contexts);
+      op_durations_.clear();
+      op_invokes_.clear();
     });
   } else {
     return VirtualMachine::GetFunction(name, sptr_to_self);
@@ -86,31 +78,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
 
 void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
   VirtualMachine::LoadExecutable(exec);
-  CHECK(this->exec);
-  for (auto kv : this->exec->primitive_map) {
-    packed_index_map[kv.second] = kv.first;
-    op_invokes[kv.second] = 0;
+  CHECK(exec_);
+  for (auto kv : exec_->primitive_map) {
+    packed_index_map_[kv.second] = kv.first;
+    op_invokes_[kv.second] = 0;
   }
 }
 
-void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
-  VirtualMachine::Init(ctxs);
-}
-
 void VirtualMachineDebug::InvokePacked(Index packed_index,
                                        const PackedFunc& func, Index arg_count,
                                        Index output_size,
                                        const std::vector<ObjectRef>& args) {
-  CHECK(this->exec);
+  CHECK(exec_);
   auto ctx = this->GetParamsContext();
   // warmup
-  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
-                               args);
+  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
   TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
 
   auto op_begin = std::chrono::high_resolution_clock::now();
-  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
-                               args);
+  VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
   TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
   auto op_end = std::chrono::high_resolution_clock::now();
   double op_duration =
@@ -118,8 +104,8 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
                                                                  op_begin)
           .count();
 
-  op_durations[packed_index].push_back(op_duration * 1e6);
-  op_invokes[packed_index] += 1;
+  op_durations_[packed_index].push_back(op_duration * 1e6);
+  op_invokes_[packed_index] += 1;
 }
 
 runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
index 2e95a07..f0a407f 100644 (file)
@@ -43,19 +43,17 @@ class VirtualMachineDebug : public VirtualMachine {
   PackedFunc GetFunction(const std::string& name,
                          const ObjectPtr<Object>& sptr_to_self) final;
 
-  void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
-                    Index output_size, const std::vector<ObjectRef>& args) final;
-
-  void LoadExecutable(const Executable* exec);
+  void LoadExecutable(const Executable* exec) final;
 
   ~VirtualMachineDebug() {}
 
  private:
-  void Init(const std::vector<TVMContext>& ctxs);
+  void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
+                    Index output_size, const std::vector<ObjectRef>& args) final;
 
-  std::unordered_map<Index, std::string> packed_index_map;
-  std::unordered_map<Index, std::vector<double>> op_durations;
-  std::unordered_map<Index, int> op_invokes;
+  std::unordered_map<Index, std::string> packed_index_map_;
+  std::unordered_map<Index, std::vector<double>> op_durations_;
+  std::unordered_map<Index, int> op_invokes_;
 };
 
 }  // namespace vm
index 463c575..cc16391 100644 (file)
@@ -544,7 +544,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       break;
     }
     case Opcode::If: {
-      os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
+      os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " "
          << instr.if_op.true_offset << " " << instr.if_op.false_offset;
       break;
     }
@@ -565,7 +565,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       break;
     }
     case Opcode::LoadConsti: {
-      os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
+      os << "load_consti $" << instr.dst << " " << instr.load_consti.val;
       break;
     }
     case Opcode::GetField: {
@@ -630,35 +630,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
                                        const ObjectPtr<Object>& sptr_to_self) {
   if (name == "invoke") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      CHECK(exec) << "The executable is not created yet.";
+      CHECK(exec_) << "The executable is not created yet.";
       std::string func_name = args[0];
-      auto gvit = exec->global_map.find(func_name);
-      CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name;
-      auto func_index = gvit->second;
-      const auto& vm_func = exec->functions[func_index];
-      const auto& param_names = vm_func.params;
-      auto ctx = this->GetParamsContext();
-
-      // Prepare the func args
-      std::vector<ObjectRef> func_args(param_names.size());
-      std::vector<size_t> empty_slots;
-
-      for (size_t i = 0; i < param_names.size(); ++i) {
-        const auto& pit = params_.find(param_names[i]);
-        if (pit != params_.end()) {
-          func_args[i] = pit->second;
-        } else {
-          empty_slots.push_back(i);
-        }
-      }
-      CHECK_EQ(empty_slots.size(), args.size() - 1)
-          << "The number of provided parameters doesn't match the number of arguments";
-      for (int i = 1; i < args.size(); ++i) {
-        ObjectRef obj = CopyTo(args[i], ctx);
-        func_args[empty_slots[i - 1]] = obj;
+      auto git = exec_->global_map.find(func_name);
+      CHECK(git != exec_->global_map.end())
+        << "Cannot find function " << func_name << " in the executable";
+      auto func = exec_->functions[git->second];
+      if (func.params.empty()) {
+        *rv = Invoke(func, {});
+      } else {
+        auto it = inputs_.find(func_name);
+        CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name;
+        const std::vector<ObjectRef> &func_args = it->second;
+        *rv = Invoke(func, func_args);
       }
-
-      *rv = this->Invoke(vm_func, func_args);
     });
   } else if (name == "init") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -673,6 +658,27 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
       }
       this->Init(contexts);
     });
+  } else if (name == "set_input") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      CHECK(exec_) << "The executable is not created yet.";
+      std::string func_name = args[0];
+      auto gvit = exec_->global_map.find(func_name);
+      CHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
+      auto func_index = gvit->second;
+      const auto& vm_func = exec_->functions[func_index];
+      const auto& param_names = vm_func.params;
+      // TODO(icemelon9): For heterogeneous execution, get input device information
+      TVMContext ctx = ctxs_[0];
+      CHECK_EQ(args.size() - 1, param_names.size()) <<
+          "The number of provided parameters doesn't match the number of arguments";
+      std::vector<ObjectRef> func_args(param_names.size());
+      for (int i = 1; i < args.size(); ++i) {
+        ObjectRef obj = CopyTo(args[i], ctx);
+        func_args[i - 1] = obj;
+      }
+      inputs_.erase(func_name);
+      inputs_.emplace(func_name, func_args);
+    });
   } else {
     LOG(FATAL) << "Unknown packed function: " << name;
     return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
@@ -680,47 +686,46 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
 }
 
 TVMContext VirtualMachine::GetParamsContext() const {
-  CHECK(!ctxs.empty()) << "Context has not been initialized yet."
-                       << "\n";
+  CHECK(!ctxs_.empty()) << "Context has not been initialized yet.";
 
   // Use the fallback device if no device index is available.
-  int fallback_device_type = static_cast<int>(ctxs[0].device_type);
+  int fallback_device_type = static_cast<int>(ctxs_[0].device_type);
   // TODO(wweic): For heterogeneous execution, get device information from byte
 
   const auto& cit =
-      std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
+      std::find_if(ctxs_.begin(), ctxs_.end(), [&fallback_device_type](const TVMContext& c) {
         return fallback_device_type == static_cast<int>(c.device_type);
       });
-  return (cit == ctxs.end() ? ctxs[0] : *cit);
+  return (cit == ctxs_.end() ? ctxs_[0] : *cit);
 }
 
 void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
-  auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size);
-  frames.push_back(frame);
+  auto frame = VMFrame(ret_pc, func_index_, arg_count, code_, vm_func.register_file_size);
+  frames_.push_back(frame);
 }
 
 Index VirtualMachine::PopFrame() {
-  CHECK_GT(frames.size(), 0);
-  const VMFrame& fr = frames.back();
-  func_index = fr.func_index;
-  code = fr.code;
-  pc = fr.pc;
-  auto call_stack_size = frames.size();
-  frames.pop_back();
+  CHECK_GT(frames_.size(), 0);
+  const VMFrame& fr = frames_.back();
+  func_index_ = fr.func_index;
+  code_ = fr.code;
+  pc_ = fr.pc;
+  auto call_stack_size = frames_.size();
+  frames_.pop_back();
   return call_stack_size;
 }
 
 void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
   DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
 
-  PushFrame(func.params.size(), this->pc + 1, func);
+  PushFrame(func.params.size(), this->pc_ + 1, func);
   for (size_t i = 0; i < args.size(); ++i) {
     WriteRegister(i, args[i]);
   }
   DLOG(INFO) << "func.params= " << func.params.size();
 
-  code = func.instructions.data();
-  pc = 0;
+  code_ = func.instructions.data();
+  pc_ = 0;
 }
 
 ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
@@ -729,16 +734,19 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
   InvokeGlobal(func, args);
   RunLoop();
   // TODO(wweic) ctx could be obtained from the ctxs list.
-  auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]);
+  auto alloc = MemoryManager::Global()->GetAllocator(ctxs_[0]);
   DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
-  return return_register;
+  return return_register_;
 }
 
 ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
-  CHECK(exec) << "The executable has not been created yet.";
-  auto func_index = exec->global_map.at(name);
-  DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
-  return Invoke(exec->functions[func_index], args);
+  CHECK(exec_) << "The executable has not been created yet.";
+  auto it = exec_->global_map.find(name);
+  CHECK(it != exec_->global_map.end())
+    << "Cannot find function " << name << " in the executable";
+  auto func_index_ = it->second;
+  DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_;
+  return Invoke(exec_->functions[func_index_], args);
 }
 
 void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
@@ -777,34 +785,34 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
 
 void VirtualMachine::LoadExecutable(const Executable* exec) {
   CHECK(exec) << "The executable is not created yet.";
-  this->exec = exec;
+  exec_ = exec;
 
-  runtime::Module lib = this->exec->lib;
+  runtime::Module lib = exec_->lib;
   // Get the list of packed functions.
   CHECK(exec->primitive_map.empty() || lib.operator->())
       << "runtime module should have been built for primitive functions"
       << "\n";
-  for (const auto& it : this->exec->primitive_map) {
+  for (const auto& it : exec_->primitive_map) {
     const auto& packed_name = it.first;
     auto packed_index = static_cast<size_t>(it.second);
-    if (packed_funcs.size() <= packed_index) {
-      packed_funcs.resize(packed_index + 1);
+    if (packed_funcs_.size() <= packed_index) {
+      packed_funcs_.resize(packed_index + 1);
     }
-    packed_funcs[packed_index] = lib.GetFunction(packed_name);
+    packed_funcs_[packed_index] = lib.GetFunction(packed_name);
   }
 }
 
 
 void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
-  this->ctxs = ctxs;
+  ctxs_ = ctxs;
 }
 
 inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
-  frames.back().register_file[r] = val;
+  frames_.back().register_file[r] = val;
 }
 
 inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
-  return frames.back().register_file[r];
+  return frames_.back().register_file[r];
 }
 
 inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
@@ -825,14 +833,14 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
 }
 
 void VirtualMachine::RunLoop() {
-  CHECK(this->code);
-  CHECK(this->exec);
-  this->pc = 0;
-  Index frame_start = frames.size();
+  CHECK(this->exec_);
+  CHECK(this->code_);
+  pc_ = 0;
+  Index frame_start = frames_.size();
   while (true) {
   main_loop:
-    auto const& instr = this->code[this->pc];
-    DLOG(INFO) << "Executing(" << pc << "): " << instr;
+    auto const& instr = code_[this->pc_];
+    DLOG(INFO) << "Executing(" << pc_ << "): " << instr;
 #if USE_RELAY_DEBUG
     InstructionPrint(std::cout, instr);
 #endif  // USE_RELAY_DEBUG
@@ -842,14 +850,14 @@ void VirtualMachine::RunLoop() {
         ObjectRef from_obj;
         from_obj = ReadRegister(instr.from);
         WriteRegister(instr.dst, from_obj);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::Fatal: {
         throw std::runtime_error("VM encountered fatal error");
       }
       case Opcode::LoadConst: {
-        auto constant_obj = exec->constants[instr.const_index];
+        auto constant_obj = exec_->constants[instr.const_index];
         // We cache the allocated object in the constant pool. To measure, the
         // first iteration will set the pool up. The other iterations will
         // directly reuse the allocated objects.
@@ -859,17 +867,17 @@ void VirtualMachine::RunLoop() {
 
         if (!const_pool_[instr.const_index].defined()) {
           // TODO(wweic) ctx could be obtained from the ctxs list.
-          const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs[0]);
+          const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs_[0]);
         }
         WriteRegister(instr.dst, const_pool_[instr.const_index]);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::LoadConsti: {
         auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
         reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
         WriteRegister(instr.dst, Tensor(tensor));
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::Invoke: {
@@ -877,14 +885,13 @@ void VirtualMachine::RunLoop() {
         for (Index i = 0; i < instr.num_args; ++i) {
           args.push_back(ReadRegister(instr.invoke_args_registers[i]));
         }
-        InvokeGlobal(exec->functions[instr.func_index], args);
-        frames.back().caller_return_register = instr.dst;
+        InvokeGlobal(exec_->functions[instr.func_index], args);
+        frames_.back().caller_return_register = instr.dst;
         goto main_loop;
       }
       case Opcode::InvokePacked: {
-        DLOG(INFO) << "InvokedPacked "
-          << "arity=" << instr.arity;
-        const auto& func = packed_funcs[instr.packed_index];
+        DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity;
+        const auto& func = packed_funcs_[instr.packed_index];
         const auto& arity = instr.arity;
         std::vector<ObjectRef> args;
         for (Index i = 0; i < arity; ++i) {
@@ -897,7 +904,7 @@ void VirtualMachine::RunLoop() {
         // We no longer need to write the registers back, we write directly
         // through the registers mutably.
         InvokePacked(instr.packed_index, func, arity, instr.output_size, args);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::InvokeClosure: {
@@ -911,8 +918,8 @@ void VirtualMachine::RunLoop() {
         for (Index i = 0; i < instr.num_closure_args; ++i) {
           args.push_back(ReadRegister(instr.closure_args[i]));
         }
-        InvokeGlobal(exec->functions[closure->func_index], args);
-        frames.back().caller_return_register = instr.dst;
+        InvokeGlobal(exec_->functions[closure->func_index], args);
+        frames_.back().caller_return_register = instr.dst;
         goto main_loop;
       }
       case Opcode::GetField: {
@@ -923,7 +930,7 @@ void VirtualMachine::RunLoop() {
             << object->type_index();
         auto field = tuple->fields[instr.field_index];
         WriteRegister(instr.dst, field);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::GetTag: {
@@ -937,11 +944,11 @@ void VirtualMachine::RunLoop() {
         auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
         reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
         WriteRegister(instr.dst, Tensor(tag_tensor));
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::Goto: {
-        pc += instr.pc_offset;
+        pc_ += instr.pc_offset;
         goto main_loop;
       }
       case Opcode::If: {
@@ -950,10 +957,10 @@ void VirtualMachine::RunLoop() {
 
         if (test_val == target_val) {
           CHECK_NE(instr.if_op.true_offset, 0);
-          pc += instr.if_op.true_offset;
+          pc_ += instr.if_op.true_offset;
         } else {
           CHECK_NE(instr.if_op.false_offset, 0);
-          pc += instr.if_op.false_offset;
+          pc_ += instr.if_op.false_offset;
         }
 
         goto main_loop;
@@ -971,7 +978,7 @@ void VirtualMachine::RunLoop() {
 
         auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::AllocTensorReg: {
@@ -996,7 +1003,7 @@ void VirtualMachine::RunLoop() {
 
         auto obj = Tensor(data);
         WriteRegister(instr.dst, obj);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::AllocADT: {
@@ -1006,7 +1013,7 @@ void VirtualMachine::RunLoop() {
         }
         ObjectRef obj = ADT(instr.constructor_tag, fields);
         WriteRegister(instr.dst, obj);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::AllocClosure: {
@@ -1015,7 +1022,7 @@ void VirtualMachine::RunLoop() {
           free_vars.push_back(ReadRegister(instr.free_vars[i]));
         }
         WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::AllocStorage: {
@@ -1027,23 +1034,23 @@ void VirtualMachine::RunLoop() {
           "alignment=" << alignment <<
           "dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint);
 
-        auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs[0]);
+        auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
         WriteRegister(instr.dst, storage);
-        pc++;
+        pc_++;
         goto main_loop;
       }
       case Opcode::Ret: {
         // If we have hit the point from which we started
         // running, we should return to the caller breaking
         // the dispatch loop.
-        return_register = ReadRegister(instr.result);
-        auto caller_return_register = frames.back().caller_return_register;
+        return_register_ = ReadRegister(instr.result);
+        auto caller_return_register = frames_.back().caller_return_register;
 
         if (PopFrame() == frame_start) {
           return;
           // Otherwise we are just returning from a local call.
         } else {
-          WriteRegister(caller_return_register, return_register);
+          WriteRegister(caller_return_register, return_register_);
           goto main_loop;
         }
       }
@@ -1061,8 +1068,7 @@ TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   runtime::Module mod = args[0];
   const auto* exec = dynamic_cast<Executable*>(mod.operator->());
-  CHECK(exec) << "The virtual machine executable has not been defined yet."
-              << "\n";
+  CHECK(exec) << "The virtual machine executable has not been defined yet.";
   *rv = CreateVirtualMachine(exec);
 });
 
index a3b251c..a4c5b7d 100644 (file)
@@ -47,18 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
     if isinstance(f, relay.Expr):
         mod = relay.Module()
         mod["main"] = f
-        exe = relay.vm.compile(mod, target)
-        vm = relay.vm.VirtualMachine(exe)
-        vm.init(ctx)
-        return vm.invoke("main", *args)
     else:
         assert isinstance(f, relay.Module), "expected expression or module"
         mod = f
-        exe = relay.vm.compile(mod, target)
-        vm = relay.vm.VirtualMachine(exe)
-        vm.init(ctx)
-        ret = vm.invoke("main", *args)
-        return ret
+    exe = relay.vm.compile(mod, target)
+    vm = relay.vm.VirtualMachine(exe)
+    vm.init(ctx)
+    return vm.invoke("main", *args)
 
 def vmobj_to_list(o):
     if isinstance(o, tvm.relay.backend.vm.Tensor):
@@ -577,35 +572,4 @@ def test_add_op_broadcast():
 
 
 if __name__ == "__main__":
-    test_id()
-    test_op()
-    test_cond()
-    test_simple_if()
-    test_simple_call()
-    test_count_loop()
-    test_sum_loop()
-    test_tuple_fst()
-    test_tuple_second()
-    test_let_scalar()
-    test_let_tensor()
-    test_split()
-    test_split_no_fuse()
-    test_list_constructor()
-    test_let_tensor()
-    test_let_scalar()
-    test_compose()
-    test_list_hd()
-    test_list_tl_empty_list()
-    test_list_tl()
-    test_list_nth()
-    test_list_update()
-    test_list_length()
-    test_list_map()
-    test_list_foldl()
-    test_list_foldr()
-    test_list_sum()
-    test_list_filter()
-    test_closure()
-    test_add_op_scalar()
-    test_add_op_tensor()
-    test_add_op_broadcast()
+    pytest.main()
index 0327c14..b31fce7 100644 (file)
@@ -107,9 +107,9 @@ def test_serializer():
     assert any(item.startswith('fused_multiply') for item in prim_ops)
 
     code = exe.bytecode
-    assert "main 8 2 8" in code
-    assert "f1 5 1 6" in code
-    assert "f2 5 1 6" in code
+    assert "main(x1, y1)" in code
+    assert "f1(x)" in code
+    assert "f2(y)" in code
 
     code, lib = exe.save()
     assert isinstance(code, bytearray)
index 35f5905..6cfe6e8 100644 (file)
@@ -28,7 +28,7 @@ def test_basic():
     ctx = tvm.cpu()
     if not relay.profiler_vm.enabled():
         return
-    exe = relay.profiler_vm.compile(mod, target, params=params)
+    exe = relay.vm.compile(mod, target, params=params)
     vm = relay.profiler_vm.VirtualMachineProfiler(exe)
     vm.init(ctx)