[Relay][Vm] Some performance improvement to VM (#5901)
authorHaichen Shen <shenhaichen@gmail.com>
Thu, 25 Jun 2020 14:55:40 +0000 (07:55 -0700)
committerGitHub <noreply@github.com>
Thu, 25 Jun 2020 14:55:40 +0000 (07:55 -0700)
* make alignment constant

* tweak copyto and loadscalarint

* some safety check

* x

* lint

* fix

include/tvm/runtime/vm.h
src/relay/backend/vm/compiler.cc
src/runtime/vm/executable.cc
src/runtime/vm/vm.cc

index 552edc5..b9ccbf9 100644 (file)
@@ -241,7 +241,7 @@ struct Instruction {
       /*! \brief The size of the allocation. */
       RegName allocation_size;
       /*! \brief The alignment of the allocation. */
-      RegName alignment;
+      Index alignment;
       /*! \brief The hint of the dtype. */
       DLDataType dtype_hint;
     } alloc_storage;
@@ -386,7 +386,7 @@ struct Instruction {
    * \param dst The destination to place the storage.
    * \return The alloc storage instruction.
    */
-  static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint,
+  static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
                                   RegName dst);
 
   Instruction();
@@ -733,7 +733,7 @@ class VirtualMachine : public runtime::ModuleNode {
    * \param reg The register to read from.
    * \return The read scalar.
    */
-  int32_t LoadScalarInt(RegName reg) const;
+  inline int64_t LoadScalarInt(RegName reg) const;
 
   /*!
    * \brief Invoke a VM function.
index 0af1949..0b839a2 100644 (file)
@@ -204,6 +204,9 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause
 
 std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
   std::vector<int64_t> raw_shape;
+  if (shape->ndim == 0) {
+    return raw_shape;
+  }
   CHECK_EQ(shape->ndim, 1u);
   CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got "
                                   << DLDataType2String(shape->dtype);
@@ -425,10 +428,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     // Prepare input and output registers
     std::vector<Index> argument_registers;
     for (auto input : inputs) {
-      auto reg = var_register_map_.find(Downcast<Var>(input));
-      CHECK(reg != var_register_map_.end())
-          << "internal error: all variables should be in the register mapping";
-      argument_registers.push_back(reg->second);
+      VisitExpr(input);
+      argument_registers.push_back(last_register_);
     }
 
     for (auto output : outputs) {
@@ -457,10 +458,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
                         << "please file a bug in the memory manifestation pass";
 
     for (auto input : input_tuple->fields) {
-      auto reg = var_register_map_.find(Downcast<Var>(input));
-      CHECK(reg != var_register_map_.end())
-          << "internal error: all variables should be in the register mapping";
-      argument_registers.push_back(reg->second);
+      VisitExpr(input);
+      argument_registers.push_back(last_register_);
     }
 
     for (auto output : output_tuple->fields) {
@@ -566,16 +565,20 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
                    this->VisitExpr(args[0]);
                    auto size_register = last_register_;
 
-                   this->VisitExpr(args[1]);
-                   auto alignment_register = last_register_;
+                   CHECK(args[1].as<ConstantNode>());
+                   NDArray alignment_arr = args[1].as<ConstantNode>()->data;
+                   CHECK_EQ(alignment_arr->dtype.code, 0U)
+                       << "The dtype of constant shape must be int32 or int64, but got "
+                       << DLDataType2String(alignment_arr->dtype);
+                   CHECK_EQ(alignment_arr->dtype.bits, 64U);
+                   Index alignment = reinterpret_cast<int64_t*>(alignment_arr->data)[0];
 
                    // Get the dtype hint from the attributes.
                    auto alloc_attrs = attrs.as<AllocStorageAttrs>();
                    CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs";
                    auto dtype = alloc_attrs->dtype;
 
-                   Emit(Instruction::AllocStorage(size_register, alignment_register, dtype,
-                                                  NewRegister()));
+                   Emit(Instruction::AllocStorage(size_register, alignment, dtype, NewRegister()));
                  })
           .Match("memory.shape_func",
                  [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
@@ -890,7 +893,9 @@ transform::Sequential MemoryOpt(tvm::Target host_target) {
   pass_seqs.push_back(transform::FoldConstant());
 
   // Lift constants to the top-level of the block to simplify VM code generation.
-  pass_seqs.push_back(transform::LiftConstants());
+  // TODO(@icemelon9, @jroesch): Remove this pass for now because some
+  //  instructions need to access to constant
+  // pass_seqs.push_back(transform::LiftConstants());
 
   return transform::Sequential(pass_seqs);
 }
index 47bdd1c..65b1a2f 100644 (file)
@@ -552,7 +552,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
     case Opcode::AllocTensor: {
       // Number of fields = 7 + instr.alloc_tensor.ndim
       DCHECK_GE(instr.fields.size(), 7U);
-      DCHECK_EQ(instr.fields.size(), 7U + static_cast<size_t>(instr.fields[4]));
+      DCHECK_EQ(instr.fields.size(), 7U + static_cast<size_t>(instr.fields[5]));
 
       RegName storage_reg = instr.fields[0];
       RegName offset = instr.fields[1];
index 42bca37..0c0ca35 100644 (file)
@@ -529,8 +529,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
     }
     case Opcode::AllocTensorReg: {
       os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $"
-         << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.offset << " $"
-         << instr.alloc_tensor_reg.shape_register << " ";
+         << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " ";
       DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
       break;
     }
@@ -581,7 +580,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       break;
     }
     case Opcode::AllocStorage: {
-      os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " $"
+      os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " "
          << instr.alloc_storage.alignment << " "
          << DLDataType2String(instr.alloc_storage.dtype_hint);
       break;
@@ -822,6 +821,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
     CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
     packed_funcs_[packed_index] = pf;
   }
+  for (size_t i = 0; i < packed_funcs_.size(); ++i) {
+    CHECK(packed_funcs_[i] != nullptr) << "Packed function " << i << " is not initialized";
+  }
 }
 
 void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ctxs_ = ctxs; }
@@ -834,18 +836,34 @@ inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
   return frames_.back().register_file[r];
 }
 
-inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
-  int32_t result;
+inline int64_t VirtualMachine::LoadScalarInt(Index r) const {
+  int64_t result = 0;
   const auto& obj = ReadRegister(r);
-  auto nd_array = Downcast<NDArray>(obj);
-  NDArray array = nd_array.CopyTo({kDLCPU, 0});
+  NDArray array = Downcast<NDArray>(CopyTo(obj, {kDLCPU, 0}));
 
-  if (array->dtype.bits <= 8) {
-    result = reinterpret_cast<int8_t*>(array->data)[0];
-  } else if (array->dtype.bits <= 16) {
-    result = reinterpret_cast<int16_t*>(array->data)[0];
-  } else {
-    result = reinterpret_cast<int32_t*>(array->data)[0];
+  switch (array->dtype.bits) {
+    case 1: {
+      result = reinterpret_cast<bool*>(array->data)[0];
+      break;
+    }
+    case 8: {
+      result = reinterpret_cast<int8_t*>(array->data)[0];
+      break;
+    }
+    case 16: {
+      result = reinterpret_cast<int16_t*>(array->data)[0];
+      break;
+    }
+    case 32: {
+      result = reinterpret_cast<int32_t*>(array->data)[0];
+      break;
+    }
+    case 64: {
+      result = reinterpret_cast<int64_t*>(array->data)[0];
+      break;
+    }
+    default:
+      LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(array->dtype);
   }
   return result;
 }
@@ -908,8 +926,8 @@ void VirtualMachine::RunLoop() {
         goto main_loop;
       }
       case Opcode::InvokePacked: {
-        DLOG(INFO) << "InvokedPacked "
-                   << "arity=" << instr.arity;
+        DLOG(INFO) << "InvokedPacked " << instr.packed_index << " arity=" << instr.arity;
+        CHECK_LE(instr.packed_index, packed_funcs_.size());
         const auto& func = packed_funcs_[instr.packed_index];
         const auto& arity = instr.arity;
         std::vector<ObjectRef> args;
@@ -996,9 +1014,8 @@ void VirtualMachine::RunLoop() {
         DLContext cpu_ctx;
         cpu_ctx.device_type = kDLCPU;
         cpu_ctx.device_id = 0;
-        auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
-        const auto shape_arr = Downcast<NDArray>(shape_tensor_obj);
-        NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx);
+        auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
+        NDArray shape_tensor = Downcast<NDArray>(CopyTo(shape_obj, cpu_ctx));
         auto shape = ToShape(shape_tensor);
         auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage);
         auto storage = Downcast<Storage>(storage_obj);
@@ -1030,7 +1047,7 @@ void VirtualMachine::RunLoop() {
       }
       case Opcode::AllocStorage: {
         auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
-        auto alignment = LoadScalarInt(instr.alloc_storage.alignment);
+        auto alignment = instr.alloc_storage.alignment;
 
         DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment
                    << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);