[Relay][Runtime] Add VM compiler. (#3139)
authorJared Roesch <roeschinc@gmail.com>
Sat, 11 May 2019 22:08:13 +0000 (18:08 -0400)
committerGitHub <noreply@github.com>
Sat, 11 May 2019 22:08:13 +0000 (18:08 -0400)
* Implement the VM compiler

* Fix issues

* Fix ASF headers

* Fix test issue

* Apply typo fixes.

* Update src/relay/backend/vm/compiler.cc

Co-Authored-By: 雾雨魔理沙 <lolisa@marisa.moe>
* Refactor compiler

* Fix

* Fix

* Fix in benchmark

* Fix

* Address comments

include/tvm/relay/pass.h
include/tvm/runtime/vm.h
src/relay/backend/vm/compiler.cc [new file with mode: 0644]
src/relay/backend/vm/inline_primitives.cc [new file with mode: 0644]
src/relay/backend/vm/lambda_lift.cc [new file with mode: 0644]
src/relay/backend/vm/vm.cc [new file with mode: 0644]
src/relay/op/tensor/reduce.cc
src/relay/pass/dead_code.cc
src/runtime/vm/vm.cc
tests/python/relay/benchmarking/benchmark_vm.py [new file with mode: 0644]
tests/python/relay/test_vm.py [new file with mode: 0644]

index 43831fc..3106792 100644 (file)
@@ -65,6 +65,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/adt.h>
+#include <tvm/runtime/vm.h>
 #include <string>
 #include <vector>
 
@@ -593,6 +594,18 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
  * As a side effect, code size will explode.
  */
 Expr PartialEval(const Expr& e);
+
+namespace vm {
+
+/*! \brief Compile a module, and construct the virtual machine.
+ *
+ * \param mod The module to compile.
+ * \return The constructed virtual machine.
+ */
+runtime::vm::VirtualMachine CompileModule(const Module& mod);
+
+}  // namespace vm
+
 }  // namespace relay
 }  // namespace tvm
 
index 0a0a4de..8911ad4 100644 (file)
@@ -265,7 +265,7 @@ struct Instruction {
 
   Instruction();
   Instruction(const Instruction& instr);
-  Instruction& operator=(const Instruction& instr) = delete;
+  Instruction& operator=(const Instruction& instr);
   ~Instruction();
 
   friend std::ostream& operator<<(std::ostream& os, const Instruction&);
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
new file mode 100644 (file)
index 0000000..97f03c6
--- /dev/null
@@ -0,0 +1,616 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/backend/vm/compiler.cc
+ * \brief A compiler from relay::Module to the VM byte code.
+ */
+
+#include <tvm/relay/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/interpreter.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include "../../../runtime/vm/naive_allocator.h"
+#include "../../backend/compile_engine.h"
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::vm;
+
+// (@jroesch): VM passes, eventually declare as passes.
+bool IsClosure(const Function& func);
+Module LambdaLift(const Module& module);
+Module InlinePrimitives(const Module& module);
+
+template <typename T, typename U>
+using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
+using TagMap = NodeMap<tvm::relay::Constructor, Index>;
+using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
+using GlobalMap = NodeMap<GlobalVar, Index>;
+using ConstMap = NodeMap<Constant, Index>;
+using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
+
+struct VMCompilerContext {
+  // The module context for the compilation
+  Module module;
+  // Error reporter
+  ErrorReporter err_reporter;
+  // Map from a unique integer to ADT constructor tag
+  TagNameMap tag_index_map;
+  // Map from ADT constructor tag to a unique integer
+  TagMap tag_map;
+  // Map from global var to a unique integer
+  GlobalMap global_map;
+  // Map from Const object to its index in const pool
+  ConstMap const_map;
+  // Map from Const tensor shape to its index in const pool
+  ConstTensorShapeMap const_tensor_shape_map;
+  // List of lowered functions
+  std::vector<LoweredFunc> lowered_funcs;
+};
+
+// Compute the constant pool, i.e a mapping from Constant node to constant index.
+struct ConstantPool : ExprVisitor {
+  std::set<GlobalVar> visited;
+  Module module;
+  ConstMap const_map;
+  ConstTensorShapeMap const_tensor_shape_map;
+
+  size_t index;
+
+  explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
+
+  void VisitExpr_(const GlobalVarNode* var_node) {
+    auto gvar = GetRef<GlobalVar>(var_node);
+    if (visited.find(gvar) == visited.end()) {
+      visited.insert(gvar);
+      this->VisitExpr(this->module->Lookup(gvar));
+    }
+  }
+
+  void AddConstantTensorShape(TensorType expr, NDArray value) {
+    auto it = this->const_tensor_shape_map.find(expr);
+    if (it == this->const_tensor_shape_map.end()) {
+      this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)});
+    }
+  }
+
+  void VisitExpr_(const ConstantNode* const_node) {
+    auto konst = GetRef<Constant>(const_node);
+    auto it = this->const_map.find(konst);
+    if (it == this->const_map.end()) {
+      this->const_map.insert({konst, index++});
+    }
+  }
+
+  NDArray GetTensorConstant(const TensorTypeNode* ttype) {
+    std::vector<int64_t> shapes;
+    for (auto sh : ttype->shape) {
+      shapes.push_back(Downcast<tvm::Integer>(sh)->value);
+    }
+    int64_t s = shapes.size();
+    DLContext cpu_ctx;
+    cpu_ctx.device_type = kDLCPU;
+    cpu_ctx.device_id = 0;
+    auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx);
+    int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
+    for (size_t i = 0; i < shapes.size(); ++i) {
+      dims[i] = shapes[i];
+    }
+    return shape_tensor;
+  }
+
+  void VisitExpr_(const CallNode* call_node) {
+    for (auto arg : call_node->args) {
+      this->VisitExpr(arg);
+    }
+
+    Expr op = call_node->op;
+    auto func_node = op.as<FunctionNode>();
+    if (func_node) {
+      auto ret_type = call_node->checked_type();
+      if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+        auto shape = GetTensorConstant(ttype);
+        auto tensor_type = GetRef<TensorType>(ttype);
+        AddConstantTensorShape(tensor_type, shape);
+      } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+        for (size_t i = 0; i < ttype->fields.size(); ++i) {
+          auto f = ttype->fields[i];
+          auto f_type = f.as<TensorTypeNode>();
+          auto shape = GetTensorConstant(f_type);
+          auto tensor_type = GetRef<TensorType>(f_type);
+          AddConstantTensorShape(tensor_type, shape);
+        }
+      }
+    }
+  }
+};
+
+std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
+  auto cp = ConstantPool(module);
+  for (auto& func : module->functions) {
+    cp.VisitExpr(func.first);
+  }
+  return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
+}
+
+void InstructionPrint(std::ostream& os, const Instruction& instr);
+
+struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
+  /*! \brief Store the expression a variable points to. */
+  std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
+
+  std::vector<Instruction> instructions;
+
+  // var -> register num
+  std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map;
+
+  size_t last_register;
+
+  // Total number of virtual registers allocated
+  size_t registers_num;
+  CompileEngine engine;
+
+  /*! \brief The functions that have been lowered. */
+  std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
+
+  /*! \brief Global shared meta data */
+  VMCompilerContext* context;
+
+  VMCompiler(VMCompilerContext* context)
+      : instructions(),
+        var_register_map(),
+        last_register(0),
+        registers_num(0),
+        engine(CompileEngine::Global()),
+        context(context)
+        {}
+
+  size_t NewRegister() { return registers_num++; }
+
+  inline void Emit(const Instruction& instr) {
+    DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
+    CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
+    switch (instr.op) {
+      case Opcode::AllocDatatype:
+      case Opcode::AllocTensor:
+      case Opcode::GetField:
+      case Opcode::LoadConst:
+      case Opcode::Select:
+      case Opcode::Invoke:
+      case Opcode::AllocClosure:
+      case Opcode::Move:
+      case Opcode::InvokeClosure:
+        last_register = instr.dst;
+        break;
+      case Opcode::InvokePacked:
+        last_register = instr.packed_args[instr.arity - 1];
+        break;
+      case Opcode::If:
+      case Opcode::Ret:
+      case Opcode::Goto:
+        break;
+    }
+    instructions.push_back(instr);
+  }
+
+  void VisitExpr_(const ConstantNode* const_node) {
+    auto rconst = GetRef<Constant>(const_node);
+    auto it = this->context->const_map.find(rconst);
+    CHECK(it != this->context->const_map.end());
+    Emit(Instruction::LoadConst(it->second, NewRegister()));
+  }
+
+  void VisitExpr_(const VarNode* var_node) {
+    auto var = GetRef<Var>(var_node);
+    auto reg_it = this->var_register_map.find(var);
+    CHECK(reg_it != this->var_register_map.end());
+    last_register = reg_it->second;
+  }
+
+  void VisitExpr_(const TupleNode* tuple_node) {
+    auto tuple = GetRef<Tuple>(tuple_node);
+    std::vector<Index> fields_registers;
+
+    for (auto& field : tuple->fields) {
+      this->VisitExpr(field);
+      fields_registers.push_back(last_register);
+    }
+
+    // TODO(@jroesch): use correct tag
+    Emit(Instruction::AllocDatatype(
+      0,
+      tuple->fields.size(),
+      fields_registers,
+      NewRegister()));
+  }
+
+  void VisitExpr_(const MatchNode* match_node) {
+    auto match = GetRef<Match>(match_node);
+    LOG(FATAL) << "translation of match nodes to the VM is"
+               << "currently unsupported" << std::endl;
+  }
+
+  void VisitExpr_(const LetNode* let_node) {
+    DLOG(INFO) << let_node->value << std::endl;
+    this->VisitExpr(let_node->value);
+    DLOG(INFO) << this->last_register << std::endl;
+    var_register_map.insert({let_node->var, this->last_register});
+    this->VisitExpr(let_node->body);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* get_node) {
+    auto get = GetRef<TupleGetItem>(get_node);
+    this->VisitExpr(get->tuple);
+    auto tuple_register = last_register;
+    Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
+  }
+
+  void VisitExpr_(const GlobalVarNode* gvar) {
+    LOG(FATAL) << "Global variables should only appear in the call position";
+  }
+
+  void VisitExpr_(const IfNode* if_node) {
+    this->VisitExpr(if_node->cond);
+
+    size_t cond_register = last_register;
+
+    auto after_cond = this->instructions.size();
+
+    this->Emit(Instruction::If(cond_register, 0, 0));
+    this->VisitExpr(if_node->true_branch);
+
+    size_t true_register = last_register;
+
+    Emit(Instruction::Goto(0));
+
+    // Finally store how many instructions there are in the
+    // true branch.
+    auto after_true = this->instructions.size();
+
+    this->VisitExpr(if_node->false_branch);
+
+    size_t false_register = last_register;
+
+    // Compute the total number of instructions
+    // after generating false.
+    auto after_false = this->instructions.size();
+
+    // Now we will compute the jump targets in order
+    // to properly patch the instruction with the
+    // the requiste targets.
+
+    // After we emit the true body, and false body,
+    // we patch up the if instruction, and goto.
+    auto true_offset = 1;
+    auto false_offset = after_true - after_cond;
+    this->instructions[after_cond].true_offset = true_offset;
+    this->instructions[after_cond].false_offset = false_offset;
+
+    // Patch the Goto.
+    this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1;
+
+    Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister()));
+  }
+
+  Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
+    DataType dtype = ttype->dtype;
+    TVMType dltype = Type2TVMType(dtype);
+
+    auto tensor_type = GetRef<TensorType>(ttype);
+    auto it = this->context->const_tensor_shape_map.find(tensor_type);
+    if (it == this->context->const_tensor_shape_map.end()) {
+      DLOG(INFO) << "Can not find constant shape for " << tensor_type;
+    } else {
+      Emit(Instruction::LoadConst(it->second.first, NewRegister()));
+    }
+
+    return Instruction::AllocTensor(last_register, dltype, NewRegister());
+  }
+
+  void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
+                           const Type& ret_type) {
+    std::vector<Instruction> allocs;
+    size_t return_num = 0;
+    if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+      // Allocate space for the return tensor.
+      auto alloc = AllocTensorFromType(ttype);
+      allocs.push_back(alloc);
+      return_num = 1;
+    } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+      std::vector<Index> fields_registers;
+
+      for (size_t i = 0; i < ttype->fields.size(); ++i) {
+        auto f = ttype->fields[i];
+        auto f_type = f.as<TensorTypeNode>();
+        allocs.push_back(AllocTensorFromType(f_type));
+        fields_registers.push_back(allocs.back().dst);
+      }
+      return_num = ttype->fields.size();
+    } else {
+      LOG(FATAL) << "Unsupported return value type";
+    }
+
+    for (auto& alloc : allocs) {
+      Emit(alloc);
+      args_registers.push_back(alloc.dst);
+    }
+
+    // Next generate the invoke instruction.
+    CHECK(func->IsPrimitive());
+    auto target = Target::create("llvm");
+    auto key = CCacheKeyNode::make(func, target);
+    auto cfunc = engine->Lower(key);
+    // TODO(jroesch): support lowered funcs for multiple targets
+    CHECK_EQ(cfunc->funcs.size(), 1);
+    auto op_index = -1;
+    if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
+      op_index = this->context->lowered_funcs.size();
+      this->context->lowered_funcs.push_back(cfunc->funcs[0]);
+      seen_funcs[cfunc->funcs[0]] = op_index;
+    } else {
+      op_index = seen_funcs[cfunc->funcs[0]];
+    }
+
+    // If Tensor, 1
+    // If Tuple, size of tuple
+    size_t arity = func->params.size() + return_num;
+    Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
+    if (return_num > 1) {
+      // return value is a tuple, we need to create a tuple
+      std::vector<Index> fields_registers;
+      for (size_t i = func->params.size(); i < arity; ++i) {
+        fields_registers.push_back(args_registers[i]);
+      }
+      Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
+    }
+  }
+
+  void VisitExpr_(const CallNode* call_node) {
+    std::vector<Index> args_registers;
+
+    for (auto arg : call_node->args) {
+      CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
+      this->VisitExpr(arg);
+      args_registers.push_back(last_register);
+    }
+
+    Expr op = call_node->op;
+
+    if (auto func_node = op.as<FunctionNode>()) {
+      CHECK(func_node->IsPrimitive());
+      EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
+    } else if (auto global_node = op.as<GlobalVarNode>()) {
+      auto global = GetRef<GlobalVar>(global_node);
+      auto it = this->context->global_map.find(global);
+      CHECK(it != this->context->global_map.end());
+      DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
+                      << " with func_index=" << it->second;
+
+      auto func = this->context->module->Lookup(global);
+      if (IsClosure(func)) {
+        auto arity = func->params.size();
+        std::vector<Index> free_var_registers;
+        for (size_t i = 0; i < arity; ++i) {
+          free_var_registers.push_back(var_register_map.at(func->params[i]));
+        }
+        Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
+      } else {
+        Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
+      }
+    } else if (auto constructor_node = op.as<ConstructorNode>()) {
+      auto constructor = GetRef<Constructor>(constructor_node);
+      auto tag = GetConstructorTag(constructor);
+      Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
+    } else if (auto var_node = op.as<VarNode>()) {
+      VisitExpr(GetRef<Var>(var_node));
+      Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
+    } else {
+      LOG(FATAL) << "unsupported case in vm compiler: " << op;
+    }
+  }
+
+  size_t GetConstructorTag(tvm::relay::Constructor constructor) {
+    auto it = this->context->tag_map.find(constructor);
+    if (it != this->context->tag_map.end()) {
+      return it->second;
+    } else {
+      auto tag = this->context->tag_map.size();
+      this->context->tag_map[constructor] = tag;
+      this->context->tag_index_map[tag] = constructor;
+      return tag;
+    }
+  }
+
+  void VisitExpr_(const FunctionNode* func_node) {
+    if (!func_node->IsPrimitive()) {
+      LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
+                 << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
+                 << "AST: " << GetRef<Function>(func_node);
+    }
+  }
+
+  void CompileClosure(const Function& func) {
+    // We first layout the function arguments.
+    auto inner_func = Downcast<Function>(func->body);
+
+    size_t i = 0;
+    for (auto param : inner_func->params) {
+      auto arg_register = NewRegister();
+      CHECK_EQ(i, arg_register);
+      var_register_map.insert({param, arg_register});
+      i++;
+    }
+
+    // We then assign register num to the free variables
+    for (auto param : func->params) {
+      auto arg_register = NewRegister();
+      CHECK_EQ(i, arg_register);
+      var_register_map.insert({param, arg_register});
+      i++;
+    }
+
+    // We will now process the body like normal.
+    this->VisitExpr(inner_func->body);
+  }
+
+  void Compile(const Function& func) {
+    // We need to generate code specially for lifted closures.
+    if (IsClosure(func)) {
+      CompileClosure(func);
+      return;
+    }
+
+    for (size_t i = 0; i < func->params.size(); ++i) {
+      auto arg_register = NewRegister();
+      CHECK_EQ(arg_register, i);
+      var_register_map.insert({func->params[i], arg_register});
+    }
+
+    this->VisitExpr(func->body);
+  }
+};
+
+void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
+                           std::vector<PackedFunc>* packed_funcs) {
+  runtime::Module mod;
+  if (lowered_funcs.size() > 0) {
+    // TODO(@jroesch): we need to read target from build config
+    Target target = Target::create("llvm");
+    if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+      mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
+    } else {
+      LOG(FATAL) << "relay.backend.build is not registered";
+    }
+    CHECK(mod.operator->());
+    for (auto lfunc : lowered_funcs) {
+      packed_funcs->push_back(mod.GetFunction(lfunc->name));
+    }
+  }
+}
+
+VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
+  DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
+  size_t params = func->params.size();
+  VMCompiler compiler(context);
+  compiler.Compile(func);
+  // return the last evaluated expression
+  compiler.instructions.push_back(Instruction::Ret(compiler.last_register));
+
+  // Would like to refactor this so we only check if closure once.
+  if (IsClosure(func)) {
+    auto inner_params = Downcast<Function>(func->body)->params.size();
+    return VMFunction(var->name_hint, params + inner_params, compiler.instructions,
+                      compiler.registers_num);
+  } else {
+    return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num);
+  }
+}
+
+Module OptimizeModule(const Module& mod) {
+  ToANormalForm(mod->entry_func, mod);
+  InlinePrimitives(mod);
+  LambdaLift(mod);
+  return InlinePrimitives(mod);
+}
+
+void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
+  // First we populate global map.
+  size_t global_index = 0;
+  for (auto named_func : mod->functions) {
+    auto gvar = named_func.first;
+    global_map->insert({gvar, global_index++});
+  }
+}
+
+VirtualMachine CompileModule(const Module& mod_ref) {
+  Module mod = mod_ref;
+
+  // Run some optimizations first, this code should
+  // be moved to pass manager.
+  mod = OptimizeModule(mod);
+
+  VirtualMachine vm;
+
+  VMCompilerContext context;
+  context.module = mod;
+
+  // Populate the global map.
+  //
+  // This maps global variables to a global index
+  // in the VMFunction table.
+  PopulateGlobalMap(&context.global_map, mod);
+
+  // Next we populate constant map.
+  auto constant_analysis_result = LayoutConstantPool(mod);
+  context.const_map = std::get<0>(constant_analysis_result);
+  context.const_tensor_shape_map = std::get<1>(constant_analysis_result);
+
+  // Next we get ready by allocating space for
+  // the global state.
+  vm.functions.resize(mod->functions.size());
+  vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size());
+
+  for (auto pair : context.const_map) {
+    vm.constants[pair.second] = Object::Tensor(pair.first->data);
+  }
+
+  for (auto pair : context.const_tensor_shape_map) {
+    vm.constants[pair.second.first] = Object::Tensor(pair.second.second);
+  }
+
+  for (auto named_func : mod->functions) {
+    auto gvar = named_func.first;
+    auto func = named_func.second;
+    auto vm_func = CompileFunc(&context, gvar, func);
+
+    size_t func_index = context.global_map.at(gvar);
+    CHECK(func_index < vm.functions.size());
+    vm.functions[func_index] = vm_func;
+  }
+
+#ifdef USE_RELAY_DEBUG
+  for (auto vm_func : vm.functions) {
+    std::cout << "Function: " << vm_func.name << std::endl
+              << vm_func << "-------------" << std::endl;
+  }
+#endif  // USE_RELAY_DEBUG
+
+  PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs);
+
+  for (auto gv : context.global_map) {
+    vm.global_map_.insert({gv.first->name_hint, gv.second});
+  }
+
+  return vm;
+}
+
+}  // namespace vm
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
new file mode 100644 (file)
index 0000000..b033a37
--- /dev/null
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/inline_primitives.cc
+ * \brief Ensure that primitives only appear in the call position.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <vector>
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+struct PrimitiveInliner : ExprMutator {
+  Module module_;
+  std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
+
+  explicit PrimitiveInliner(const Module& module) : module_(module) {}
+
+  Expr VisitExpr_(const LetNode* let_node) {
+    var_map.insert({let_node->var, VisitExpr(let_node->value)});
+    return ExprMutator::VisitExpr_(let_node);
+  }
+
+  Expr VisitExpr_(const CallNode* call) {
+    Expr op = call->op;
+    // For now just collapse the chain of variables to see if
+    // they point to a primitive function.
+    const VarNode* var_node;
+
+    // Collapse a chain of let bindings
+    //
+    // let x = fn (..) { .. };
+    // let y = x
+    // let w = y
+    // in w(...)
+    while ((var_node = op.as<VarNode>())) {
+      auto var = GetRef<Var>(var_node);
+      DLOG(INFO) << "Var: " << var << std::endl;
+      auto it = var_map.find(GetRef<Var>(var_node));
+      if (it != var_map.end()) {
+        op = it->second;
+      } else {
+        return ExprMutator::VisitExpr_(call);
+      }
+    }
+
+    if (auto func = op.as<FunctionNode>()) {
+      if (func->IsPrimitive()) {
+        return CallNode::make(GetRef<Function>(func), call->args, call->attrs, call->type_args);
+      }
+    }
+
+    if (auto global = op.as<GlobalVarNode>()) {
+      return CallNode::make(GetRef<GlobalVar>(global), call->args, call->attrs, call->type_args);
+    }
+
+    return ExprMutator::VisitExpr_(call);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) {
+    if (func->IsPrimitive()) {
+      return GetRef<Function>(func);
+    } else {
+      return ExprMutator::VisitExpr_(func);
+    }
+  }
+
+  Function Inline(const Function& func) {
+    DLOG(INFO) << "Before inlining primitives: " << std::endl
+                    << "func= " << AsText(func, false) << std::endl;
+
+    auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
+                                      func->type_params, func->attrs);
+
+    inlined = Downcast<Function>(DeadCodeElimination(inlined));
+
+    DLOG(INFO) << "After inlining primitives" << std::endl
+                    << "after_func= " << AsText(inlined, false) << std::endl;
+    return inlined;
+  }
+};
+
+// TODO(@jroesch): write verifier
+
+/* This pass will eliminate primitives which have been lifted by the ANF
+ * transform inlining them directly into call sites.
+ *
+ * This makes VM related code generation easier as the call target is always
+ * a primitive function.
+ *
+ * let prim = fn(...) { ... };
+ * prim(...)
+ *
+ * will become:
+ *
+ * (fn(...) { ... })(...)
+ */
+Module InlinePrimitives(const Module& module) {
+  PrimitiveInliner inliner(module);
+
+  tvm::Map<GlobalVar, Function> updates;
+
+  // There is an ordering bug here.
+  for (auto pair : module->functions) {
+    auto global = pair.first;
+    auto func = pair.second;
+    updates.Set(global, inliner.Inline(func));
+  }
+
+  for (auto pair : updates) {
+    module->Add(pair.first, pair.second, true);
+  }
+
+  return module;
+}
+
+}  // namespace vm
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
new file mode 100644 (file)
index 0000000..13d8112
--- /dev/null
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/lambda_lift.cc
+ * \brief Lift all local functions into global functions.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <vector>
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+static const char* kIsClosure = "IsClosure";
+
+inline std::string GenerateName(const Function& func) {
+  size_t hash = StructuralHash()(func);
+  return std::string("lifted_name") + std::to_string(hash);
+}
+
+bool IsClosure(const Function& func) {
+  NodeRef res = FunctionGetAttr(func, kIsClosure);
+  const ir::IntImm* pval = res.as<ir::IntImm>();
+  return pval && pval->value != 0;
+}
+
+Function MarkClosure(const Function& func) {
+  return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
+}
+
+struct LambdaLifter : ExprMutator {
+  Module module_;
+  std::vector<std::pair<GlobalVar, Function>> lifted_;
+  explicit LambdaLifter(const Module& module) : module_(module) {}
+
+  Expr VisitExpr_(const FunctionNode* func_node) final {
+    auto func = GetRef<Function>(func_node);
+
+    // We should not transform primitive functions.
+    if (func->IsPrimitive()) {
+      return std::move(func);
+    }
+
+    auto free_vars = FreeVars(func);
+    auto free_type_vars = FreeTypeVars(func, module_);
+    auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
+
+    // When performing this optimization there are two
+    // cases.
+    //
+    // The first case in which we have no free variables
+    // we can just lift the function into the global
+    // environment without needing to allocate a closure.
+    //
+    //
+    // The second case requires that we generate a special
+    // function with makes a distinction between allocating
+    // a closure, and then the code for the closure.
+    //
+    // We represent a closure allocation by lifting the
+    // closure to a global function which takes its
+    // captured arguments and then directly returns
+    // the function representing the closure's code.
+    //
+    // When we generate code later on a call to the "outer"
+    // function marked as a closure is used to emit allocation
+    // code for the closure's environment.
+    //
+    // The "inner" function is should be used to generate the
+    // code for the closure.
+    Function lifted_func;
+    if (free_vars.size() == 0) {
+      lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
+    } else {
+      lifted_func =
+          FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
+
+      lifted_func = MarkClosure(lifted_func);
+    }
+
+    CHECK(lifted_func.defined());
+
+    auto name = GenerateName(lifted_func);
+    auto global = this->module_->GetGlobalVar(name);
+
+    lifted_.push_back({global, lifted_func});
+
+    if (free_vars.size() == 0) {
+      return std::move(global);
+    } else {
+      // If we need to allocate a closure
+      // we pass the variables in its environment
+      // here.
+      Array<Expr> fvs;
+      for (auto fv : free_vars) {
+        fvs.push_back(fv);
+      }
+      return CallNode::make(global, fvs);
+    }
+  }
+
+  Function Lift(const Function& func) {
+    DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
+    return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
+                              func->type_params, func->attrs);
+  }
+};
+
+/* The goal of this pass is to lift out any nested functions into top-level
+ * functions.
+ *
+ * We will lift the functions out into globals which take the set of the free vars
+ * and then return a function whcih has b
+ */
+Module LambdaLift(const Module& module) {
+  LambdaLifter lifter(module);
+
+  tvm::Map<GlobalVar, Function> updates;
+
+  // There is an ordering bug here.
+  for (auto pair : module->functions) {
+    auto global = pair.first;
+    auto func = pair.second;
+    updates.Set(global, lifter.Lift(func));
+  }
+
+  for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
+    module->Add(i->first, i->second);
+  }
+
+  for (auto pair : updates) {
+    module->Add(pair.first, pair.second, true);
+  }
+
+  return module;
+}
+
+}  // namespace vm
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc
new file mode 100644 (file)
index 0000000..34d067b
--- /dev/null
@@ -0,0 +1,159 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/backend/vm/vm.cc
+ * \brief The Relay virtual machine.
+ */
+
+#include <tvm/relay/interpreter.h>
+#include <tvm/logging.h>
+#include <tvm/relay/module.h>
+#include <tvm/runtime/vm.h>
+#include <tvm/relay/pass.h>
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+using tvm::runtime::Object;
+using tvm::runtime::ObjectTag;
+using tvm::runtime::vm::VirtualMachine;
+
+
+VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
+  auto vm = CompileModule(module);
+  vm.Init(ctxs);
+  return vm;
+}
+
+Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
+                      const std::vector<Object>& vm_args) {
+  VirtualMachine vm = FromModule(module, ctxs);
+  // TODO(zhiics): This measurement is for temporary usage. Remove it later. We
+  // need to introduce a better profiling method.
+#if ENABLE_PROFILING
+  DLOG(INFO) << "Entry function is " << module->entry_func << std::endl;
+  auto start = std::chrono::high_resolution_clock::now();
+#endif  // ENABLE_PROFILING
+  Object res = vm.Invoke(module->entry_func->name_hint, vm_args);
+#if ENABLE_PROFILING
+  auto end = std::chrono::high_resolution_clock::now();
+  auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
+  LOG(INFO) << "Inference time: " << duration << "ms\n";
+#endif  // ENABLE_PROFILING
+  return res;
+}
+
+Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
+  CHECK(module.defined() && type.defined());
+  switch (obj->tag) {
+    case ObjectTag::kTensor: {
+      CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
+      return TensorValueNode::make(ToNDArray(obj));
+    }
+    case ObjectTag::kDatatype: {
+      // const auto* tuple_type
+      // const auto& data_type = obj.AsDatatype();
+
+      // tvm::Array<Value> fields;
+      // for (size_t i = 0; i < data_type->fields.size(); ++i) {
+      //   fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
+      // }
+
+      // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
+      LOG(FATAL) << "fix me";
+    }
+    default:
+      LOG(FATAL) << "unsupported return value of type: " << obj->tag;
+      return Value();
+  }
+}
+
+TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
+  *ret = Object::Tensor(args[0]);
+});
+
+TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
+  std::vector<Object> fields;
+  for (auto i = 0; i < args.size(); i++) {
+    fields.push_back(args[i]);
+  }
+  *ret = Object::Tuple(fields);
+});
+
+template <typename T>
+std::string ToString(const T& t) {
+  std::stringstream s;
+  s << t;
+  return s.str();
+}
+
+TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
+  Object obj = args[0];
+  *ret = ToString(obj->tag);
+});
+
+TVM_REGISTER_API("relay._vm._Datatype")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+    int itag = args[0];
+    size_t tag = static_cast<size_t>(itag);
+    std::vector<Object> fields;
+    for (int i = 1; i < args.size(); i++) {
+      fields.push_back(args[i]);
+    }
+
+    *ret = Object::Datatype(tag, fields);
+});
+
+TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
+  NodeRef to_compile = args[0];
+  TVMContext ctx;
+  int dev_type = args[1];
+  ctx.device_type = static_cast<DLDeviceType>(dev_type);
+  ctx.device_id = args[2];
+
+  Module module;
+  if (to_compile.as<FunctionNode>()) {
+    Function to_compile = args[0];
+    module = ModuleNode::FromExpr(to_compile);
+  } else if (to_compile.as<ModuleNode>()) {
+    module = args[0];
+  } else {
+    LOG(FATAL) << "expected function or module";
+  }
+
+  auto return_type = module->Lookup(module->entry_func)->ret_type;
+
+  std::vector<Object> vm_args;
+  for (auto i = 3; i < args.size(); i++) {
+    Object obj = args[i];
+    vm_args.push_back(obj);
+  }
+
+  auto result = EvaluateModule(module, {ctx}, vm_args);
+  DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
+  *ret = VMToValue(module, return_type, result);
+});
+
+}  // namespace vm
+}  // namespace relay
+}  // namespace tvm
index b889b6c..a4ebd1e 100644 (file)
@@ -154,6 +154,9 @@ Array<Tensor> ReduceCompute(const Attrs& attrs,
                             F f) {
   const ReduceAttrs* param = attrs.as<ReduceAttrs>();
   CHECK(param != nullptr);
+  if (inputs[0]->shape.size() == 0) {
+    return { topi::identity(inputs[0]) };
+  }
   auto axes = param->axis;
   if (param->exclude) {
     axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
@@ -251,7 +254,6 @@ bool ReduceRel(const Array<Type>& types,
   CHECK_EQ(types.size(), 2);
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) return false;
-  CHECK(static_cast<int>(data->shape.size()) != 0);
   std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
 
   const ReduceAttrs* param = attrs.as<ReduceAttrs>();
index c5c4f33..533c214 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -124,7 +124,8 @@ class CalcDep : private ExprVisitor {
     friend CalcDep;
 
     bool HasLet(const Var& v) {
-      return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
+      // TODO(@jroesch): MK fix me
+      return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
     }
 
     Expr VisitExpr_(const VarNode* op) final {
index d7ea53e..b2d326e 100644 (file)
@@ -118,6 +118,86 @@ Instruction::Instruction(const Instruction& instr) {
   }
 }
 
+template<typename T>
+static inline void FreeIf(T* t) {
+  if (t != nullptr) {
+    delete t;
+  }
+}
+
+Instruction& Instruction::operator=(const Instruction& instr) {
+  this->op = instr.op;
+  this->dst = instr.dst;
+
+  switch (instr.op) {
+    case Opcode::Move:
+      this->from = instr.from;
+      return *this;
+    case Opcode::Select:
+      this->select_cond = instr.select_cond;
+      this->select_op1 = instr.select_op1;
+      this->select_op2 = instr.select_op2;
+      return *this;
+    case Opcode::Ret:
+      this->result = instr.result;
+      return *this;
+    case Opcode::AllocTensor:
+      this->shape_register = instr.shape_register;
+      this->dtype = instr.dtype;
+      return *this;
+    case Opcode::AllocDatatype:
+      this->constructor_tag = instr.constructor_tag;
+      this->num_fields = instr.num_fields;
+      FreeIf(this->datatype_fields);
+      this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
+      return *this;
+    case Opcode::AllocClosure:
+      this->clo_index = instr.clo_index;
+      this->num_freevar = instr.num_freevar;
+      FreeIf(this->free_vars);
+      this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
+      return *this;
+    case Opcode::InvokePacked:
+      this->packed_index = instr.packed_index;
+      this->arity = instr.arity;
+      this->output_size = instr.output_size;
+      FreeIf(this->packed_args);
+      this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
+      return *this;
+    case Opcode::InvokeClosure:
+      this->closure = instr.closure;
+      this->closure_args_num = instr.closure_args_num;
+      FreeIf(this->closure_args);
+      this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
+      return *this;
+    case Opcode::Invoke:
+      this->func_index = instr.func_index;
+      this->num_args = instr.num_args;
+      FreeIf(this->invoke_args_registers);
+      this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
+      return *this;
+    case Opcode::If:
+      this->if_cond = instr.if_cond;
+      this->true_offset = instr.true_offset;
+      this->false_offset = instr.false_offset;
+      return *this;
+    case Opcode::LoadConst:
+      this->const_index = instr.const_index;
+      return *this;
+    case Opcode::GetField:
+      this->object = instr.object;
+      this->field_index = instr.field_index;
+      return *this;
+    case Opcode::Goto:
+      this->pc_offset = instr.pc_offset;
+      return *this;
+    default:
+      std::ostringstream out;
+      out << "Invalid instruction " << static_cast<int>(instr.op);
+      throw std::runtime_error(out.str());
+  }
+}
+
 Instruction::~Instruction() {
   switch (this->op) {
     case Opcode::Move:
diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py
new file mode 100644 (file)
index 0000000..e359ade
--- /dev/null
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmarking Relay VM using models from MXNet."""
+import numpy as np
+
+import tvm
+from tvm.contrib import graph_runtime
+from tvm import relay
+from tvm.relay import testing
+
+
+def benchmark_execution(net,
+                        params,
+                        measure=False,
+                        data_shape=(1, 3, 224, 224),
+                        out_shape=(1, 1000),
+                        dtype='float32'):
+    def get_tvm_output(net, data, params, target, ctx, dtype='float32'):
+        with relay.build_config(opt_level=1):
+            graph, lib, params = relay.build(net, target, params=params)
+
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        m.set_input("data", data)
+        m.set_input(**params)
+        m.run()
+        out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
+
+        if measure:
+            print("Evaluate graph runtime inference time cost...")
+            ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
+            # Measure in millisecond.
+            prof_res = np.array(ftimer().results) * 1000
+            print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
+                  (np.mean(prof_res), np.std(prof_res)))
+
+        return out.asnumpy()
+
+    def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'):
+        ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
+        result = ex.evaluate(net)(data, **params)
+        return result.asnumpy().astype(dtype)
+
+    # random input
+    data = np.random.uniform(size=data_shape).astype(dtype)
+    target = "llvm"
+    ctx = tvm.cpu(0)
+
+    tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params,
+                             target, ctx, dtype)
+    vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params,
+                               target, ctx, dtype)
+    tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+def test_mlp():
+    image_shape = (1, 28, 28)
+    net, params = testing.mlp.get_workload(1)
+    benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10))
+
+
+def test_vgg():
+    for n in [11, 16]:
+        net, params = testing.vgg.get_workload(1, num_layers=n)
+        benchmark_execution(net, params)
+
+
+def test_resnet():
+    for n in [18, 50]:
+        net, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
+        benchmark_execution(net, params, True)
+
+
+def test_squeezenet():
+    for version in ['1.0', '1.1']:
+        net, params = testing.squeezenet.get_workload(version=version)
+        benchmark_execution(net, params)
+
+
+def test_inception_v3():
+    image_shape = (3, 299, 299)
+    net, params = testing.inception_v3.get_workload(image_shape=image_shape)
+    benchmark_execution(net, params, data_shape=image_shape)
+
+
+def test_dqn():
+    image_shape = (4, 84, 84)
+    net, params = testing.dqn.get_workload(
+        batch_size=1, image_shape=image_shape)
+    benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18))
+
+
+def test_dcgan():
+    image_shape = (1, 100)
+    net, params = testing.dcgan.get_workload(batch_size=1)
+    benchmark_execution(net, params, data_shape=image_shape)
+
+
+def test_mobilenet():
+    net, params = testing.mobilenet.get_workload(batch_size=1)
+    benchmark_execution(net, params)
+
+
+def test_densenet():
+    net, params = testing.densenet.get_workload(batch_size=1)
+    benchmark_execution(net, params)
+
+
+if __name__ == '__main__':
+    test_resnet()
+    test_vgg()
+    test_squeezenet()
+    test_mobilenet()
+    test_densenet()
+    # The following networks fail
+    # test_inception_v3()
+    # test_mlp()
+    # test_dqn()
+    # test_dcgan()
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
new file mode 100644 (file)
index 0000000..bc99418
--- /dev/null
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from nose.tools import nottest
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay.scope_builder import ScopeBuilder
+from tvm.relay.prelude import Prelude
+
+def veval(f, *args, ctx=tvm.cpu()):
+    if isinstance(f, relay.Expr):
+        ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
+        if len(args) == 0:
+            return ex.evaluate(f)
+        else:
+            return ex.evaluate(f)(*args)
+    else:
+        assert isinstance(f, relay.Module), "expected expression or module"
+        mod = f
+        ex = relay.create_executor('vm', mod=mod, ctx=ctx)
+        if len(args) == 0:
+            return ex.evaluate(mod[mod.entry_func])
+        else:
+            return ex.evaluate(mod[mod.entry_func])(*args)
+
+def test_split():
+    x = relay.var('x', shape=(12,))
+    y = relay.split(x, 3, axis=0).astuple()
+    z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
+    f = relay.Function([x], z)
+
+    x_data = np.random.rand(12,).astype('float32')
+    res = veval(f, x_data)
+    tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+
+def test_id():
+    x = relay.var('x', shape=(10, 10))
+    f = relay.Function([x], x)
+    x_data = np.random.rand(10, 10).astype('float64')
+    res = veval(f, x_data)
+    tvm.testing.assert_allclose(res.asnumpy(), x_data)
+
+def test_op():
+    x = relay.var('x', shape=(10, 10))
+    f = relay.Function([x], x + x)
+    x_data = np.random.rand(10, 10).astype('float32')
+    res = veval(f, x_data)
+    tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
+
+def any(x):
+    x = relay.op.nn.batch_flatten(x)
+    return relay.op.min(x, axis=[0, 1])
+
+def test_cond():
+    x = relay.var('x', shape=(10, 10))
+    y = relay.var('x', shape=(10, 10))
+    # f = relay.Function([x, y], relay.op.equal(x, y))
+    f = relay.Function([x, y], any(relay.op.equal(x, y)))
+    x_data = np.random.rand(10, 10).astype('float32')
+    y_data = np.random.rand(10, 10).astype('float32')
+
+    # same
+    res = veval(f, x_data, x_data)
+    np.testing.assert_allclose(res.asnumpy(), True)
+
+    # diff
+    res = veval(f, x_data, y_data)
+    tvm.testing.assert_allclose(res.asnumpy(), False)
+
+
+def test_simple_if():
+    x = relay.var('x', shape=(10, 10))
+    y = relay.var('y', shape=(10, 10))
+    f = relay.Function([x, y],
+        relay.If(any(relay.op.equal(x, y)), x, y))
+    x_data = np.random.rand(10, 10).astype('float32')
+    y_data = np.random.rand(10, 10).astype('float32')
+
+    # same
+    res = veval(f, x_data, x_data)
+    tvm.testing.assert_allclose(res.asnumpy(), x_data)
+
+    # diff
+    res = veval(f, x_data, y_data)
+    tvm.testing.assert_allclose(res.asnumpy(), y_data)
+
+def test_simple_call():
+    mod = relay.module.Module({})
+    sum_up = relay.GlobalVar('sum_up')
+    i = relay.var('i', shape=[], dtype='int32')
+    sb = ScopeBuilder()
+    sb.ret(i)
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+    mod[sum_up] = func
+    i_data = np.array(0, dtype='int32')
+    iarg = relay.var('i', shape=[], dtype='int32')
+    mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+    result = veval(mod, i_data)
+    tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_count_loop():
+    mod = relay.module.Module({})
+    sum_up = relay.GlobalVar('sum_up')
+    i = relay.var('i', shape=[], dtype='int32')
+    sb = ScopeBuilder()
+    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+        sb.ret(i)
+    with sb.else_scope():
+        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+        rec_call = relay.Call(sum_up, [one_less])
+        sb.ret(relay.add(rec_call, i))
+    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+    mod[sum_up] = func
+    i_data = np.array(0, dtype='int32')
+    iarg = relay.var('i', shape=[], dtype='int32')
+    mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+    result = veval(mod, i_data)
+    tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_sum_loop():
+    mod = relay.module.Module({})
+    sum_up = relay.GlobalVar('sum_up')
+    i = relay.var('i', shape=[], dtype='int32')
+    accum = relay.var('accum', shape=[], dtype='int32')
+    sb = ScopeBuilder()
+    with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
+        sb.ret(accum)
+    with sb.else_scope():
+        one_less = relay.subtract(i, relay.const(1, 'int32'))
+        new_accum = relay.add(accum, i)
+        sb.ret(relay.Call(sum_up, [one_less, new_accum]))
+    func = relay.Function([i, accum], sb.get())
+    mod[sum_up] = func
+    loop_bound = 0
+    i_data = np.array(loop_bound, dtype='int32')
+    accum_data = np.array(0, dtype='int32')
+    iarg = relay.var('i', shape=[], dtype='int32')
+    aarg = relay.var('accum', shape=[], dtype='int32')
+    mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
+    result = veval(mod, i_data, accum_data)
+    tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
+
+def test_tuple_fst():
+    ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
+    tup = relay.var('tup', type_annotation=ttype)
+    f = relay.Function([tup], relay.TupleGetItem(tup, 0))
+    i_data = np.random.rand(41).astype('float32')
+    j_data = np.random.rand(10).astype('float32')
+    result = veval(f, (i_data, j_data))
+    tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_tuple_second():
+    ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
+    tup = relay.var('tup', type_annotation=ttype)
+    f = relay.Function([tup], relay.TupleGetItem(tup, 1))
+    i_data = np.random.rand(41).astype('float32')
+    j_data = np.random.rand(10).astype('float32')
+    result = veval(f, (i_data, j_data))
+    tvm.testing.assert_allclose(result.asnumpy(), j_data)
+
+@nottest
+def test_list_constructor():
+    # TODO(wweic): implement pattern match to support this test
+    def to_list(o):
+        if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
+            return [o.data.asnumpy().tolist()]
+        if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
+            result = []
+            for f in o.fields:
+                result.extend(to_list(f))
+            return result
+
+    mod = relay.Module()
+    p = Prelude(mod)
+
+    nil = p.nil
+    cons = p.cons
+    l = p.l
+
+    one2 = cons(relay.const(1), nil())
+    one3 = cons(relay.const(2), one2)
+    one4 = cons(relay.const(3), one3)
+    f = relay.Function([], one4)
+
+    mod[mod.entry_func] = f
+
+    result = veval(mod)()
+    obj = to_list(result)
+    import pdb; pdb.set_trace()
+    tvm.testing.assert_allclose(obj, np.array([3,2,1]))
+
+def test_let_tensor():
+    sb = relay.ScopeBuilder()
+    shape = (1,)
+    x = relay.var('x', shape=shape, dtype='float32')
+    x1 = relay.var('x1', shape=shape, dtype='float32')
+
+    x1 = sb.let(x1, x)
+    xplusone = x1 + relay.const(42.0, 'float32')
+    sb.ret(xplusone)
+    body = sb.get()
+
+    f = relay.Function([x], body)
+
+    x_data = np.random.rand(*shape).astype('float32')
+    result = veval(f, x_data)
+    tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
+
+def test_let_scalar():
+    sb = relay.ScopeBuilder()
+
+    x = relay.var('x', 'float32')
+    x1 = sb.let('x1', x)
+    xplusone = x1 + relay.const(42.0, 'float32')
+    sb.ret(xplusone)
+    body = sb.get()
+
+    f = relay.Function([x], body)
+
+    x_data = np.array(np.random.rand()).astype('float32')
+    result = veval(f, x_data)
+    tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
+
+def test_closure():
+    x = relay.var('x', shape=())
+    y = relay.var('y', shape=())
+    f = relay.Function([x], x + y)
+    ff = relay.Function([y], f)
+    clo = ff(relay.const(1.0))
+    main = clo(relay.const(2.0))
+    res = veval(main)
+    tvm.testing.assert_allclose(res.asnumpy(), 3.0)
+
+if __name__ == "__main__":
+    test_id()
+    test_op()
+    test_cond()
+    test_simple_if()
+    test_simple_call()
+    test_count_loop()
+    test_sum_loop()
+    test_tuple_fst()
+    test_tuple_second()
+    test_let_scalar()
+    test_let_tensor()
+    # TODO(@jroesch): restore when match is supported
+    # test_list_constructor()
+    test_closure()