From 6a4d71ff40915611bd42b62994992b879e6be610 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 11 May 2019 18:08:13 -0400 Subject: [PATCH] [Relay][Runtime] Add VM compiler. (#3139) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit * Implement the VM compiler * Fix issues * Fix ASF headers * Fix test issue * Apply typo fixes. * Update src/relay/backend/vm/compiler.cc Co-Authored-By: 雾雨魔理沙 * Refactor compiler * Fix * Fix * Fix in benchmark * Fix * Address comments --- include/tvm/relay/pass.h | 13 + include/tvm/runtime/vm.h | 2 +- src/relay/backend/vm/compiler.cc | 616 ++++++++++++++++++++++++ src/relay/backend/vm/inline_primitives.cc | 146 ++++++ src/relay/backend/vm/lambda_lift.cc | 166 +++++++ src/relay/backend/vm/vm.cc | 159 ++++++ src/relay/op/tensor/reduce.cc | 4 +- src/relay/pass/dead_code.cc | 7 +- src/runtime/vm/vm.cc | 80 +++ tests/python/relay/benchmarking/benchmark_vm.py | 133 +++++ tests/python/relay/test_vm.py | 264 ++++++++++ 11 files changed, 1585 insertions(+), 5 deletions(-) create mode 100644 src/relay/backend/vm/compiler.cc create mode 100644 src/relay/backend/vm/inline_primitives.cc create mode 100644 src/relay/backend/vm/lambda_lift.cc create mode 100644 src/relay/backend/vm/vm.cc create mode 100644 tests/python/relay/benchmarking/benchmark_vm.py create mode 100644 tests/python/relay/test_vm.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 43831fc..3106792 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -65,6 +65,7 @@ #include #include #include +#include #include #include @@ -593,6 +594,18 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); * As a side effect, code size will explode. */ Expr PartialEval(const Expr& e); + +namespace vm { + +/*! \brief Compile a module, and construct the virtual machine. + * + * \param mod The module to compile. + * \return The constructed virtual machine. + */ +runtime::vm::VirtualMachine CompileModule(const Module& mod); + +} // namespace vm + } // namespace relay } // namespace tvm diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 0a0a4de..8911ad4 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -265,7 +265,7 @@ struct Instruction { Instruction(); Instruction(const Instruction& instr); - Instruction& operator=(const Instruction& instr) = delete; + Instruction& operator=(const Instruction& instr); ~Instruction(); friend std::ostream& operator<<(std::ostream& os, const Instruction&); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc new file mode 100644 index 0000000..97f03c6 --- /dev/null +++ b/src/relay/backend/vm/compiler.cc @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/compiler.cc + * \brief A compiler from relay::Module to the VM byte code. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../../runtime/vm/naive_allocator.h" +#include "../../backend/compile_engine.h" + +namespace tvm { +namespace relay { +namespace vm { + +using namespace tvm::runtime; +using namespace tvm::runtime::vm; + +// (@jroesch): VM passes, eventually declare as passes. +bool IsClosure(const Function& func); +Module LambdaLift(const Module& module); +Module InlinePrimitives(const Module& module); + +template +using NodeMap = std::unordered_map; +using TagMap = NodeMap; +using TagNameMap = std::unordered_map; +using GlobalMap = NodeMap; +using ConstMap = NodeMap; +using ConstTensorShapeMap = NodeMap>; + +struct VMCompilerContext { + // The module context for the compilation + Module module; + // Error reporter + ErrorReporter err_reporter; + // Map from a unique integer to ADT constructor tag + TagNameMap tag_index_map; + // Map from ADT constructor tag to a unique integer + TagMap tag_map; + // Map from global var to a unique integer + GlobalMap global_map; + // Map from Const object to its index in const pool + ConstMap const_map; + // Map from Const tensor shape to its index in const pool + ConstTensorShapeMap const_tensor_shape_map; + // List of lowered functions + std::vector lowered_funcs; +}; + +// Compute the constant pool, i.e a mapping from Constant node to constant index. +struct ConstantPool : ExprVisitor { + std::set visited; + Module module; + ConstMap const_map; + ConstTensorShapeMap const_tensor_shape_map; + + size_t index; + + explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {} + + void VisitExpr_(const GlobalVarNode* var_node) { + auto gvar = GetRef(var_node); + if (visited.find(gvar) == visited.end()) { + visited.insert(gvar); + this->VisitExpr(this->module->Lookup(gvar)); + } + } + + void AddConstantTensorShape(TensorType expr, NDArray value) { + auto it = this->const_tensor_shape_map.find(expr); + if (it == this->const_tensor_shape_map.end()) { + this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)}); + } + } + + void VisitExpr_(const ConstantNode* const_node) { + auto konst = GetRef(const_node); + auto it = this->const_map.find(konst); + if (it == this->const_map.end()) { + this->const_map.insert({konst, index++}); + } + } + + NDArray GetTensorConstant(const TensorTypeNode* ttype) { + std::vector shapes; + for (auto sh : ttype->shape) { + shapes.push_back(Downcast(sh)->value); + } + int64_t s = shapes.size(); + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx); + int64_t* dims = static_cast(shape_tensor->data); + for (size_t i = 0; i < shapes.size(); ++i) { + dims[i] = shapes[i]; + } + return shape_tensor; + } + + void VisitExpr_(const CallNode* call_node) { + for (auto arg : call_node->args) { + this->VisitExpr(arg); + } + + Expr op = call_node->op; + auto func_node = op.as(); + if (func_node) { + auto ret_type = call_node->checked_type(); + if (const TensorTypeNode* ttype = ret_type.as()) { + auto shape = GetTensorConstant(ttype); + auto tensor_type = GetRef(ttype); + AddConstantTensorShape(tensor_type, shape); + } else if (const TupleTypeNode* ttype = ret_type.as()) { + for (size_t i = 0; i < ttype->fields.size(); ++i) { + auto f = ttype->fields[i]; + auto f_type = f.as(); + auto shape = GetTensorConstant(f_type); + auto tensor_type = GetRef(f_type); + AddConstantTensorShape(tensor_type, shape); + } + } + } + } +}; + +std::tuple LayoutConstantPool(const Module& module) { + auto cp = ConstantPool(module); + for (auto& func : module->functions) { + cp.VisitExpr(func.first); + } + return std::make_tuple(cp.const_map, cp.const_tensor_shape_map); +} + +void InstructionPrint(std::ostream& os, const Instruction& instr); + +struct VMCompiler : ExprFunctor { + /*! \brief Store the expression a variable points to. */ + std::unordered_map expr_map; + + std::vector instructions; + + // var -> register num + std::unordered_map var_register_map; + + size_t last_register; + + // Total number of virtual registers allocated + size_t registers_num; + CompileEngine engine; + + /*! \brief The functions that have been lowered. */ + std::unordered_map seen_funcs; + + /*! \brief Global shared meta data */ + VMCompilerContext* context; + + VMCompiler(VMCompilerContext* context) + : instructions(), + var_register_map(), + last_register(0), + registers_num(0), + engine(CompileEngine::Global()), + context(context) + {} + + size_t NewRegister() { return registers_num++; } + + inline void Emit(const Instruction& instr) { + DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; + CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; + switch (instr.op) { + case Opcode::AllocDatatype: + case Opcode::AllocTensor: + case Opcode::GetField: + case Opcode::LoadConst: + case Opcode::Select: + case Opcode::Invoke: + case Opcode::AllocClosure: + case Opcode::Move: + case Opcode::InvokeClosure: + last_register = instr.dst; + break; + case Opcode::InvokePacked: + last_register = instr.packed_args[instr.arity - 1]; + break; + case Opcode::If: + case Opcode::Ret: + case Opcode::Goto: + break; + } + instructions.push_back(instr); + } + + void VisitExpr_(const ConstantNode* const_node) { + auto rconst = GetRef(const_node); + auto it = this->context->const_map.find(rconst); + CHECK(it != this->context->const_map.end()); + Emit(Instruction::LoadConst(it->second, NewRegister())); + } + + void VisitExpr_(const VarNode* var_node) { + auto var = GetRef(var_node); + auto reg_it = this->var_register_map.find(var); + CHECK(reg_it != this->var_register_map.end()); + last_register = reg_it->second; + } + + void VisitExpr_(const TupleNode* tuple_node) { + auto tuple = GetRef(tuple_node); + std::vector fields_registers; + + for (auto& field : tuple->fields) { + this->VisitExpr(field); + fields_registers.push_back(last_register); + } + + // TODO(@jroesch): use correct tag + Emit(Instruction::AllocDatatype( + 0, + tuple->fields.size(), + fields_registers, + NewRegister())); + } + + void VisitExpr_(const MatchNode* match_node) { + auto match = GetRef(match_node); + LOG(FATAL) << "translation of match nodes to the VM is" + << "currently unsupported" << std::endl; + } + + void VisitExpr_(const LetNode* let_node) { + DLOG(INFO) << let_node->value << std::endl; + this->VisitExpr(let_node->value); + DLOG(INFO) << this->last_register << std::endl; + var_register_map.insert({let_node->var, this->last_register}); + this->VisitExpr(let_node->body); + } + + void VisitExpr_(const TupleGetItemNode* get_node) { + auto get = GetRef(get_node); + this->VisitExpr(get->tuple); + auto tuple_register = last_register; + Emit(Instruction::GetField(tuple_register, get->index, NewRegister())); + } + + void VisitExpr_(const GlobalVarNode* gvar) { + LOG(FATAL) << "Global variables should only appear in the call position"; + } + + void VisitExpr_(const IfNode* if_node) { + this->VisitExpr(if_node->cond); + + size_t cond_register = last_register; + + auto after_cond = this->instructions.size(); + + this->Emit(Instruction::If(cond_register, 0, 0)); + this->VisitExpr(if_node->true_branch); + + size_t true_register = last_register; + + Emit(Instruction::Goto(0)); + + // Finally store how many instructions there are in the + // true branch. + auto after_true = this->instructions.size(); + + this->VisitExpr(if_node->false_branch); + + size_t false_register = last_register; + + // Compute the total number of instructions + // after generating false. + auto after_false = this->instructions.size(); + + // Now we will compute the jump targets in order + // to properly patch the instruction with the + // the requiste targets. + + // After we emit the true body, and false body, + // we patch up the if instruction, and goto. + auto true_offset = 1; + auto false_offset = after_true - after_cond; + this->instructions[after_cond].true_offset = true_offset; + this->instructions[after_cond].false_offset = false_offset; + + // Patch the Goto. + this->instructions[after_true - 1].pc_offset = (after_false - after_true) + 1; + + Emit(Instruction::Select(cond_register, true_register, false_register, NewRegister())); + } + + Instruction AllocTensorFromType(const TensorTypeNode* ttype) { + DataType dtype = ttype->dtype; + TVMType dltype = Type2TVMType(dtype); + + auto tensor_type = GetRef(ttype); + auto it = this->context->const_tensor_shape_map.find(tensor_type); + if (it == this->context->const_tensor_shape_map.end()) { + DLOG(INFO) << "Can not find constant shape for " << tensor_type; + } else { + Emit(Instruction::LoadConst(it->second.first, NewRegister())); + } + + return Instruction::AllocTensor(last_register, dltype, NewRegister()); + } + + void EmitInvokePrimitive(const Function& func, std::vector args_registers, + const Type& ret_type) { + std::vector allocs; + size_t return_num = 0; + if (const TensorTypeNode* ttype = ret_type.as()) { + // Allocate space for the return tensor. + auto alloc = AllocTensorFromType(ttype); + allocs.push_back(alloc); + return_num = 1; + } else if (const TupleTypeNode* ttype = ret_type.as()) { + std::vector fields_registers; + + for (size_t i = 0; i < ttype->fields.size(); ++i) { + auto f = ttype->fields[i]; + auto f_type = f.as(); + allocs.push_back(AllocTensorFromType(f_type)); + fields_registers.push_back(allocs.back().dst); + } + return_num = ttype->fields.size(); + } else { + LOG(FATAL) << "Unsupported return value type"; + } + + for (auto& alloc : allocs) { + Emit(alloc); + args_registers.push_back(alloc.dst); + } + + // Next generate the invoke instruction. + CHECK(func->IsPrimitive()); + auto target = Target::create("llvm"); + auto key = CCacheKeyNode::make(func, target); + auto cfunc = engine->Lower(key); + // TODO(jroesch): support lowered funcs for multiple targets + CHECK_EQ(cfunc->funcs.size(), 1); + auto op_index = -1; + if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) { + op_index = this->context->lowered_funcs.size(); + this->context->lowered_funcs.push_back(cfunc->funcs[0]); + seen_funcs[cfunc->funcs[0]] = op_index; + } else { + op_index = seen_funcs[cfunc->funcs[0]]; + } + + // If Tensor, 1 + // If Tuple, size of tuple + size_t arity = func->params.size() + return_num; + Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers)); + if (return_num > 1) { + // return value is a tuple, we need to create a tuple + std::vector fields_registers; + for (size_t i = func->params.size(); i < arity; ++i) { + fields_registers.push_back(args_registers[i]); + } + Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister())); + } + } + + void VisitExpr_(const CallNode* call_node) { + std::vector args_registers; + + for (auto arg : call_node->args) { + CHECK(arg.as()) << "found: " << AsText(arg, false) << std::endl << arg; + this->VisitExpr(arg); + args_registers.push_back(last_register); + } + + Expr op = call_node->op; + + if (auto func_node = op.as()) { + CHECK(func_node->IsPrimitive()); + EmitInvokePrimitive(GetRef(func_node), args_registers, call_node->checked_type()); + } else if (auto global_node = op.as()) { + auto global = GetRef(global_node); + auto it = this->context->global_map.find(global); + CHECK(it != this->context->global_map.end()); + DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint + << " with func_index=" << it->second; + + auto func = this->context->module->Lookup(global); + if (IsClosure(func)) { + auto arity = func->params.size(); + std::vector free_var_registers; + for (size_t i = 0; i < arity; ++i) { + free_var_registers.push_back(var_register_map.at(func->params[i])); + } + Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister())); + } else { + Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); + } + } else if (auto constructor_node = op.as()) { + auto constructor = GetRef(constructor_node); + auto tag = GetConstructorTag(constructor); + Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister())); + } else if (auto var_node = op.as()) { + VisitExpr(GetRef(var_node)); + Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister())); + } else { + LOG(FATAL) << "unsupported case in vm compiler: " << op; + } + } + + size_t GetConstructorTag(tvm::relay::Constructor constructor) { + auto it = this->context->tag_map.find(constructor); + if (it != this->context->tag_map.end()) { + return it->second; + } else { + auto tag = this->context->tag_map.size(); + this->context->tag_map[constructor] = tag; + this->context->tag_index_map[tag] = constructor; + return tag; + } + } + + void VisitExpr_(const FunctionNode* func_node) { + if (!func_node->IsPrimitive()) { + LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl + << "Program: " << AsText(GetRef(func_node), false) << std::endl + << "AST: " << GetRef(func_node); + } + } + + void CompileClosure(const Function& func) { + // We first layout the function arguments. + auto inner_func = Downcast(func->body); + + size_t i = 0; + for (auto param : inner_func->params) { + auto arg_register = NewRegister(); + CHECK_EQ(i, arg_register); + var_register_map.insert({param, arg_register}); + i++; + } + + // We then assign register num to the free variables + for (auto param : func->params) { + auto arg_register = NewRegister(); + CHECK_EQ(i, arg_register); + var_register_map.insert({param, arg_register}); + i++; + } + + // We will now process the body like normal. + this->VisitExpr(inner_func->body); + } + + void Compile(const Function& func) { + // We need to generate code specially for lifted closures. + if (IsClosure(func)) { + CompileClosure(func); + return; + } + + for (size_t i = 0; i < func->params.size(); ++i) { + auto arg_register = NewRegister(); + CHECK_EQ(arg_register, i); + var_register_map.insert({func->params[i], arg_register}); + } + + this->VisitExpr(func->body); + } +}; + +void PopulatePackedFuncMap(const std::vector& lowered_funcs, + std::vector* packed_funcs) { + runtime::Module mod; + if (lowered_funcs.size() > 0) { + // TODO(@jroesch): we need to read target from build config + Target target = Target::create("llvm"); + if (const auto* f = runtime::Registry::Get("relay.backend.build")) { + mod = (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target); + } else { + LOG(FATAL) << "relay.backend.build is not registered"; + } + CHECK(mod.operator->()); + for (auto lfunc : lowered_funcs) { + packed_funcs->push_back(mod.GetFunction(lfunc->name)); + } + } +} + +VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { + DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl; + size_t params = func->params.size(); + VMCompiler compiler(context); + compiler.Compile(func); + // return the last evaluated expression + compiler.instructions.push_back(Instruction::Ret(compiler.last_register)); + + // Would like to refactor this so we only check if closure once. + if (IsClosure(func)) { + auto inner_params = Downcast(func->body)->params.size(); + return VMFunction(var->name_hint, params + inner_params, compiler.instructions, + compiler.registers_num); + } else { + return VMFunction(var->name_hint, params, compiler.instructions, compiler.registers_num); + } +} + +Module OptimizeModule(const Module& mod) { + ToANormalForm(mod->entry_func, mod); + InlinePrimitives(mod); + LambdaLift(mod); + return InlinePrimitives(mod); +} + +void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { + // First we populate global map. + size_t global_index = 0; + for (auto named_func : mod->functions) { + auto gvar = named_func.first; + global_map->insert({gvar, global_index++}); + } +} + +VirtualMachine CompileModule(const Module& mod_ref) { + Module mod = mod_ref; + + // Run some optimizations first, this code should + // be moved to pass manager. + mod = OptimizeModule(mod); + + VirtualMachine vm; + + VMCompilerContext context; + context.module = mod; + + // Populate the global map. + // + // This maps global variables to a global index + // in the VMFunction table. + PopulateGlobalMap(&context.global_map, mod); + + // Next we populate constant map. + auto constant_analysis_result = LayoutConstantPool(mod); + context.const_map = std::get<0>(constant_analysis_result); + context.const_tensor_shape_map = std::get<1>(constant_analysis_result); + + // Next we get ready by allocating space for + // the global state. + vm.functions.resize(mod->functions.size()); + vm.constants.resize(context.const_map.size() + context.const_tensor_shape_map.size()); + + for (auto pair : context.const_map) { + vm.constants[pair.second] = Object::Tensor(pair.first->data); + } + + for (auto pair : context.const_tensor_shape_map) { + vm.constants[pair.second.first] = Object::Tensor(pair.second.second); + } + + for (auto named_func : mod->functions) { + auto gvar = named_func.first; + auto func = named_func.second; + auto vm_func = CompileFunc(&context, gvar, func); + + size_t func_index = context.global_map.at(gvar); + CHECK(func_index < vm.functions.size()); + vm.functions[func_index] = vm_func; + } + +#ifdef USE_RELAY_DEBUG + for (auto vm_func : vm.functions) { + std::cout << "Function: " << vm_func.name << std::endl + << vm_func << "-------------" << std::endl; + } +#endif // USE_RELAY_DEBUG + + PopulatePackedFuncMap(context.lowered_funcs, &vm.packed_funcs); + + for (auto gv : context.global_map) { + vm.global_map_.insert({gv.first->name_hint, gv.second}); + } + + return vm; +} + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc new file mode 100644 index 0000000..b033a37 --- /dev/null +++ b/src/relay/backend/vm/inline_primitives.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/backend/vm/inline_primitives.cc + * \brief Ensure that primitives only appear in the call position. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm::runtime; + +namespace tvm { +namespace relay { +namespace vm { + +struct PrimitiveInliner : ExprMutator { + Module module_; + std::unordered_map var_map; + + explicit PrimitiveInliner(const Module& module) : module_(module) {} + + Expr VisitExpr_(const LetNode* let_node) { + var_map.insert({let_node->var, VisitExpr(let_node->value)}); + return ExprMutator::VisitExpr_(let_node); + } + + Expr VisitExpr_(const CallNode* call) { + Expr op = call->op; + // For now just collapse the chain of variables to see if + // they point to a primitive function. + const VarNode* var_node; + + // Collapse a chain of let bindings + // + // let x = fn (..) { .. }; + // let y = x + // let w = y + // in w(...) + while ((var_node = op.as())) { + auto var = GetRef(var_node); + DLOG(INFO) << "Var: " << var << std::endl; + auto it = var_map.find(GetRef(var_node)); + if (it != var_map.end()) { + op = it->second; + } else { + return ExprMutator::VisitExpr_(call); + } + } + + if (auto func = op.as()) { + if (func->IsPrimitive()) { + return CallNode::make(GetRef(func), call->args, call->attrs, call->type_args); + } + } + + if (auto global = op.as()) { + return CallNode::make(GetRef(global), call->args, call->attrs, call->type_args); + } + + return ExprMutator::VisitExpr_(call); + } + + Expr VisitExpr_(const FunctionNode* func) { + if (func->IsPrimitive()) { + return GetRef(func); + } else { + return ExprMutator::VisitExpr_(func); + } + } + + Function Inline(const Function& func) { + DLOG(INFO) << "Before inlining primitives: " << std::endl + << "func= " << AsText(func, false) << std::endl; + + auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, + func->type_params, func->attrs); + + inlined = Downcast(DeadCodeElimination(inlined)); + + DLOG(INFO) << "After inlining primitives" << std::endl + << "after_func= " << AsText(inlined, false) << std::endl; + return inlined; + } +}; + +// TODO(@jroesch): write verifier + +/* This pass will eliminate primitives which have been lifted by the ANF + * transform inlining them directly into call sites. + * + * This makes VM related code generation easier as the call target is always + * a primitive function. + * + * let prim = fn(...) { ... }; + * prim(...) + * + * will become: + * + * (fn(...) { ... })(...) + */ +Module InlinePrimitives(const Module& module) { + PrimitiveInliner inliner(module); + + tvm::Map updates; + + // There is an ordering bug here. + for (auto pair : module->functions) { + auto global = pair.first; + auto func = pair.second; + updates.Set(global, inliner.Inline(func)); + } + + for (auto pair : updates) { + module->Add(pair.first, pair.second, true); + } + + return module; +} + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc new file mode 100644 index 0000000..13d8112 --- /dev/null +++ b/src/relay/backend/vm/lambda_lift.cc @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/backend/vm/lambda_lift.cc + * \brief Lift all local functions into global functions. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm::runtime; + +namespace tvm { +namespace relay { +namespace vm { + +static const char* kIsClosure = "IsClosure"; + +inline std::string GenerateName(const Function& func) { + size_t hash = StructuralHash()(func); + return std::string("lifted_name") + std::to_string(hash); +} + +bool IsClosure(const Function& func) { + NodeRef res = FunctionGetAttr(func, kIsClosure); + const ir::IntImm* pval = res.as(); + return pval && pval->value != 0; +} + +Function MarkClosure(const Function& func) { + return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); +} + +struct LambdaLifter : ExprMutator { + Module module_; + std::vector> lifted_; + explicit LambdaLifter(const Module& module) : module_(module) {} + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // We should not transform primitive functions. + if (func->IsPrimitive()) { + return std::move(func); + } + + auto free_vars = FreeVars(func); + auto free_type_vars = FreeTypeVars(func, module_); + auto body = Downcast(ExprMutator::VisitExpr_(func_node)); + + // When performing this optimization there are two + // cases. + // + // The first case in which we have no free variables + // we can just lift the function into the global + // environment without needing to allocate a closure. + // + // + // The second case requires that we generate a special + // function with makes a distinction between allocating + // a closure, and then the code for the closure. + // + // We represent a closure allocation by lifting the + // closure to a global function which takes its + // captured arguments and then directly returns + // the function representing the closure's code. + // + // When we generate code later on a call to the "outer" + // function marked as a closure is used to emit allocation + // code for the closure's environment. + // + // The "inner" function is should be used to generate the + // code for the closure. + Function lifted_func; + if (free_vars.size() == 0) { + lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars); + } else { + lifted_func = + FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars); + + lifted_func = MarkClosure(lifted_func); + } + + CHECK(lifted_func.defined()); + + auto name = GenerateName(lifted_func); + auto global = this->module_->GetGlobalVar(name); + + lifted_.push_back({global, lifted_func}); + + if (free_vars.size() == 0) { + return std::move(global); + } else { + // If we need to allocate a closure + // we pass the variables in its environment + // here. + Array fvs; + for (auto fv : free_vars) { + fvs.push_back(fv); + } + return CallNode::make(global, fvs); + } + } + + Function Lift(const Function& func) { + DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl; + return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, + func->type_params, func->attrs); + } +}; + +/* The goal of this pass is to lift out any nested functions into top-level + * functions. + * + * We will lift the functions out into globals which take the set of the free vars + * and then return a function whcih has b + */ +Module LambdaLift(const Module& module) { + LambdaLifter lifter(module); + + tvm::Map updates; + + // There is an ordering bug here. + for (auto pair : module->functions) { + auto global = pair.first; + auto func = pair.second; + updates.Set(global, lifter.Lift(func)); + } + + for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) { + module->Add(i->first, i->second); + } + + for (auto pair : updates) { + module->Add(pair.first, pair.second, true); + } + + return module; +} + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc new file mode 100644 index 0000000..34d067b --- /dev/null +++ b/src/relay/backend/vm/vm.cc @@ -0,0 +1,159 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/vm.cc + * \brief The Relay virtual machine. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace vm { + +using tvm::runtime::Object; +using tvm::runtime::ObjectTag; +using tvm::runtime::vm::VirtualMachine; + + +VirtualMachine FromModule(const Module& module, const std::vector& ctxs) { + auto vm = CompileModule(module); + vm.Init(ctxs); + return vm; +} + +Object EvaluateModule(const Module& module, const std::vector ctxs, + const std::vector& vm_args) { + VirtualMachine vm = FromModule(module, ctxs); + // TODO(zhiics): This measurement is for temporary usage. Remove it later. We + // need to introduce a better profiling method. +#if ENABLE_PROFILING + DLOG(INFO) << "Entry function is " << module->entry_func << std::endl; + auto start = std::chrono::high_resolution_clock::now(); +#endif // ENABLE_PROFILING + Object res = vm.Invoke(module->entry_func->name_hint, vm_args); +#if ENABLE_PROFILING + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + LOG(INFO) << "Inference time: " << duration << "ms\n"; +#endif // ENABLE_PROFILING + return res; +} + +Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) { + CHECK(module.defined() && type.defined()); + switch (obj->tag) { + case ObjectTag::kTensor: { + CHECK(type.as()) << "VM internal error: return value must be a tensor"; + return TensorValueNode::make(ToNDArray(obj)); + } + case ObjectTag::kDatatype: { + // const auto* tuple_type + // const auto& data_type = obj.AsDatatype(); + + // tvm::Array fields; + // for (size_t i = 0; i < data_type->fields.size(); ++i) { + // fields.push_back(VMToValue(tag_index_map, data_type->fields[i])); + // } + + // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields); + LOG(FATAL) << "fix me"; + } + default: + LOG(FATAL) << "unsupported return value of type: " << obj->tag; + return Value(); + } +} + +TVM_REGISTER_API("relay._vm._Tensor").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Object::Tensor(args[0]); +}); + +TVM_REGISTER_API("relay._vm._Tuple").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector fields; + for (auto i = 0; i < args.size(); i++) { + fields.push_back(args[i]); + } + *ret = Object::Tuple(fields); +}); + +template +std::string ToString(const T& t) { + std::stringstream s; + s << t; + return s.str(); +} + +TVM_REGISTER_API("relay._vm._ObjectTag").set_body([](TVMArgs args, TVMRetValue* ret) { + Object obj = args[0]; + *ret = ToString(obj->tag); +}); + +TVM_REGISTER_API("relay._vm._Datatype") +.set_body([](TVMArgs args, TVMRetValue* ret) { + int itag = args[0]; + size_t tag = static_cast(itag); + std::vector fields; + for (int i = 1; i < args.size(); i++) { + fields.push_back(args[i]); + } + + *ret = Object::Datatype(tag, fields); +}); + +TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef to_compile = args[0]; + TVMContext ctx; + int dev_type = args[1]; + ctx.device_type = static_cast(dev_type); + ctx.device_id = args[2]; + + Module module; + if (to_compile.as()) { + Function to_compile = args[0]; + module = ModuleNode::FromExpr(to_compile); + } else if (to_compile.as()) { + module = args[0]; + } else { + LOG(FATAL) << "expected function or module"; + } + + auto return_type = module->Lookup(module->entry_func)->ret_type; + + std::vector vm_args; + for (auto i = 3; i < args.size(); i++) { + Object obj = args[i]; + vm_args.push_back(obj); + } + + auto result = EvaluateModule(module, {ctx}, vm_args); + DLOG(INFO) << "Evaluate VM returning: result=" << result->tag; + *ret = VMToValue(module, return_type, result); +}); + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index b889b6c..a4ebd1e 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -154,6 +154,9 @@ Array ReduceCompute(const Attrs& attrs, F f) { const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); + if (inputs[0]->shape.size() == 0) { + return { topi::identity(inputs[0]) }; + } auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); @@ -251,7 +254,6 @@ bool ReduceRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; - CHECK(static_cast(data->shape.size()) != 0); std::vector&& in_shape = AsVector(data->shape); const ReduceAttrs* param = attrs.as(); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index c5c4f33..533c214 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -124,7 +124,8 @@ class CalcDep : private ExprVisitor { friend CalcDep; bool HasLet(const Var& v) { - return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + // TODO(@jroesch): MK fix me + return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); } Expr VisitExpr_(const VarNode* op) final { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index d7ea53e..b2d326e 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -118,6 +118,86 @@ Instruction::Instruction(const Instruction& instr) { } } +template +static inline void FreeIf(T* t) { + if (t != nullptr) { + delete t; + } +} + +Instruction& Instruction::operator=(const Instruction& instr) { + this->op = instr.op; + this->dst = instr.dst; + + switch (instr.op) { + case Opcode::Move: + this->from = instr.from; + return *this; + case Opcode::Select: + this->select_cond = instr.select_cond; + this->select_op1 = instr.select_op1; + this->select_op2 = instr.select_op2; + return *this; + case Opcode::Ret: + this->result = instr.result; + return *this; + case Opcode::AllocTensor: + this->shape_register = instr.shape_register; + this->dtype = instr.dtype; + return *this; + case Opcode::AllocDatatype: + this->constructor_tag = instr.constructor_tag; + this->num_fields = instr.num_fields; + FreeIf(this->datatype_fields); + this->datatype_fields = Duplicate(instr.datatype_fields, instr.num_fields); + return *this; + case Opcode::AllocClosure: + this->clo_index = instr.clo_index; + this->num_freevar = instr.num_freevar; + FreeIf(this->free_vars); + this->free_vars = Duplicate(instr.free_vars, instr.num_freevar); + return *this; + case Opcode::InvokePacked: + this->packed_index = instr.packed_index; + this->arity = instr.arity; + this->output_size = instr.output_size; + FreeIf(this->packed_args); + this->packed_args = Duplicate(instr.packed_args, instr.arity); + return *this; + case Opcode::InvokeClosure: + this->closure = instr.closure; + this->closure_args_num = instr.closure_args_num; + FreeIf(this->closure_args); + this->closure_args = Duplicate(instr.closure_args, instr.closure_args_num); + return *this; + case Opcode::Invoke: + this->func_index = instr.func_index; + this->num_args = instr.num_args; + FreeIf(this->invoke_args_registers); + this->invoke_args_registers = Duplicate(instr.invoke_args_registers, instr.num_args); + return *this; + case Opcode::If: + this->if_cond = instr.if_cond; + this->true_offset = instr.true_offset; + this->false_offset = instr.false_offset; + return *this; + case Opcode::LoadConst: + this->const_index = instr.const_index; + return *this; + case Opcode::GetField: + this->object = instr.object; + this->field_index = instr.field_index; + return *this; + case Opcode::Goto: + this->pc_offset = instr.pc_offset; + return *this; + default: + std::ostringstream out; + out << "Invalid instruction " << static_cast(instr.op); + throw std::runtime_error(out.str()); + } +} + Instruction::~Instruction() { switch (this->op) { case Opcode::Move: diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py new file mode 100644 index 0000000..e359ade --- /dev/null +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Benchmarking Relay VM using models from MXNet.""" +import numpy as np + +import tvm +from tvm.contrib import graph_runtime +from tvm import relay +from tvm.relay import testing + + +def benchmark_execution(net, + params, + measure=False, + data_shape=(1, 3, 224, 224), + out_shape=(1, 1000), + dtype='float32'): + def get_tvm_output(net, data, params, target, ctx, dtype='float32'): + with relay.build_config(opt_level=1): + graph, lib, params = relay.build(net, target, params=params) + + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input("data", data) + m.set_input(**params) + m.run() + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)) + + if measure: + print("Evaluate graph runtime inference time cost...") + ftimer = m.module.time_evaluator("run", ctx, number=1, repeat=20) + # Measure in millisecond. + prof_res = np.array(ftimer().results) * 1000 + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + + return out.asnumpy() + + def get_tvm_vm_output(net, data, params, target, ctx, dtype='float32'): + ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx) + result = ex.evaluate(net)(data, **params) + return result.asnumpy().astype(dtype) + + # random input + data = np.random.uniform(size=data_shape).astype(dtype) + target = "llvm" + ctx = tvm.cpu(0) + + tvm_out = get_tvm_output(net, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + vm_out = get_tvm_vm_output(net, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_mlp(): + image_shape = (1, 28, 28) + net, params = testing.mlp.get_workload(1) + benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 10)) + + +def test_vgg(): + for n in [11, 16]: + net, params = testing.vgg.get_workload(1, num_layers=n) + benchmark_execution(net, params) + + +def test_resnet(): + for n in [18, 50]: + net, params = testing.resnet.get_workload(batch_size=1, num_layers=n) + benchmark_execution(net, params, True) + + +def test_squeezenet(): + for version in ['1.0', '1.1']: + net, params = testing.squeezenet.get_workload(version=version) + benchmark_execution(net, params) + + +def test_inception_v3(): + image_shape = (3, 299, 299) + net, params = testing.inception_v3.get_workload(image_shape=image_shape) + benchmark_execution(net, params, data_shape=image_shape) + + +def test_dqn(): + image_shape = (4, 84, 84) + net, params = testing.dqn.get_workload( + batch_size=1, image_shape=image_shape) + benchmark_execution(net, params, data_shape=image_shape, out_shape=(1, 18)) + + +def test_dcgan(): + image_shape = (1, 100) + net, params = testing.dcgan.get_workload(batch_size=1) + benchmark_execution(net, params, data_shape=image_shape) + + +def test_mobilenet(): + net, params = testing.mobilenet.get_workload(batch_size=1) + benchmark_execution(net, params) + + +def test_densenet(): + net, params = testing.densenet.get_workload(batch_size=1) + benchmark_execution(net, params) + + +if __name__ == '__main__': + test_resnet() + test_vgg() + test_squeezenet() + test_mobilenet() + test_densenet() + # The following networks fail + # test_inception_v3() + # test_mlp() + # test_dqn() + # test_dcgan() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py new file mode 100644 index 0000000..bc99418 --- /dev/null +++ b/tests/python/relay/test_vm.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from nose.tools import nottest + +import tvm +import numpy as np +from tvm import relay +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.prelude import Prelude + +def veval(f, *args, ctx=tvm.cpu()): + if isinstance(f, relay.Expr): + ex = relay.create_executor('vm', mod=relay.Module(), ctx=ctx) + if len(args) == 0: + return ex.evaluate(f) + else: + return ex.evaluate(f)(*args) + else: + assert isinstance(f, relay.Module), "expected expression or module" + mod = f + ex = relay.create_executor('vm', mod=mod, ctx=ctx) + if len(args) == 0: + return ex.evaluate(mod[mod.entry_func]) + else: + return ex.evaluate(mod[mod.entry_func])(*args) + +def test_split(): + x = relay.var('x', shape=(12,)) + y = relay.split(x, 3, axis=0).astuple() + z = relay.concatenate([relay.TupleGetItem(y, 0)], axis=0) + f = relay.Function([x], z) + + x_data = np.random.rand(12,).astype('float32') + res = veval(f, x_data) + tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0]) + +def test_id(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x) + x_data = np.random.rand(10, 10).astype('float64') + res = veval(f, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data) + +def test_op(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + x_data = np.random.rand(10, 10).astype('float32') + res = veval(f, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) + +def any(x): + x = relay.op.nn.batch_flatten(x) + return relay.op.min(x, axis=[0, 1]) + +def test_cond(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('x', shape=(10, 10)) + # f = relay.Function([x, y], relay.op.equal(x, y)) + f = relay.Function([x, y], any(relay.op.equal(x, y))) + x_data = np.random.rand(10, 10).astype('float32') + y_data = np.random.rand(10, 10).astype('float32') + + # same + res = veval(f, x_data, x_data) + np.testing.assert_allclose(res.asnumpy(), True) + + # diff + res = veval(f, x_data, y_data) + tvm.testing.assert_allclose(res.asnumpy(), False) + + +def test_simple_if(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(10, 10)) + f = relay.Function([x, y], + relay.If(any(relay.op.equal(x, y)), x, y)) + x_data = np.random.rand(10, 10).astype('float32') + y_data = np.random.rand(10, 10).astype('float32') + + # same + res = veval(f, x_data, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data) + + # diff + res = veval(f, x_data, y_data) + tvm.testing.assert_allclose(res.asnumpy(), y_data) + +def test_simple_call(): + mod = relay.module.Module({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + sb = ScopeBuilder() + sb.ret(i) + func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) + mod[sum_up] = func + i_data = np.array(0, dtype='int32') + iarg = relay.var('i', shape=[], dtype='int32') + mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg)) + result = veval(mod, i_data) + tvm.testing.assert_allclose(result.asnumpy(), i_data) + +def test_count_loop(): + mod = relay.module.Module({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + sb = ScopeBuilder() + with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))): + sb.ret(i) + with sb.else_scope(): + one_less = relay.subtract(i, relay.const(1, dtype='int32')) + rec_call = relay.Call(sum_up, [one_less]) + sb.ret(relay.add(rec_call, i)) + func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) + mod[sum_up] = func + i_data = np.array(0, dtype='int32') + iarg = relay.var('i', shape=[], dtype='int32') + mod[mod.entry_func] = relay.Function([iarg], sum_up(iarg)) + result = veval(mod, i_data) + tvm.testing.assert_allclose(result.asnumpy(), i_data) + +def test_sum_loop(): + mod = relay.module.Module({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + accum = relay.var('accum', shape=[], dtype='int32') + sb = ScopeBuilder() + with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))): + sb.ret(accum) + with sb.else_scope(): + one_less = relay.subtract(i, relay.const(1, 'int32')) + new_accum = relay.add(accum, i) + sb.ret(relay.Call(sum_up, [one_less, new_accum])) + func = relay.Function([i, accum], sb.get()) + mod[sum_up] = func + loop_bound = 0 + i_data = np.array(loop_bound, dtype='int32') + accum_data = np.array(0, dtype='int32') + iarg = relay.var('i', shape=[], dtype='int32') + aarg = relay.var('accum', shape=[], dtype='int32') + mod[mod.entry_func] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) + result = veval(mod, i_data, accum_data) + tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) + +def test_tuple_fst(): + ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) + tup = relay.var('tup', type_annotation=ttype) + f = relay.Function([tup], relay.TupleGetItem(tup, 0)) + i_data = np.random.rand(41).astype('float32') + j_data = np.random.rand(10).astype('float32') + result = veval(f, (i_data, j_data)) + tvm.testing.assert_allclose(result.asnumpy(), i_data) + +def test_tuple_second(): + ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) + tup = relay.var('tup', type_annotation=ttype) + f = relay.Function([tup], relay.TupleGetItem(tup, 1)) + i_data = np.random.rand(41).astype('float32') + j_data = np.random.rand(10).astype('float32') + result = veval(f, (i_data, j_data)) + tvm.testing.assert_allclose(result.asnumpy(), j_data) + +@nottest +def test_list_constructor(): + # TODO(wweic): implement pattern match to support this test + def to_list(o): + if isinstance(o, tvm.relay.backend.interpreter.TensorValue): + return [o.data.asnumpy().tolist()] + if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + result = [] + for f in o.fields: + result.extend(to_list(f)) + return result + + mod = relay.Module() + p = Prelude(mod) + + nil = p.nil + cons = p.cons + l = p.l + + one2 = cons(relay.const(1), nil()) + one3 = cons(relay.const(2), one2) + one4 = cons(relay.const(3), one3) + f = relay.Function([], one4) + + mod[mod.entry_func] = f + + result = veval(mod)() + obj = to_list(result) + import pdb; pdb.set_trace() + tvm.testing.assert_allclose(obj, np.array([3,2,1])) + +def test_let_tensor(): + sb = relay.ScopeBuilder() + shape = (1,) + x = relay.var('x', shape=shape, dtype='float32') + x1 = relay.var('x1', shape=shape, dtype='float32') + + x1 = sb.let(x1, x) + xplusone = x1 + relay.const(42.0, 'float32') + sb.ret(xplusone) + body = sb.get() + + f = relay.Function([x], body) + + x_data = np.random.rand(*shape).astype('float32') + result = veval(f, x_data) + tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0) + +def test_let_scalar(): + sb = relay.ScopeBuilder() + + x = relay.var('x', 'float32') + x1 = sb.let('x1', x) + xplusone = x1 + relay.const(42.0, 'float32') + sb.ret(xplusone) + body = sb.get() + + f = relay.Function([x], body) + + x_data = np.array(np.random.rand()).astype('float32') + result = veval(f, x_data) + tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0) + +def test_closure(): + x = relay.var('x', shape=()) + y = relay.var('y', shape=()) + f = relay.Function([x], x + y) + ff = relay.Function([y], f) + clo = ff(relay.const(1.0)) + main = clo(relay.const(2.0)) + res = veval(main) + tvm.testing.assert_allclose(res.asnumpy(), 3.0) + +if __name__ == "__main__": + test_id() + test_op() + test_cond() + test_simple_if() + test_simple_call() + test_count_loop() + test_sum_loop() + test_tuple_fst() + test_tuple_second() + test_let_scalar() + test_let_tensor() + # TODO(@jroesch): restore when match is supported + # test_list_constructor() + test_closure() -- 2.7.4