#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
+#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
* As a side effect, code size will explode.
*/
Expr PartialEval(const Expr& e);
+
+namespace vm {
+
+/*! \brief Compile a module, and construct the virtual machine.
+ *
+ * \param mod The module to compile.
+ * \return The constructed virtual machine.
+ */
+runtime::vm::VirtualMachine CompileModule(const Module& mod);
+
+} // namespace vm
+
} // namespace relay
} // namespace tvm
Instruction();
Instruction(const Instruction& instr);
- Instruction& operator=(const Instruction& instr) = delete;
+ Instruction& operator=(const Instruction& instr);
~Instruction();
friend std::ostream& operator<<(std::ostream& os, const Instruction&);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file src/relay/backend/vm/compiler.cc
+ * \brief A compiler from relay::Module to the VM byte code.
+ */
+
+#include <tvm/relay/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/interpreter.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include "../../../runtime/vm/naive_allocator.h"
+#include "../../backend/compile_engine.h"
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::vm;
+
+// (@jroesch): VM passes, eventually declare as passes.
+bool IsClosure(const Function& func);
+Module LambdaLift(const Module& module);
+Module InlinePrimitives(const Module& module);
+
+template <typename T, typename U>
+using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
+using TagMap = NodeMap<tvm::relay::Constructor, Index>;
+using TagNameMap = std::unordered_map<size_t, tvm::relay::Constructor>;
+using GlobalMap = NodeMap<GlobalVar, Index>;
+using ConstMap = NodeMap<Constant, Index>;
+using ConstTensorShapeMap = NodeMap<TensorType, std::pair<Index, NDArray>>;
+
+struct VMCompilerContext {
+ // The module context for the compilation
+ Module module;
+ // Error reporter
+ ErrorReporter err_reporter;
+ // Map from a unique integer to ADT constructor tag
+ TagNameMap tag_index_map;
+ // Map from ADT constructor tag to a unique integer
+ TagMap tag_map;
+ // Map from global var to a unique integer
+ GlobalMap global_map;
+ // Map from Const object to its index in const pool
+ ConstMap const_map;
+ // Map from Const tensor shape to its index in const pool
+ ConstTensorShapeMap const_tensor_shape_map;
+ // List of lowered functions
+ std::vector<LoweredFunc> lowered_funcs;
+};
+
+// Compute the constant pool, i.e a mapping from Constant node to constant index.
+struct ConstantPool : ExprVisitor {
+ std::set<GlobalVar> visited;
+ Module module;
+ ConstMap const_map;
+ ConstTensorShapeMap const_tensor_shape_map;
+
+ size_t index;
+
+ explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
+
+ void VisitExpr_(const GlobalVarNode* var_node) {
+ auto gvar = GetRef<GlobalVar>(var_node);
+ if (visited.find(gvar) == visited.end()) {
+ visited.insert(gvar);
+ this->VisitExpr(this->module->Lookup(gvar));
+ }
+ }
+
+ void AddConstantTensorShape(TensorType expr, NDArray value) {
+ auto it = this->const_tensor_shape_map.find(expr);
+ if (it == this->const_tensor_shape_map.end()) {
+ this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)});
+ }
+ }
+
+ void VisitExpr_(const ConstantNode* const_node) {
+ auto konst = GetRef<Constant>(const_node);
+ auto it = this->const_map.find(konst);
+ if (it == this->const_map.end()) {
+ this->const_map.insert({konst, index++});
+ }
+ }
+
+ NDArray GetTensorConstant(const TensorTypeNode* ttype) {
+ std::vector<int64_t> shapes;
+ for (auto sh : ttype->shape) {
+ shapes.push_back(Downcast<tvm::Integer>(sh)->value);
+ }
+ int64_t s = shapes.size();
+ DLContext cpu_ctx;
+ cpu_ctx.device_type = kDLCPU;
+ cpu_ctx.device_id = 0;
+ auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx);
+ int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
+ for (size_t i = 0; i < shapes.size(); ++i) {
+ dims[i] = shapes[i];
+ }
+ return shape_tensor;
+ }
+
+ void VisitExpr_(const CallNode* call_node) {
+ for (auto arg : call_node->args) {
+ this->VisitExpr(arg);
+ }
+
+ Expr op = call_node->op;
+ auto func_node = op.as<FunctionNode>();
+ if (func_node) {
+ auto ret_type = call_node->checked_type();
+ if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+ auto shape = GetTensorConstant(ttype);
+ auto tensor_type = GetRef<TensorType>(ttype);
+ AddConstantTensorShape(tensor_type, shape);
+ } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+ for (size_t i = 0; i < ttype->fields.size(); ++i) {
+ auto f = ttype->fields[i];
+ auto f_type = f.as<TensorTypeNode>();
+ auto shape = GetTensorConstant(f_type);
+ auto tensor_type = GetRef<TensorType>(f_type);
+ AddConstantTensorShape(tensor_type, shape);
+ }
+ }
+ }
+ }
+};
+
+std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
+ auto cp = ConstantPool(module);
+ for (auto& func : module->functions) {
+ cp.VisitExpr(func.first);
+ }
+ return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
+}
+
+void InstructionPrint(std::ostream& os, const Instruction& instr);
+
+struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
+ /*! \brief Store the expression a variable points to. */
+ std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map;
+
+ std::vector<Instruction> instructions;
+
+ // var -> register num
+ std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map;
+
+ size_t last_register;
+
+ // Total number of virtual registers allocated
+ size_t registers_num;
+ CompileEngine engine;
+
+ /*! \brief The functions that have been lowered. */
+ std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
+
+ /*! \brief Global shared meta data */
+ VMCompilerContext* context;
+
+ VMCompiler(VMCompilerContext* context)
+ : instructions(),
+ var_register_map(),
+ last_register(0),
+ registers_num(0),
+ engine(CompileEngine::Global()),
+ context(context)
+ {}
+
+ size_t NewRegister() { return registers_num++; }
+
+ inline void Emit(const Instruction& instr) {
+ DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
+ CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
+ switch (instr.op) {
+ case Opcode::AllocDatatype:
+ case Opcode::AllocTensor:
+ case Opcode::GetField:
+ case Opcode::LoadConst:
+ case Opcode::Select:
+ case Opcode::Invoke:
+ case Opcode::AllocClosure:
+ case Opcode::Move:
+ case Opcode::InvokeClosure:
+ last_register = instr.dst;
+ break;
+ case Opcode::InvokePacked:
+ last_register = instr.packed_args[instr.arity - 1];
+ break;
+ case Opcode::If:
+ case Opcode::Ret:
+ case Opcode::Goto:
+ break;
+ }
+ instructions.push_back(instr);
+ }
+
+ void VisitExpr_(const ConstantNode* const_node) {
+ auto rconst = GetRef<Constant>(const_node);
+ auto it = this->context->const_map.find(rconst);
+ CHECK(it != this->context->const_map.end());
+ Emit(Instruction::LoadConst(it->second, NewRegister()));
+ }
+
+ void VisitExpr_(const VarNode* var_node) {
+ auto var = GetRef<Var>(var_node);
+ auto reg_it = this->var_register_map.find(var);
+ CHECK(reg_it != this->var_register_map.end());
+ last_register = reg_it->second;
+ }
+
+ void VisitExpr_(const TupleNode* tuple_node) {
+ auto tuple = GetRef<Tuple>(tuple_node);
+ std::vector<Index> fields_registers;
+
+ for (auto& field : tuple->fields) {
+ this->VisitExpr(field);
+ fields_registers.push_back(last_register);
+ }
+
+ // TODO(@jroesch): use correct tag
+ Emit(Instruction::AllocDatatype(
+ 0,
+ tuple->fields.size(),
+ fields_registers,
+ NewRegister()));
+ }
+
+ void VisitExpr_(const MatchNode* match_node) {
+ auto match = GetRef<Match>(match_node);
+ LOG(FATAL) << "translation of match nodes to the VM is"
+ << "currently unsupported" << std::endl;
+ }
+
+ void VisitExpr_(const LetNode* let_node) {
+ DLOG(INFO) << let_node->value << std::endl;
+ this->VisitExpr(let_node->value);
+ DLOG(INFO) << this->last_register << std::endl;
+ var_register_map.insert({let_node->var, this->last_register});
+ this->VisitExpr(let_node->body);
+ }
+
+ void VisitExpr_(const TupleGetItemNode* get_node) {
+ auto get = GetRef<TupleGetItem>(get_node);
+ this->VisitExpr(get->tuple);
+ auto tuple_register = last_register;
+ Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
+ }
+
+ void VisitExpr_(const GlobalVarNode* gvar) {
+ LOG(FATAL) << "Global variables should only appear in the call position";
+ }
+
+ void VisitExpr_(const IfNode* if_node) {
+ this->VisitExpr(if_node->cond);
+
+ size_t cond_register = last_register;
+
+ auto after_cond = this->instructions.size();
+
+ this->Emit(Instruction::If(cond_register, 0, 0));
+ this->VisitExpr(if_node->true_branch);
+
+ size_t true_register = last_register;
+
+ Emit(Instruction::Goto(0));
+
+ // Finally store how many instructions there are in the
+ // true branch.
+ auto after_true = this->instructions.size();
+
+ this->VisitExpr(if_node->false_branch);
+
+ size_t false_register = last_register;
+
+ // Compute the total number of instructions
+ // after generating false.
+ auto after_false = this->instructions.size();
+
+ // Now we will compute the jump targets in order
+ // to properly patch the instruction with the
+ // the requiste targets.
+
+ // After we emit the true body, and false body,
+ // we patch up the if instruction, and goto.
+ auto true_offset = 1;
+ auto false_offset = after_true - after_cond;
+ this->instructions[after_cond].true_offset = true_offset;
+ this->instructions[after_cond].false_offset = false_offset;
+
+ // Patch the Goto.
+ this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1;
+
+ Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister()));
+ }
+
+ Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
+ DataType dtype = ttype->dtype;
+ TVMType dltype = Type2TVMType(dtype);
+
+ auto tensor_type = GetRef<TensorType>(ttype);
+ auto it = this->context->const_tensor_shape_map.find(tensor_type);
+ if (it == this->context->const_tensor_shape_map.end()) {
+ DLOG(INFO) << "Can not find constant shape for " << tensor_type;
+ } else {
+ Emit(Instruction::LoadConst(it->second.first, NewRegister()));
+ }
+
+ return Instruction::AllocTensor(last_register, dltype, NewRegister());
+ }
+
+ void EmitInvokePrimitive(const Function& func, std::vector<Index> args_registers,
+ const Type& ret_type) {
+ std::vector<Instruction> allocs;
+ size_t return_num = 0;
+ if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+ // Allocate space for the return tensor.
+ auto alloc = AllocTensorFromType(ttype);
+ allocs.push_back(alloc);
+ return_num = 1;
+ } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+ std::vector<Index> fields_registers;
+
+ for (size_t i = 0; i < ttype->fields.size(); ++i) {
+ auto f = ttype->fields[i];
+ auto f_type = f.as<TensorTypeNode>();
+ allocs.push_back(AllocTensorFromType(f_type));
+ fields_registers.push_back(allocs.back().dst);
+ }
+ return_num = ttype->fields.size();
+ } else {
+ LOG(FATAL) << "Unsupported return value type";
+ }
+
+ for (auto& alloc : allocs) {
+ Emit(alloc);
+ args_registers.push_back(alloc.dst);
+ }
+
+ // Next generate the invoke instruction.
+ CHECK(func->IsPrimitive());
+ auto target = Target::create("llvm");
+ auto key = CCacheKeyNode::make(func, target);
+ auto cfunc = engine->Lower(key);
+ // TODO(jroesch): support lowered funcs for multiple targets
+ CHECK_EQ(cfunc->funcs.size(), 1);
+ auto op_index = -1;
+ if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) {
+ op_index = this->context->lowered_funcs.size();
+ this->context->lowered_funcs.push_back(cfunc->funcs[0]);
+ seen_funcs[cfunc->funcs[0]] = op_index;
+ } else {
+ op_index = seen_funcs[cfunc->funcs[0]];
+ }
+
+ // If Tensor, 1
+ // If Tuple, size of tuple
+ size_t arity = func->params.size() + return_num;
+ Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers));
+ if (return_num > 1) {
+ // return value is a tuple, we need to create a tuple
+ std::vector<Index> fields_registers;
+ for (size_t i = func->params.size(); i < arity; ++i) {
+ fields_registers.push_back(args_registers[i]);
+ }
+ Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister()));
+ }
+ }
+
+ void VisitExpr_(const CallNode* call_node) {
+ std::vector<Index> args_registers;
+
+ for (auto arg : call_node->args) {
+ CHECK(arg.as<VarNode>()) << "found: " << AsText(arg, false) << std::endl << arg;
+ this->VisitExpr(arg);
+ args_registers.push_back(last_register);
+ }
+
+ Expr op = call_node->op;
+
+ if (auto func_node = op.as<FunctionNode>()) {
+ CHECK(func_node->IsPrimitive());
+ EmitInvokePrimitive(GetRef<Function>(func_node), args_registers, call_node->checked_type());
+ } else if (auto global_node = op.as<GlobalVarNode>()) {
+ auto global = GetRef<GlobalVar>(global_node);
+ auto it = this->context->global_map.find(global);
+ CHECK(it != this->context->global_map.end());
+ DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
+ << " with func_index=" << it->second;
+
+ auto func = this->context->module->Lookup(global);
+ if (IsClosure(func)) {
+ auto arity = func->params.size();
+ std::vector<Index> free_var_registers;
+ for (size_t i = 0; i < arity; ++i) {
+ free_var_registers.push_back(var_register_map.at(func->params[i]));
+ }
+ Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister()));
+ } else {
+ Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
+ }
+ } else if (auto constructor_node = op.as<ConstructorNode>()) {
+ auto constructor = GetRef<Constructor>(constructor_node);
+ auto tag = GetConstructorTag(constructor);
+ Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister()));
+ } else if (auto var_node = op.as<VarNode>()) {
+ VisitExpr(GetRef<Var>(var_node));
+ Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister()));
+ } else {
+ LOG(FATAL) << "unsupported case in vm compiler: " << op;
+ }
+ }
+
+ size_t GetConstructorTag(tvm::relay::Constructor constructor) {
+ auto it = this->context->tag_map.find(constructor);
+ if (it != this->context->tag_map.end()) {
+ return it->second;
+ } else {
+ auto tag = this->context->tag_map.size();
+ this->context->tag_map[constructor] = tag;
+ this->context->tag_index_map[tag] = constructor;
+ return tag;
+ }
+ }
+
+ void VisitExpr_(const FunctionNode* func_node) {
+ if (!func_node->IsPrimitive()) {
+ LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
+ << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
+ << "AST: " << GetRef<Function>(func_node);
+ }
+ }
+
+ void CompileClosure(const Function& func) {
+ // We first layout the function arguments.
+ auto inner_func = Downcast<Function>(func->body);
+
+ size_t i = 0;
+ for (auto param : inner_func->params) {
+ auto arg_register = NewRegister();
+ CHECK_EQ(i, arg_register);
+ var_register_map.insert({param, arg_register});
+ i++;
+ }
+
+ // We then assign register num to the free variables
+ for (auto param : func->params) {
+ auto arg_register = NewRegister();
+ CHECK_EQ(i, arg_register);
+ var_register_map.insert({param, arg_register});
+ i++;
+ }
+
+ // We will now process the body like normal.
+ this->VisitExpr(inner_func->body);
+ }
+
+ void Compile(const Function& func) {
+ // We need to generate code specially for lifted closures.
+ if (IsClosure(func)) {
+ CompileClosure(func);
+ return;
+ }
+
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ auto arg_register = NewRegister();
+ CHECK_EQ(arg_register, i);
+ var_register_map.insert({func->params[i], arg_register});
+ }
+
+ this->VisitExpr(func->body);
+ }
+};
+
+void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
+ std::vector<PackedFunc>* packed_funcs) {
+ runtime::Module mod;
+ if (lowered_funcs.size() > 0) {
+ // TODO(@jroesch): we need to read target from build config
+ Target target = Target::create("llvm");
+ if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+ mod = (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target);
+ } else {
+ LOG(FATAL) << "relay.backend.build is not registered";
+ }
+ CHECK(mod.operator->());
+ for (auto lfunc : lowered_funcs) {
+ packed_funcs->push_back(mod.GetFunction(lfunc->name));
+ }
+ }
+}
+
+VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) {
+ DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl;
+ size_t params = func->params.size();
+ VMCompiler compiler(context);
+ compiler.Compile(func);
+ // return the last evaluated expression
+ compiler.instructions.push_back(Instruction::Ret(compiler.last_register));
+
+ // Would like to refactor this so we only check if closure once.
+ if (IsClosure(func)) {
+ auto inner_params = Downcast<Function>(func->body)->params.size();
+ return VMFunction(var->name_hint, params + inner_params, compiler.instructions,
+ compiler.registers_num);
+ } else {
+ return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num);
+ }
+}
+
+Module OptimizeModule(const Module& mod) {
+ ToANormalForm(mod->entry_func, mod);
+ InlinePrimitives(mod);
+ LambdaLift(mod);
+ return InlinePrimitives(mod);
+}
+
+void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) {
+ // First we populate global map.
+ size_t global_index = 0;
+ for (auto named_func : mod->functions) {
+ auto gvar = named_func.first;
+ global_map->insert({gvar, global_index++});
+ }
+}
+
+VirtualMachine CompileModule(const Module& mod_ref) {
+ Module mod = mod_ref;
+
+ // Run some optimizations first, this code should
+ // be moved to pass manager.
+ mod = OptimizeModule(mod);
+
+ VirtualMachine vm;
+
+ VMCompilerContext context;
+ context.module = mod;
+
+ // Populate the global map.
+ //
+ // This maps global variables to a global index
+ // in the VMFunction table.
+ PopulateGlobalMap(&context.global_map, mod);
+
+ // Next we populate constant map.
+ auto constant_analysis_result = LayoutConstantPool(mod);
+ context.const_map = std::get<0>(constant_analysis_result);
+ context.const_tensor_shape_map = std::get<1>(constant_analysis_result);
+
+ // Next we get ready by allocating space for
+ // the global state.
+ vm.functions.resize(mod->functions.size());
+ vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size());
+
+ for (auto pair : context.const_map) {
+ vm.constants[pair.second] = Object::Tensor(pair.first->data);
+ }
+
+ for (auto pair : context.const_tensor_shape_map) {
+ vm.constants[pair.second.first] = Object::Tensor(pair.second.second);
+ }
+
+ for (auto named_func : mod->functions) {
+ auto gvar = named_func.first;
+ auto func = named_func.second;
+ auto vm_func = CompileFunc(&context, gvar, func);
+
+ size_t func_index = context.global_map.at(gvar);
+ CHECK(func_index < vm.functions.size());
+ vm.functions[func_index] = vm_func;
+ }
+
+#ifdef USE_RELAY_DEBUG
+ for (auto vm_func : vm.functions) {
+ std::cout << "Function: " << vm_func.name << std::endl
+ << vm_func << "-------------" << std::endl;
+ }
+#endif // USE_RELAY_DEBUG
+
+ PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs);
+
+ for (auto gv : context.global_map) {
+ vm.global_map_.insert({gv.first->name_hint, gv.second});
+ }
+
+ return vm;
+}
+
+} // namespace vm
+} // namespace relay
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/inline_primitives.cc
+ * \brief Ensure that primitives only appear in the call position.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <vector>
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+struct PrimitiveInliner : ExprMutator {
+ Module module_;
+ std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map;
+
+ explicit PrimitiveInliner(const Module& module) : module_(module) {}
+
+ Expr VisitExpr_(const LetNode* let_node) {
+ var_map.insert({let_node->var, VisitExpr(let_node->value)});
+ return ExprMutator::VisitExpr_(let_node);
+ }
+
+ Expr VisitExpr_(const CallNode* call) {
+ Expr op = call->op;
+ // For now just collapse the chain of variables to see if
+ // they point to a primitive function.
+ const VarNode* var_node;
+
+ // Collapse a chain of let bindings
+ //
+ // let x = fn (..) { .. };
+ // let y = x
+ // let w = y
+ // in w(...)
+ while ((var_node = op.as<VarNode>())) {
+ auto var = GetRef<Var>(var_node);
+ DLOG(INFO) << "Var: " << var << std::endl;
+ auto it = var_map.find(GetRef<Var>(var_node));
+ if (it != var_map.end()) {
+ op = it->second;
+ } else {
+ return ExprMutator::VisitExpr_(call);
+ }
+ }
+
+ if (auto func = op.as<FunctionNode>()) {
+ if (func->IsPrimitive()) {
+ return CallNode::make(GetRef<Function>(func), call->args, call->attrs, call->type_args);
+ }
+ }
+
+ if (auto global = op.as<GlobalVarNode>()) {
+ return CallNode::make(GetRef<GlobalVar>(global), call->args, call->attrs, call->type_args);
+ }
+
+ return ExprMutator::VisitExpr_(call);
+ }
+
+ Expr VisitExpr_(const FunctionNode* func) {
+ if (func->IsPrimitive()) {
+ return GetRef<Function>(func);
+ } else {
+ return ExprMutator::VisitExpr_(func);
+ }
+ }
+
+ Function Inline(const Function& func) {
+ DLOG(INFO) << "Before inlining primitives: " << std::endl
+ << "func= " << AsText(func, false) << std::endl;
+
+ auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
+ func->type_params, func->attrs);
+
+ inlined = Downcast<Function>(DeadCodeElimination(inlined));
+
+ DLOG(INFO) << "After inlining primitives" << std::endl
+ << "after_func= " << AsText(inlined, false) << std::endl;
+ return inlined;
+ }
+};
+
+// TODO(@jroesch): write verifier
+
+/* This pass will eliminate primitives which have been lifted by the ANF
+ * transform inlining them directly into call sites.
+ *
+ * This makes VM related code generation easier as the call target is always
+ * a primitive function.
+ *
+ * let prim = fn(...) { ... };
+ * prim(...)
+ *
+ * will become:
+ *
+ * (fn(...) { ... })(...)
+ */
+Module InlinePrimitives(const Module& module) {
+ PrimitiveInliner inliner(module);
+
+ tvm::Map<GlobalVar, Function> updates;
+
+ // There is an ordering bug here.
+ for (auto pair : module->functions) {
+ auto global = pair.first;
+ auto func = pair.second;
+ updates.Set(global, inliner.Inline(func));
+ }
+
+ for (auto pair : updates) {
+ module->Add(pair.first, pair.second, true);
+ }
+
+ return module;
+}
+
+} // namespace vm
+} // namespace relay
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file tvm/relay/backend/vm/lambda_lift.cc
+ * \brief Lift all local functions into global functions.
+ */
+
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/logging.h>
+#include <tvm/relay/pass.h>
+#include <tvm/runtime/vm.h>
+#include <iostream>
+#include <vector>
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+static const char* kIsClosure = "IsClosure";
+
+inline std::string GenerateName(const Function& func) {
+ size_t hash = StructuralHash()(func);
+ return std::string("lifted_name") + std::to_string(hash);
+}
+
+bool IsClosure(const Function& func) {
+ NodeRef res = FunctionGetAttr(func, kIsClosure);
+ const ir::IntImm* pval = res.as<ir::IntImm>();
+ return pval && pval->value != 0;
+}
+
+Function MarkClosure(const Function& func) {
+ return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
+}
+
+struct LambdaLifter : ExprMutator {
+ Module module_;
+ std::vector<std::pair<GlobalVar, Function>> lifted_;
+ explicit LambdaLifter(const Module& module) : module_(module) {}
+
+ Expr VisitExpr_(const FunctionNode* func_node) final {
+ auto func = GetRef<Function>(func_node);
+
+ // We should not transform primitive functions.
+ if (func->IsPrimitive()) {
+ return std::move(func);
+ }
+
+ auto free_vars = FreeVars(func);
+ auto free_type_vars = FreeTypeVars(func, module_);
+ auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
+
+ // When performing this optimization there are two
+ // cases.
+ //
+ // The first case in which we have no free variables
+ // we can just lift the function into the global
+ // environment without needing to allocate a closure.
+ //
+ //
+ // The second case requires that we generate a special
+ // function with makes a distinction between allocating
+ // a closure, and then the code for the closure.
+ //
+ // We represent a closure allocation by lifting the
+ // closure to a global function which takes its
+ // captured arguments and then directly returns
+ // the function representing the closure's code.
+ //
+ // When we generate code later on a call to the "outer"
+ // function marked as a closure is used to emit allocation
+ // code for the closure's environment.
+ //
+ // The "inner" function is should be used to generate the
+ // code for the closure.
+ Function lifted_func;
+ if (free_vars.size() == 0) {
+ lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
+ } else {
+ lifted_func =
+ FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
+
+ lifted_func = MarkClosure(lifted_func);
+ }
+
+ CHECK(lifted_func.defined());
+
+ auto name = GenerateName(lifted_func);
+ auto global = this->module_->GetGlobalVar(name);
+
+ lifted_.push_back({global, lifted_func});
+
+ if (free_vars.size() == 0) {
+ return std::move(global);
+ } else {
+ // If we need to allocate a closure
+ // we pass the variables in its environment
+ // here.
+ Array<Expr> fvs;
+ for (auto fv : free_vars) {
+ fvs.push_back(fv);
+ }
+ return CallNode::make(global, fvs);
+ }
+ }
+
+ Function Lift(const Function& func) {
+ DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl;
+ return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type,
+ func->type_params, func->attrs);
+ }
+};
+
+/* The goal of this pass is to lift out any nested functions into top-level
+ * functions.
+ *
+ * We will lift the functions out into globals which take the set of the free vars
+ * and then return a function whcih has b
+ */
+Module LambdaLift(const Module& module) {
+ LambdaLifter lifter(module);
+
+ tvm::Map<GlobalVar, Function> updates;
+
+ // There is an ordering bug here.
+ for (auto pair : module->functions) {
+ auto global = pair.first;
+ auto func = pair.second;
+ updates.Set(global, lifter.Lift(func));
+ }
+
+ for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) {
+ module->Add(i->first, i->second);
+ }
+
+ for (auto pair : updates) {
+ module->Add(pair.first, pair.second, true);
+ }
+
+ return module;
+}
+
+} // namespace vm
+} // namespace relay
+} // namespace tvm
--- /dev/null
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file src/relay/backend/vm/vm.cc
+ * \brief The Relay virtual machine.
+ */
+
+#include <tvm/relay/interpreter.h>
+#include <tvm/logging.h>
+#include <tvm/relay/module.h>
+#include <tvm/runtime/vm.h>
+#include <tvm/relay/pass.h>
+
+namespace tvm {
+namespace relay {
+namespace vm {
+
+using tvm::runtime::Object;
+using tvm::runtime::ObjectTag;
+using tvm::runtime::vm::VirtualMachine;
+
+
+VirtualMachine FromModule(const Module& module, const std::vector<TVMContext>& ctxs) {
+ auto vm = CompileModule(module);
+ vm.Init(ctxs);
+ return vm;
+}
+
+Object EvaluateModule(const Module& module, const std::vector<TVMContext> ctxs,
+ const std::vector<Object>& vm_args) {
+ VirtualMachine vm = FromModule(module, ctxs);
+ // TODO(zhiics): This measurement is for temporary usage. Remove it later. We
+ // need to introduce a better profiling method.
+#if ENABLE_PROFILING
+ DLOG(INFO) << "Entry function is " << module->entry_func << std::endl;
+ auto start = std::chrono::high_resolution_clock::now();
+#endif // ENABLE_PROFILING
+ Object res = vm.Invoke(module->entry_func->name_hint, vm_args);
+#if ENABLE_PROFILING
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
+ LOG(INFO) << "Inference time: " << duration << "ms\n";
+#endif // ENABLE_PROFILING
+ return res;
+}
+
+Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) {
+ CHECK(module.defined() && type.defined());
+ switch (obj->tag) {
+ case ObjectTag::kTensor: {
+ CHECK(type.as<TensorTypeNode>()) << "VM internal error: return value must be a tensor";
+ return TensorValueNode::make(ToNDArray(obj));
+ }
+ case ObjectTag::kDatatype: {
+ // const auto* tuple_type
+ // const auto& data_type = obj.AsDatatype();
+
+ // tvm::Array<Value> fields;
+ // for (size_t i = 0; i < data_type->fields.size(); ++i) {
+ // fields.push_back(VMToValue(tag_index_map, data_type->fields[i]));
+ // }
+
+ // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields);
+ LOG(FATAL) << "fix me";
+ }
+ default:
+ LOG(FATAL) << "unsupported return value of type: " << obj->tag;
+ return Value();
+ }
+}
+
+TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) {
+ *ret = Object::Tensor(args[0]);
+});
+
+TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) {
+ std::vector<Object> fields;
+ for (auto i = 0; i < args.size(); i++) {
+ fields.push_back(args[i]);
+ }
+ *ret = Object::Tuple(fields);
+});
+
+template <typename T>
+std::string ToString(const T& t) {
+ std::stringstream s;
+ s << t;
+ return s.str();
+}
+
+TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) {
+ Object obj = args[0];
+ *ret = ToString(obj->tag);
+});
+
+TVM_REGISTER_API("relay._vm._Datatype")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+ int itag = args[0];
+ size_t tag = static_cast<size_t>(itag);
+ std::vector<Object> fields;
+ for (int i = 1; i < args.size(); i++) {
+ fields.push_back(args[i]);
+ }
+
+ *ret = Object::Datatype(tag, fields);
+});
+
+TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) {
+ NodeRef to_compile = args[0];
+ TVMContext ctx;
+ int dev_type = args[1];
+ ctx.device_type = static_cast<DLDeviceType>(dev_type);
+ ctx.device_id = args[2];
+
+ Module module;
+ if (to_compile.as<FunctionNode>()) {
+ Function to_compile = args[0];
+ module = ModuleNode::FromExpr(to_compile);
+ } else if (to_compile.as<ModuleNode>()) {
+ module = args[0];
+ } else {
+ LOG(FATAL) << "expected function or module";
+ }
+
+ auto return_type = module->Lookup(module->entry_func)->ret_type;
+
+ std::vector<Object> vm_args;
+ for (auto i = 3; i < args.size(); i++) {
+ Object obj = args[i];
+ vm_args.push_back(obj);
+ }
+
+ auto result = EvaluateModule(module, {ctx}, vm_args);
+ DLOG(INFO) << "Evaluate VM returning: result=" << result->tag;
+ *ret = VMToValue(module, return_type, result);
+});
+
+} // namespace vm
+} // namespace relay
+} // namespace tvm
F f) {
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
+ if (inputs[0]->shape.size() == 0) {
+ return { topi::identity(inputs[0]) };
+ }
auto axes = param->axis;
if (param->exclude) {
axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
- CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
friend CalcDep;
bool HasLet(const Var& v) {
- return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
+ // TODO(@jroesch): MK fix me
+ return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
}
Expr VisitExpr_(const VarNode* op) final {
}
}
+template<typename T>
+static inline void FreeIf(T* t) {
+ if (t != nullptr) {
+ delete t;
+ }
+}
+
+Instruction& Instruction::operator=(const Instruction& instr) {
+ this->op = instr.op;
+ this->dst = instr.dst;
+
+ switch (instr.op) {
+ case Opcode::Move:
+ this->from = instr.from;
+ return *this;
+ case Opcode::Select:
+ this->select_cond = instr.select_cond;
+ this->select_op1 = instr.select_op1;
+ this->select_op2 = instr.select_op2;
+ return *this;
+ case Opcode::Ret:
+ this->result = instr.result;
+ return *this;
+ case Opcode::AllocTensor:
+ this->shape_register = instr.shape_register;
+ this->dtype = instr.dtype;
+ return *this;
+ case Opcode::AllocDatatype:
+ this->constructor_tag = instr.constructor_tag;
+ this->num_fields = instr.num_fields;
+ FreeIf(this->datatype_fields);
+ this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
+ return *this;
+ case Opcode::AllocClosure:
+ this->clo_index = instr.clo_index;
+ this->num_freevar = instr.num_freevar;
+ FreeIf(this->free_vars);
+ this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
+ return *this;
+ case Opcode::InvokePacked:
+ this->packed_index = instr.packed_index;
+ this->arity = instr.arity;
+ this->output_size = instr.output_size;
+ FreeIf(this->packed_args);
+ this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
+ return *this;
+ case Opcode::InvokeClosure:
+ this->closure = instr.closure;
+ this->closure_args_num = instr.closure_args_num;
+ FreeIf(this->closure_args);
+ this->closure_args = Duplicate<RegName>(instr.closure_args, instr.closure_args_num);
+ return *this;
+ case Opcode::Invoke:
+ this->func_index = instr.func_index;
+ this->num_args = instr.num_args;
+ FreeIf(this->invoke_args_registers);
+ this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
+ return *this;
+ case Opcode::If:
+ this->if_cond = instr.if_cond;
+ this->true_offset = instr.true_offset;
+ this->false_offset = instr.false_offset;
+ return *this;
+ case Opcode::LoadConst:
+ this->const_index = instr.const_index;
+ return *this;
+ case Opcode::GetField:
+ this->object = instr.object;
+ this->field_index = instr.field_index;
+ return *this;
+ case Opcode::Goto:
+ this->pc_offset = instr.pc_offset;
+ return *this;
+ default:
+ std::ostringstream out;
+ out << "Invalid instruction " << static_cast<int>(instr.op);
+ throw std::runtime_error(out.str());
+ }
+}
+
Instruction::~Instruction() {
switch (this->op) {
case Opcode::Move:
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmarking Relay VM using models from MXNet."""
+import numpy as np
+
+import tvm
+from tvm.contrib import graph_runtime
+from tvm import relay
+from tvm.relay import testing
+
+
+def benchmark_execution(net,
+ params,
+ measure=False,
+ data_shape=(1, 3, 224, 224),
+ out_shape=(1, 1000),
+ dtype='float32'):
+ def get_tvm_output(net, data, params, target, ctx, dtype='float32'):
+ with relay.build_config(opt_level=1):
+ graph, lib, params = relay.build(net, target, params=params)
+
+ m = graph_runtime.create(graph, lib, ctx)
+ # set inputs
+ m.set_input("data", data)
+ m.set_input(**params)
+ m.run()
+ out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
+
+ if measure:
+ print("Evaluate graph runtime inference time cost...")
+ ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20)
+ # Measure in millisecond.
+ prof_res = np.array(ftimer().results) * 1000
+ print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
+ (np.mean(prof_res), np.std(prof_res)))
+
+ return out.asnumpy()
+
+ def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'):
+ ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
+ result = ex.evaluate(net)(data, **params)
+ return result.asnumpy().astype(dtype)
+
+ # random input
+ data = np.random.uniform(size=data_shape).astype(dtype)
+ target = "llvm"
+ ctx = tvm.cpu(0)
+
+ tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params,
+ target, ctx, dtype)
+ vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params,
+ target, ctx, dtype)
+ tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+def test_mlp():
+ image_shape = (1, 28, 28)
+ net, params = testing.mlp.get_workload(1)
+ benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10))
+
+
+def test_vgg():
+ for n in [11, 16]:
+ net, params = testing.vgg.get_workload(1, num_layers=n)
+ benchmark_execution(net, params)
+
+
+def test_resnet():
+ for n in [18, 50]:
+ net, params = testing.resnet.get_workload(batch_size=1, num_layers=n)
+ benchmark_execution(net, params, True)
+
+
+def test_squeezenet():
+ for version in ['1.0', '1.1']:
+ net, params = testing.squeezenet.get_workload(version=version)
+ benchmark_execution(net, params)
+
+
+def test_inception_v3():
+ image_shape = (3, 299, 299)
+ net, params = testing.inception_v3.get_workload(image_shape=image_shape)
+ benchmark_execution(net, params, data_shape=image_shape)
+
+
+def test_dqn():
+ image_shape = (4, 84, 84)
+ net, params = testing.dqn.get_workload(
+ batch_size=1, image_shape=image_shape)
+ benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18))
+
+
+def test_dcgan():
+ image_shape = (1, 100)
+ net, params = testing.dcgan.get_workload(batch_size=1)
+ benchmark_execution(net, params, data_shape=image_shape)
+
+
+def test_mobilenet():
+ net, params = testing.mobilenet.get_workload(batch_size=1)
+ benchmark_execution(net, params)
+
+
+def test_densenet():
+ net, params = testing.densenet.get_workload(batch_size=1)
+ benchmark_execution(net, params)
+
+
+if __name__ == '__main__':
+ test_resnet()
+ test_vgg()
+ test_squeezenet()
+ test_mobilenet()
+ test_densenet()
+ # The following networks fail
+ # test_inception_v3()
+ # test_mlp()
+ # test_dqn()
+ # test_dcgan()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from nose.tools import nottest
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay.scope_builder import ScopeBuilder
+from tvm.relay.prelude import Prelude
+
+def veval(f, *args, ctx=tvm.cpu()):
+ if isinstance(f, relay.Expr):
+ ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx)
+ if len(args) == 0:
+ return ex.evaluate(f)
+ else:
+ return ex.evaluate(f)(*args)
+ else:
+ assert isinstance(f, relay.Module), "expected expression or module"
+ mod = f
+ ex = relay.create_executor('vm', mod=mod, ctx=ctx)
+ if len(args) == 0:
+ return ex.evaluate(mod[mod.entry_func])
+ else:
+ return ex.evaluate(mod[mod.entry_func])(*args)
+
+def test_split():
+ x = relay.var('x', shape=(12,))
+ y = relay.split(x, 3, axis=0).astuple()
+ z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0)
+ f = relay.Function([x], z)
+
+ x_data = np.random.rand(12,).astype('float32')
+ res = veval(f, x_data)
+ tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+
+def test_id():
+ x = relay.var('x', shape=(10, 10))
+ f = relay.Function([x], x)
+ x_data = np.random.rand(10, 10).astype('float64')
+ res = veval(f, x_data)
+ tvm.testing.assert_allclose(res.asnumpy(), x_data)
+
+def test_op():
+ x = relay.var('x', shape=(10, 10))
+ f = relay.Function([x], x + x)
+ x_data = np.random.rand(10, 10).astype('float32')
+ res = veval(f, x_data)
+ tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
+
+def any(x):
+ x = relay.op.nn.batch_flatten(x)
+ return relay.op.min(x, axis=[0, 1])
+
+def test_cond():
+ x = relay.var('x', shape=(10, 10))
+ y = relay.var('x', shape=(10, 10))
+ # f = relay.Function([x, y], relay.op.equal(x, y))
+ f = relay.Function([x, y], any(relay.op.equal(x, y)))
+ x_data = np.random.rand(10, 10).astype('float32')
+ y_data = np.random.rand(10, 10).astype('float32')
+
+ # same
+ res = veval(f, x_data, x_data)
+ np.testing.assert_allclose(res.asnumpy(), True)
+
+ # diff
+ res = veval(f, x_data, y_data)
+ tvm.testing.assert_allclose(res.asnumpy(), False)
+
+
+def test_simple_if():
+ x = relay.var('x', shape=(10, 10))
+ y = relay.var('y', shape=(10, 10))
+ f = relay.Function([x, y],
+ relay.If(any(relay.op.equal(x, y)), x, y))
+ x_data = np.random.rand(10, 10).astype('float32')
+ y_data = np.random.rand(10, 10).astype('float32')
+
+ # same
+ res = veval(f, x_data, x_data)
+ tvm.testing.assert_allclose(res.asnumpy(), x_data)
+
+ # diff
+ res = veval(f, x_data, y_data)
+ tvm.testing.assert_allclose(res.asnumpy(), y_data)
+
+def test_simple_call():
+ mod = relay.module.Module({})
+ sum_up = relay.GlobalVar('sum_up')
+ i = relay.var('i', shape=[], dtype='int32')
+ sb = ScopeBuilder()
+ sb.ret(i)
+ func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+ mod[sum_up] = func
+ i_data = np.array(0, dtype='int32')
+ iarg = relay.var('i', shape=[], dtype='int32')
+ mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+ result = veval(mod, i_data)
+ tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_count_loop():
+ mod = relay.module.Module({})
+ sum_up = relay.GlobalVar('sum_up')
+ i = relay.var('i', shape=[], dtype='int32')
+ sb = ScopeBuilder()
+ with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
+ sb.ret(i)
+ with sb.else_scope():
+ one_less = relay.subtract(i, relay.const(1, dtype='int32'))
+ rec_call = relay.Call(sum_up, [one_less])
+ sb.ret(relay.add(rec_call, i))
+ func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32'))
+ mod[sum_up] = func
+ i_data = np.array(0, dtype='int32')
+ iarg = relay.var('i', shape=[], dtype='int32')
+ mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg))
+ result = veval(mod, i_data)
+ tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_sum_loop():
+ mod = relay.module.Module({})
+ sum_up = relay.GlobalVar('sum_up')
+ i = relay.var('i', shape=[], dtype='int32')
+ accum = relay.var('accum', shape=[], dtype='int32')
+ sb = ScopeBuilder()
+ with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))):
+ sb.ret(accum)
+ with sb.else_scope():
+ one_less = relay.subtract(i, relay.const(1, 'int32'))
+ new_accum = relay.add(accum, i)
+ sb.ret(relay.Call(sum_up, [one_less, new_accum]))
+ func = relay.Function([i, accum], sb.get())
+ mod[sum_up] = func
+ loop_bound = 0
+ i_data = np.array(loop_bound, dtype='int32')
+ accum_data = np.array(0, dtype='int32')
+ iarg = relay.var('i', shape=[], dtype='int32')
+ aarg = relay.var('accum', shape=[], dtype='int32')
+ mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg))
+ result = veval(mod, i_data, accum_data)
+ tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1)))
+
+def test_tuple_fst():
+ ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
+ tup = relay.var('tup', type_annotation=ttype)
+ f = relay.Function([tup], relay.TupleGetItem(tup, 0))
+ i_data = np.random.rand(41).astype('float32')
+ j_data = np.random.rand(10).astype('float32')
+ result = veval(f, (i_data, j_data))
+ tvm.testing.assert_allclose(result.asnumpy(), i_data)
+
+def test_tuple_second():
+ ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))])
+ tup = relay.var('tup', type_annotation=ttype)
+ f = relay.Function([tup], relay.TupleGetItem(tup, 1))
+ i_data = np.random.rand(41).astype('float32')
+ j_data = np.random.rand(10).astype('float32')
+ result = veval(f, (i_data, j_data))
+ tvm.testing.assert_allclose(result.asnumpy(), j_data)
+
+@nottest
+def test_list_constructor():
+ # TODO(wweic): implement pattern match to support this test
+ def to_list(o):
+ if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
+ return [o.data.asnumpy().tolist()]
+ if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
+ result = []
+ for f in o.fields:
+ result.extend(to_list(f))
+ return result
+
+ mod = relay.Module()
+ p = Prelude(mod)
+
+ nil = p.nil
+ cons = p.cons
+ l = p.l
+
+ one2 = cons(relay.const(1), nil())
+ one3 = cons(relay.const(2), one2)
+ one4 = cons(relay.const(3), one3)
+ f = relay.Function([], one4)
+
+ mod[mod.entry_func] = f
+
+ result = veval(mod)()
+ obj = to_list(result)
+ import pdb; pdb.set_trace()
+ tvm.testing.assert_allclose(obj, np.array([3,2,1]))
+
+def test_let_tensor():
+ sb = relay.ScopeBuilder()
+ shape = (1,)
+ x = relay.var('x', shape=shape, dtype='float32')
+ x1 = relay.var('x1', shape=shape, dtype='float32')
+
+ x1 = sb.let(x1, x)
+ xplusone = x1 + relay.const(42.0, 'float32')
+ sb.ret(xplusone)
+ body = sb.get()
+
+ f = relay.Function([x], body)
+
+ x_data = np.random.rand(*shape).astype('float32')
+ result = veval(f, x_data)
+ tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
+
+def test_let_scalar():
+ sb = relay.ScopeBuilder()
+
+ x = relay.var('x', 'float32')
+ x1 = sb.let('x1', x)
+ xplusone = x1 + relay.const(42.0, 'float32')
+ sb.ret(xplusone)
+ body = sb.get()
+
+ f = relay.Function([x], body)
+
+ x_data = np.array(np.random.rand()).astype('float32')
+ result = veval(f, x_data)
+ tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
+
+def test_closure():
+ x = relay.var('x', shape=())
+ y = relay.var('y', shape=())
+ f = relay.Function([x], x + y)
+ ff = relay.Function([y], f)
+ clo = ff(relay.const(1.0))
+ main = clo(relay.const(2.0))
+ res = veval(main)
+ tvm.testing.assert_allclose(res.asnumpy(), 3.0)
+
+if __name__ == "__main__":
+ test_id()
+ test_op()
+ test_cond()
+ test_simple_if()
+ test_simple_call()
+ test_count_loop()
+ test_sum_loop()
+ test_tuple_fst()
+ test_tuple_second()
+ test_let_scalar()
+ test_let_tensor()
+ # TODO(@jroesch): restore when match is supported
+ # test_list_constructor()
+ test_closure()