[RUNTIME][REFACTOR] Use object protocol to support runtime::Module (#4289)
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 11 Nov 2019 18:09:29 +0000 (10:09 -0800)
committerGitHub <noreply@github.com>
Mon, 11 Nov 2019 18:09:29 +0000 (10:09 -0800)
Previously runtime::Module was supported using shared_ptr.
This PR refactors the codebase to use the Object protocol.

It will open doors to allow easier interpolation between
Object containers and module in the future.

58 files changed:
apps/android_deploy/app/src/main/jni/tvm_runtime.h
apps/android_rpc/app/src/main/jni/tvm_runtime.h
apps/bundle_deploy/runtime.cc
apps/howto_deploy/tvm_runtime_pack.cc
apps/ios_rpc/tvmrpc/TVMRuntime.mm
golang/src/tvm_runtime_pack.cc
include/tvm/runtime/module.h
include/tvm/runtime/object.h
include/tvm/runtime/packed_func.h
include/tvm/runtime/vm.h
python/tvm/relay/backend/vm.py
src/codegen/llvm/llvm_module.cc
src/codegen/source_module.cc
src/relay/backend/build_module.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
src/relay/backend/vm/profiler/compiler.cc
src/runtime/c_runtime_api.cc
src/runtime/cuda/cuda_module.cc
src/runtime/cuda/cuda_module.h
src/runtime/dso_module.cc
src/runtime/graph/debug/graph_runtime_debug.cc
src/runtime/graph/graph_runtime.cc
src/runtime/graph/graph_runtime.h
src/runtime/metal/metal_module.mm
src/runtime/micro/micro_device_api.cc
src/runtime/micro/micro_module.cc
src/runtime/micro/micro_session.cc
src/runtime/micro/micro_session.h
src/runtime/micro/tcl_socket.h
src/runtime/module.cc
src/runtime/module_util.cc
src/runtime/module_util.h
src/runtime/object.cc
src/runtime/object_internal.h [new file with mode: 0644]
src/runtime/opencl/aocl/aocl_common.h
src/runtime/opencl/aocl/aocl_device_api.cc
src/runtime/opencl/aocl/aocl_module.h
src/runtime/opencl/opencl_common.h
src/runtime/opencl/opencl_module.cc
src/runtime/opencl/opencl_module.h
src/runtime/opengl/opengl_module.cc
src/runtime/opengl/opengl_module.h
src/runtime/rocm/rocm_module.cc
src/runtime/rpc/rpc_module.cc
src/runtime/rpc/rpc_session.cc
src/runtime/rpc/rpc_session.h
src/runtime/stackvm/stackvm.cc
src/runtime/stackvm/stackvm_module.cc
src/runtime/system_lib_module.cc
src/runtime/vm/executable.cc
src/runtime/vm/profiler/vm.cc
src/runtime/vm/profiler/vm.h
src/runtime/vm/vm.cc
src/runtime/vulkan/vulkan.cc
vta/src/dpi/module.cc
web/web_runtime.cc

index 3a909e0..573612b 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file tvm_runtime.h
  * \brief Pack all tvm runtime source files
  */
@@ -35,6 +34,7 @@
 #include "../src/runtime/file_util.cc"
 #include "../src/runtime/dso_module.cc"
 #include "../src/runtime/thread_pool.cc"
+#include "../src/runtime/object.cc"
 #include "../src/runtime/threading_backend.cc"
 #include "../src/runtime/ndarray.cc"
 
index 73a8a48..e30b316 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -55,6 +55,7 @@
 #include "../src/runtime/threading_backend.cc"
 #include "../src/runtime/graph/graph_runtime.cc"
 #include "../src/runtime/ndarray.cc"
+#include "../src/runtime/object.cc"
 
 #ifdef TVM_OPENCL_RUNTIME
 #include "../src/runtime/opencl/opencl_device_api.cc"
index 968554b..f1c2ba2 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -32,5 +32,6 @@
 #include "../../src/runtime/threading_backend.cc"
 #include "../../src/runtime/thread_pool.cc"
 #include "../../src/runtime/ndarray.cc"
+#include "../../src/runtime/object.cc"
 #include "../../src/runtime/system_lib_module.cc"
 #include "../../src/runtime/graph/graph_runtime.cc"
index 6ebad81..67c9a9d 100644 (file)
@@ -47,6 +47,7 @@
 #include "../../src/runtime/threading_backend.cc"
 #include "../../src/runtime/thread_pool.cc"
 #include "../../src/runtime/ndarray.cc"
+#include "../../src/runtime/object.cc"
 
 // NOTE: all the files after this are optional modules
 // that you can include remove, depending on how much feature you use.
index 5d1d90e..a98862a 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file TVMRuntime.mm
  */
 #include "TVMRuntime.h"
@@ -35,6 +34,8 @@
 #include "../../../src/runtime/file_util.cc"
 #include "../../../src/runtime/dso_module.cc"
 #include "../../../src/runtime/ndarray.cc"
+#include "../../../src/runtime/object.cc"
+
 // RPC server
 #include "../../../src/runtime/rpc/rpc_session.cc"
 #include "../../../src/runtime/rpc/rpc_server_env.cc"
index cfbe237..c8be428 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \brief This is an all in one TVM runtime file.
  * \file tvm_runtime_pack.cc
  */
@@ -32,6 +31,7 @@
 #include "src/runtime/threading_backend.cc"
 #include "src/runtime/thread_pool.cc"
 #include "src/runtime/ndarray.cc"
+#include "src/runtime/object.cc"
 
 // NOTE: all the files after this are optional modules
 // that you can include remove, depending on how much feature you use.
index 7bbfa4d..ff096ee 100644 (file)
 #define TVM_RUNTIME_MODULE_H_
 
 #include <dmlc/io.h>
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/memory.h>
+
 #include <memory>
 #include <vector>
 #include <string>
 #include <unordered_map>
-#include "c_runtime_api.h"
 
 namespace tvm {
 namespace runtime {
 
-// The internal container of module.
 class ModuleNode;
 class PackedFunc;
 
 /*!
  * \brief Module container of TVM.
  */
-class Module {
+class Module : public ObjectRef {
  public:
   Module() {}
   // constructor from container.
-  explicit Module(std::shared_ptr<ModuleNode> n)
-      : node_(n) {}
+  explicit Module(ObjectPtr<Object> n)
+      : ObjectRef(n) {}
   /*!
    * \brief Get packed function from current module by name.
    *
@@ -59,10 +62,6 @@ class Module {
    * \note Implemented in packed_func.cc
    */
   inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
-  /*! \return internal container */
-  inline ModuleNode* operator->();
-  /*! \return internal container */
-  inline const ModuleNode* operator->() const;
   // The following functions requires link with runtime.
   /*!
    * \brief Import another module into this module.
@@ -71,7 +70,11 @@ class Module {
    * \note Cyclic dependency is not allowed among modules,
    *  An error will be thrown when cyclic dependency is detected.
    */
-  TVM_DLL void Import(Module other);
+  inline void Import(Module other);
+  /*! \return internal container */
+  inline ModuleNode* operator->();
+  /*! \return internal container */
+  inline const ModuleNode* operator->() const;
   /*!
    * \brief Load a module from file.
    * \param file_name The name of the host function module.
@@ -81,20 +84,41 @@ class Module {
    */
   TVM_DLL static Module LoadFromFile(const std::string& file_name,
                                      const std::string& format = "");
-
- private:
-  std::shared_ptr<ModuleNode> node_;
+  // refer to the corresponding container.
+  using ContainerType = ModuleNode;
+  friend class ModuleNode;
 };
 
 /*!
- * \brief Base node container of module.
- *  Do not create this directly, instead use Module.
+ * \brief Base container of module.
+ *
+ * Please subclass ModuleNode to create a specific runtime module.
+ *
+ * \code
+ *
+ *  class MyModuleNode : public ModuleNode {
+ *   public:
+ *    // implement the interface
+ *  };
+ *
+ *  // use make_object to create a specific
+ *  // instace of MyModuleNode.
+ *  Module CreateMyModule() {
+ *    ObjectPtr<MyModuleNode> n =
+ *      tvm::runtime::make_object<MyModuleNode>();
+ *    return Module(n);
+ *  }
+ *
+ * \endcode
  */
-class ModuleNode {
+class ModuleNode : public Object {
  public:
   /*! \brief virtual destructor */
   virtual ~ModuleNode() {}
-  /*! \return The module type key */
+  /*!
+   * \return The per module type key.
+   * \note This key is used to for serializing custom modules.
+   */
   virtual const char* type_key() const = 0;
   /*!
    * \brief Get a PackedFunc from module.
@@ -105,7 +129,7 @@ class ModuleNode {
    *  For benchmarking, use prepare to eliminate
    *
    * \param name the name of the function.
-   * \param sptr_to_self The shared_ptr that points to this module node.
+   * \param sptr_to_self The ObjectPtr that points to this module node.
    *
    * \return PackedFunc(nullptr) when it is not available.
    *
@@ -115,7 +139,7 @@ class ModuleNode {
    */
   virtual PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
+      const ObjectPtr<Object>& sptr_to_self) = 0;
   /*!
    * \brief Save the module to file.
    * \param file_name The file to be saved to.
@@ -138,6 +162,24 @@ class ModuleNode {
    */
   TVM_DLL virtual std::string GetSource(const std::string& format = "");
   /*!
+   * \brief Get packed function from current module by name.
+   *
+   * \param name The name of the function.
+   * \param query_imports Whether also query dependency modules.
+   * \return The result function.
+   *  This function will return PackedFunc(nullptr) if function do not exist.
+   * \note Implemented in packed_func.cc
+   */
+  TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
+  /*!
+   * \brief Import another module into this module.
+   * \param other The module to be imported.
+   *
+   * \note Cyclic dependency is not allowed among modules,
+   *  An error will be thrown when cyclic dependency is detected.
+   */
+  TVM_DLL void Import(Module other);
+  /*!
    * \brief Get a function from current environment
    *  The environment includes all the imports as well as Global functions.
    *
@@ -150,6 +192,13 @@ class ModuleNode {
     return imports_;
   }
 
+  // integration with the existing components.
+  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
+  static constexpr const char* _type_key = "runtime.Module";
+  // NOTE: ModuleNode can still be sub-classed
+  //
+  TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
+
  protected:
   friend class Module;
   /*! \brief The modules this module depend on */
@@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__";
 }  // namespace symbol
 
 // implementations of inline functions.
+
+inline void Module::Import(Module other) {
+  return (*this)->Import(other);
+}
+
 inline ModuleNode* Module::operator->() {
-  return node_.get();
+  return static_cast<ModuleNode*>(get_mutable());
 }
 
 inline const ModuleNode* Module::operator->() const {
-  return node_.get();
+  return static_cast<const ModuleNode*>(get());
 }
 
 }  // namespace runtime
 }  // namespace tvm
 
-#include "packed_func.h"
+#include <tvm/runtime/packed_func.h>  // NOLINT(*)
 #endif  // TVM_RUNTIME_MODULE_H_
index 0aa7815..20e6b5a 100644 (file)
@@ -53,6 +53,7 @@ enum TypeIndex  {
   kVMTensor = 1,
   kVMClosure = 2,
   kVMADT = 3,
+  kRuntimeModule = 4,
   kStaticIndexEnd,
   /*! \brief Type index is allocated during runtime. */
   kDynamic = kStaticIndexEnd
@@ -302,7 +303,7 @@ class Object {
   template<typename>
   friend class ObjectPtr;
   friend class TVMRetValue;
-  friend class TVMObjectCAPI;
+  friend class ObjectInternal;
 };
 
 /*!
@@ -310,11 +311,11 @@ class Object {
  *
  *  It is always important to get a reference type
  *  if we want to return a value as reference or keep
- *  the node alive beyond the scope of the function.
+ *  the object alive beyond the scope of the function.
  *
- * \param ptr The node pointer
+ * \param ptr The object pointer
  * \tparam RefType The reference type
- * \tparam ObjectType The node type
+ * \tparam ObjectType The object type
  * \return The corresponding RefType
  */
 template <typename RefType, typename ObjectType>
@@ -486,6 +487,8 @@ class ObjectPtr {
   friend class TVMArgValue;
   template <typename RefType, typename ObjType>
   friend RefType GetRef(const ObjType* ptr);
+  template <typename BaseType, typename ObjType>
+  friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
 };
 
 /*! \brief Base class of all object reference */
@@ -513,7 +516,7 @@ class ObjectRef {
   }
   /*!
    * \brief Comparator
-   * \param other Another node ref.
+   * \param other Another object ref.
    * \return the compare result.
    */
   bool operator!=(const ObjectRef& other) const {
@@ -535,7 +538,7 @@ class ObjectRef {
   const Object* get() const {
     return data_.get();
   }
-  /*! \return the internal node pointer */
+  /*! \return the internal object pointer */
   const Object* operator->() const {
     return get();
   }
@@ -595,6 +598,16 @@ class ObjectRef {
   friend SubRef Downcast(BaseRef ref);
 };
 
+/*!
+ * \brief Get an object ptr type from a raw object ptr.
+ *
+ * \param ptr The object pointer
+ * \tparam BaseType The reference type
+ * \tparam ObjectType The object type
+ * \return The corresponding RefType
+ */
+template <typename BaseType, typename ObjectType>
+inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
 
 /*! \brief ObjectRef hash functor */
 struct ObjectHash {
@@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) {
   return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
 }
 
+template <typename BaseType, typename ObjType>
+inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
+  static_assert(std::is_base_of<BaseType, ObjType>::value,
+                "Can only cast to the ref of same container type");
+  return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
+}
+
 template <typename SubRef, typename BaseRef>
 inline SubRef Downcast(BaseRef ref) {
   CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
index 645f499..57c4291 100644 (file)
@@ -496,6 +496,14 @@ class TVMPODValue_ {
     return ObjectRef(
         ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
   }
+  operator Module() const {
+    if (type_code_ == kNull) {
+      return Module(ObjectPtr<Object>(nullptr));
+    }
+    TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
+    return Module(
+        ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
+  }
   operator TVMContext() const {
     TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
     return value_.v_ctx;
@@ -574,6 +582,7 @@ class TVMArgValue : public TVMPODValue_ {
   using TVMPODValue_::operator NDArray;
   using TVMPODValue_::operator TVMContext;
   using TVMPODValue_::operator ObjectRef;
+  using TVMPODValue_::operator Module;
   using TVMPODValue_::IsObjectRef;
 
   // conversion operator.
@@ -610,10 +619,6 @@ class TVMArgValue : public TVMPODValue_ {
   operator TypedPackedFunc<FType>() const {
     return TypedPackedFunc<FType>(operator PackedFunc());
   }
-  operator Module() const {
-    TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
-    return *ptr<Module>();
-  }
   const TVMValue& value() const {
     return value_;
   }
@@ -665,6 +670,7 @@ class TVMRetValue : public TVMPODValue_ {
   using TVMPODValue_::operator TVMContext;
   using TVMPODValue_::operator NDArray;
   using TVMPODValue_::operator ObjectRef;
+  using TVMPODValue_::operator Module;
   using TVMPODValue_::IsObjectRef;
 
   TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
@@ -696,10 +702,6 @@ class TVMRetValue : public TVMPODValue_ {
   operator TypedPackedFunc<FType>() const {
     return TypedPackedFunc<FType>(operator PackedFunc());
   }
-  operator Module() const {
-    TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
-    return *ptr<Module>();
-  }
   // Assign operators
   TVMRetValue& operator=(TVMRetValue&& other) {
     this->Clear();
@@ -766,17 +768,13 @@ class TVMRetValue : public TVMPODValue_ {
   TVMRetValue& operator=(ObjectRef other) {
     return operator=(std::move(other.data_));
   }
+  TVMRetValue& operator=(Module m) {
+    SwitchToObject(kModuleHandle, std::move(m.data_));
+    return *this;
+  }
   template<typename T>
   TVMRetValue& operator=(ObjectPtr<T> other) {
-    if (other.data_ != nullptr) {
-      this->Clear();
-      type_code_ = kObjectHandle;
-      // move the handle out
-      value_.v_handle = other.data_;
-      other.data_ = nullptr;
-    } else {
-      SwitchToPOD(kNull);
-    }
+    SwitchToObject(kObjectHandle, std::move(other));
     return *this;
   }
   TVMRetValue& operator=(PackedFunc f) {
@@ -787,10 +785,6 @@ class TVMRetValue : public TVMPODValue_ {
   TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
     return operator=(f.packed());
   }
-  TVMRetValue& operator=(Module m) {
-    this->SwitchToClass(kModuleHandle, m);
-    return *this;
-  }
   TVMRetValue& operator=(const TVMRetValue& other) {  // NOLINT(*0
     this->Assign(other);
     return *this;
@@ -860,7 +854,7 @@ class TVMRetValue : public TVMPODValue_ {
         break;
       }
       case kModuleHandle: {
-        SwitchToClass<Module>(kModuleHandle, other);
+        *this = other.operator Module();
         break;
       }
       case kNDArrayContainer: {
@@ -907,16 +901,30 @@ class TVMRetValue : public TVMPODValue_ {
       *static_cast<T*>(value_.v_handle) = v;
     }
   }
+  void SwitchToObject(int type_code, ObjectPtr<Object> other) {
+    if (other.data_ != nullptr) {
+      this->Clear();
+      type_code_ = type_code;
+      // move the handle out
+      value_.v_handle = other.data_;
+      other.data_ = nullptr;
+    } else {
+      SwitchToPOD(kNull);
+    }
+  }
   void Clear() {
     if (type_code_ == kNull) return;
     switch (type_code_) {
       case kStr: delete ptr<std::string>(); break;
       case kFuncHandle: delete ptr<PackedFunc>(); break;
-      case kModuleHandle: delete ptr<Module>(); break;
       case kNDArrayContainer: {
         static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
         break;
       }
+      case kModuleHandle: {
+        static_cast<Object*>(value_.v_handle)->DecRef();
+        break;
+      }
       case kObjectHandle: {
         static_cast<Object*>(value_.v_handle)->DecRef();
         break;
@@ -1156,8 +1164,12 @@ class TVMArgsSetter {
     operator()(i, value.packed());
   }
   void operator()(size_t i, const Module& value) const {  // NOLINT(*)
-    values_[i].v_handle = const_cast<Module*>(&value);
-    type_codes_[i] = kModuleHandle;
+    if (value.defined()) {
+      values_[i].v_handle = value.data_.data_;
+      type_codes_[i] = kModuleHandle;
+    } else {
+      type_codes_[i] = kNull;
+    }
   }
   void operator()(size_t i, const NDArray& value) const {  // NOLINT(*)
     values_[i].v_handle = value.data_;
@@ -1372,19 +1384,10 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
   return ExtTypeVTable::RegisterInternal(code, vt);
 }
 
-// Implement Module::GetFunction
-// Put implementation in this file so we have seen the PackedFunc
 inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
-  PackedFunc pf = node_->GetFunction(name, node_);
-  if (pf != nullptr) return pf;
-  if (query_imports) {
-    for (const Module& m : node_->imports_) {
-      pf = m.node_->GetFunction(name, m.node_);
-      if (pf != nullptr) return pf;
-    }
-  }
-  return pf;
+  return (*this)->GetFunction(name, query_imports);
 }
+
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_PACKED_FUNC_H_
index a196afd..317b535 100644 (file)
@@ -480,7 +480,7 @@ class Executable : public ModuleNode {
    * \return PackedFunc or nullptr when it is not available.
    */
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+                         const ObjectPtr<Object>& sptr_to_self) final;
 
   /*!
    * \brief Serialize the executable into global section, constant section, and
@@ -658,7 +658,7 @@ class VirtualMachine : public runtime::ModuleNode {
    *   it should capture sptr_to_self.
    */
   virtual PackedFunc GetFunction(const std::string& name,
-                                 const std::shared_ptr<ModuleNode>& sptr_to_self);
+                                 const ObjectPtr<Object>& sptr_to_self);
 
   /*!
    * \brief Invoke a PackedFunction
index e190e3f..5a4c5f7 100644 (file)
@@ -148,7 +148,7 @@ class Executable(object):
             raise TypeError("bytecode is expected to be the type of bytearray " +
                             "or TVMByteArray, but received {}".format(type(code)))
 
-        if not isinstance(lib, tvm.module.Module):
+        if lib is not None and not isinstance(lib, tvm.module.Module):
             raise TypeError("lib is expected to be the type of tvm.module.Module" +
                             ", but received {}".format(type(lib)))
 
index 554aec3..b8a38f5 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file llvm_module.cc
  * \brief LLVM runtime module for TVM
  */
@@ -54,7 +53,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     if (name == "__tvm_is_system_module") {
       bool flag =
           (mptr_->getFunction("__tvm_module_startup") != nullptr);
@@ -325,7 +324,7 @@ TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
 
 TVM_REGISTER_API("codegen.build_llvm")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
+    auto n = make_object<LLVMModuleNode>();
     n->Init(args[0], args[1]);
     *rv = runtime::Module(n);
   });
@@ -339,7 +338,7 @@ TVM_REGISTER_API("codegen.llvm_version_major")
 
 TVM_REGISTER_API("module.loadfile_ll")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
+    auto n = make_object<LLVMModuleNode>();
     n->LoadIR(args[0]);
     *rv = runtime::Module(n);
   });
index 88be7fe..adbe7ea 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file source_module.cc
  * \brief Source code module, only for viewing
  */
@@ -51,7 +50,7 @@ class SourceModuleNode : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     LOG(FATAL) << "Source module cannot execute, to get executable module"
                << " build TVM with \'" << fmt_ << "\' runtime support";
     return PackedFunc();
@@ -67,8 +66,7 @@ class SourceModuleNode : public runtime::ModuleNode {
 };
 
 runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
-  std::shared_ptr<SourceModuleNode> n =
-      std::make_shared<SourceModuleNode>(code, fmt);
+  auto n = make_object<SourceModuleNode>(code, fmt);
   return runtime::Module(n);
 }
 
@@ -84,7 +82,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     LOG(FATAL) << "C Source module cannot execute, to get executable module"
                << " build TVM with \'" << fmt_ << "\' runtime support";
     return PackedFunc();
@@ -113,8 +111,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
 };
 
 runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
-  std::shared_ptr<CSourceModuleNode> n =
-      std::make_shared<CSourceModuleNode>(code, fmt);
+  auto n = make_object<CSourceModuleNode>(code, fmt);
   return runtime::Module(n);
 }
 
@@ -134,7 +131,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
         const std::string& name,
-        const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+        const ObjectPtr<Object>& sptr_to_self) final {
     LOG(FATAL) << "Source module cannot execute, to get executable module"
                << " build TVM with \'" << fmt_ << "\' runtime support";
     return PackedFunc();
@@ -182,8 +179,7 @@ runtime::Module DeviceSourceModuleCreate(
     std::unordered_map<std::string, FunctionInfo> fmap,
     std::string type_key,
     std::function<std::string(const std::string&)> fget_source) {
-  std::shared_ptr<DeviceSourceModuleNode> n =
-      std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
+  auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
   return runtime::Module(n);
 }
 
index 73cf6c2..9254c7e 100644 (file)
@@ -115,7 +115,7 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \return The corresponding member function.
    */
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+                         const ObjectPtr<Object>& sptr_to_self) final {
     if (name == "get_graph_json") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         *rv = this->GetGraphJSON();
@@ -489,7 +489,7 @@ class RelayBuildModule : public runtime::ModuleNode {
 };
 
 runtime::Module RelayBuildCreate() {
-  std::shared_ptr<RelayBuildModule> exec = std::make_shared<RelayBuildModule>();
+  auto exec = make_object<RelayBuildModule>();
   return runtime::Module(exec);
 }
 
index 0342aa6..e288178 100644 (file)
@@ -593,7 +593,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
  public:
   GraphRuntimeCodegenModule() {}
   virtual PackedFunc GetFunction(const std::string& name,
-                                 const std::shared_ptr<ModuleNode>& sptr_to_self) {
+                                 const ObjectPtr<Object>& sptr_to_self) {
      if (name == "init") {
        return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
          CHECK_EQ(args.num_args, 2)
@@ -654,8 +654,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
 };
 
 runtime::Module CreateGraphCodegenMod() {
-  std::shared_ptr<GraphRuntimeCodegenModule> ptr =
-    std::make_shared<GraphRuntimeCodegenModule>();
+  auto ptr = make_object<GraphRuntimeCodegenModule>();
   return runtime::Module(ptr);
 }
 
index 3cfea5c..7f828c4 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/relay/backend/vm/compiler.cc
  * \brief A compiler from relay::Module to the VM byte code.
  */
@@ -745,7 +744,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
 
 
 PackedFunc VMCompiler::GetFunction(const std::string& name,
-                                   const std::shared_ptr<ModuleNode>& sptr_to_self) {
+                                   const ObjectPtr<Object>& sptr_to_self) {
   if (name == "compile") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       CHECK_EQ(args.num_args, 3);
@@ -974,7 +973,7 @@ void VMCompiler::LibraryCodegen() {
 }
 
 runtime::Module CreateVMCompiler() {
-  std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>();
+  auto exec = make_object<VMCompiler>();
   return runtime::Module(exec);
 }
 
index 215cc12..db319c4 100644 (file)
@@ -86,14 +86,14 @@ class VMCompiler : public runtime::ModuleNode {
   virtual ~VMCompiler() {}
 
   virtual PackedFunc GetFunction(const std::string& name,
-                                 const std::shared_ptr<ModuleNode>& sptr_to_self);
+                                 const ObjectPtr<Object>& sptr_to_self);
 
   const char* type_key() const {
     return "VMCompiler";
   }
 
   void InitVM() {
-    exec_ = std::make_shared<Executable>();
+    exec_ = make_object<Executable>();
   }
 
   /*!
@@ -141,7 +141,7 @@ class VMCompiler : public runtime::ModuleNode {
   /*! \brief Global shared meta data */
   VMCompilerContext context_;
   /*! \brief Compiled executable. */
-  std::shared_ptr<Executable> exec_;
+  ObjectPtr<Executable> exec_;
   /*! \brief parameters */
   std::unordered_map<std::string, runtime::NDArray> params_;
 };
index 60c441a..4727f15 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/relay/backend/vm/profiler/compiler.cc
  * \brief A compiler from relay::Module to the VM byte code.
  */
@@ -37,7 +36,7 @@ class VMCompilerDebug : public VMCompiler {
 };
 
 runtime::Module CreateVMCompilerDebug() {
-  std::shared_ptr<VMCompilerDebug> exec = std::make_shared<VMCompilerDebug>();
+  auto exec = make_object<VMCompilerDebug>();
   return runtime::Module(exec);
 }
 
index 13181da..3608fce 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2016 by Contributors
  * \file c_runtime_api.cc
  * \brief Device specific implementations
  */
@@ -41,6 +40,7 @@
 #include <cstdlib>
 #include <cctype>
 #include "runtime_base.h"
+#include "object_internal.h"
 
 namespace tvm {
 namespace runtime {
@@ -370,16 +370,20 @@ int TVMModLoadFromFile(const char* file_name,
                        const char* format,
                        TVMModuleHandle* out) {
   API_BEGIN();
-  Module m = Module::LoadFromFile(file_name, format);
-  *out = new Module(m);
+  TVMRetValue ret;
+  ret = Module::LoadFromFile(file_name, format);
+  TVMValue val;
+  int type_code;
+  ret.MoveToCHost(&val, &type_code);
+  *out = val.v_handle;
   API_END();
 }
 
 int TVMModImport(TVMModuleHandle mod,
                  TVMModuleHandle dep) {
   API_BEGIN();
-  static_cast<Module*>(mod)->Import(
-      *static_cast<Module*>(dep));
+  ObjectInternal::GetModuleNode(mod)->Import(
+      GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
   API_END();
 }
 
@@ -388,7 +392,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
                       int query_imports,
                       TVMFunctionHandle *func) {
   API_BEGIN();
-  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
+  PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(
       func_name, query_imports != 0);
   if (pf != nullptr) {
     *func = new PackedFunc(pf);
@@ -399,9 +403,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
 }
 
 int TVMModFree(TVMModuleHandle mod) {
-  API_BEGIN();
-  delete static_cast<Module*>(mod);
-  API_END();
+  return TVMObjectFree(mod);
 }
 
 int TVMBackendGetFuncFromEnv(void* mod_node,
index 55d9e64..e153564 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -69,7 +69,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+      const ObjectPtr<Object>& sptr_to_self) final;
 
   void SaveToFile(const std::string& file_name,
                   const std::string& format) final {
@@ -166,7 +166,7 @@ class CUDAWrappedFunc {
  public:
   // initialize the CUDA function.
   void Init(CUDAModuleNode* m,
-            std::shared_ptr<ModuleNode> sptr,
+            ObjectPtr<Object> sptr,
             const std::string& func_name,
             size_t num_void_args,
             const std::vector<std::string>& thread_axis_tags) {
@@ -220,7 +220,7 @@ class CUDAWrappedFunc {
   // internal module
   CUDAModuleNode* m_;
   // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // The name of the function.
   std::string func_name_;
   // Device function cache per device.
@@ -233,7 +233,7 @@ class CUDAWrappedFunc {
 class CUDAPrepGlobalBarrier {
  public:
   CUDAPrepGlobalBarrier(CUDAModuleNode* m,
-                        std::shared_ptr<ModuleNode> sptr)
+                        ObjectPtr<Object> sptr)
       : m_(m), sptr_(sptr) {
     std::fill(pcache_.begin(), pcache_.end(), 0);
   }
@@ -252,14 +252,14 @@ class CUDAPrepGlobalBarrier {
   // internal module
   CUDAModuleNode* m_;
   // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // mark as mutable, to enable lazy initialization
   mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
 };
 
 PackedFunc CUDAModuleNode::GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) {
+      const ObjectPtr<Object>& sptr_to_self) {
   CHECK_EQ(sptr_to_self.get(), this);
   CHECK_NE(name, symbol::tvm_module_main)
       << "Device function do not have main";
@@ -279,8 +279,7 @@ Module CUDAModuleCreate(
     std::string fmt,
     std::unordered_map<std::string, FunctionInfo> fmap,
     std::string cuda_source) {
-  std::shared_ptr<CUDAModuleNode> n =
-      std::make_shared<CUDAModuleNode>(data, fmt, fmap, cuda_source);
+  auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
   return Module(n);
 }
 
index 54ff38d..bce0d63 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file cuda_module.h
  * \brief Execution handling of CUDA kernels
  */
index 4f69f26..abbbe12 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file dso_dll_module.cc
  * \brief Module to load from dynamic shared library.
  */
 #include <tvm/runtime/module.h>
+#include <tvm/runtime/memory.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/packed_func.h>
 #include "module_util.h"
@@ -50,7 +50,7 @@ class DSOModuleNode final : public ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     BackendPackedCFunc faddr;
     if (name == runtime::symbol::tvm_module_main) {
       const char* entry_name = reinterpret_cast<const char*>(
@@ -124,7 +124,7 @@ class DSOModuleNode final : public ModuleNode {
 
 TVM_REGISTER_GLOBAL("module.loadfile_so")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
+    auto n = make_object<DSOModuleNode>();
     n->Init(args[0]);
     *rv = runtime::Module(n);
   });
index 2b26ae5..ab28cb6 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file graph_runtime_debug.cc
  */
 #include <tvm/runtime/packed_func.h>
@@ -28,6 +27,7 @@
 #include <chrono>
 #include <sstream>
 #include "../graph_runtime.h"
+#include "../../object_internal.h"
 
 namespace tvm {
 namespace runtime {
@@ -121,7 +121,7 @@ class GraphRuntimeDebug : public GraphRuntime {
    * \param sptr_to_self Packed function pointer.
    */
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self);
+                         const ObjectPtr<Object>& sptr_to_self);
 
   /*!
    * \brief Get the node index given the name of node.
@@ -169,7 +169,7 @@ void DebugGetNodeOutput(int index, DLTensor* data_out) {
  */
 PackedFunc GraphRuntimeDebug::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   // return member functions during query.
   if (name == "get_output_by_layer") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -207,7 +207,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
 Module GraphRuntimeDebugCreate(const std::string& sym_json,
                                const tvm::runtime::Module& m,
                                const std::vector<TVMContext>& ctxs) {
-  std::shared_ptr<GraphRuntimeDebug> exec = std::make_shared<GraphRuntimeDebug>();
+  auto exec = make_object<GraphRuntimeDebug>();
   exec->Init(sym_json, m, ctxs);
   return Module(exec);
 }
@@ -222,15 +222,16 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
   });
 
 TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
-  .set_body([](TVMArgs args, TVMRetValue* rv) {
+.set_body([](TVMArgs args, TVMRetValue* rv) {
     CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
                                   "graph_runtime.remote_create is "
                                   "at least 4, but it has "
                                << args.num_args;
     void* mhandle = args[1];
+    ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
     const auto& contexts = GetAllContext(args);
     *rv = GraphRuntimeDebugCreate(
-        args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
+        args[0], GetRef<Module>(mnode), contexts);
   });
 
 }  // namespace runtime
index 38016ab..9ad10c1 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file graph_runtime.cc
  */
-#include "graph_runtime.h"
-
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
@@ -38,6 +35,9 @@
 #include <utility>
 #include <vector>
 
+#include "graph_runtime.h"
+#include "../object_internal.h"
+
 namespace tvm {
 namespace runtime {
 namespace details {
@@ -411,7 +411,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
 
 PackedFunc GraphRuntime::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   // Return member functions during query.
   if (name == "set_input") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -478,7 +478,7 @@ PackedFunc GraphRuntime::GetFunction(
 Module GraphRuntimeCreate(const std::string& sym_json,
                           const tvm::runtime::Module& m,
                           const std::vector<TVMContext>& ctxs) {
-  std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>();
+  auto exec = make_object<GraphRuntime>();
   exec->Init(sym_json, m, ctxs);
   return Module(exec);
 }
@@ -513,15 +513,17 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
   });
 
 TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
-  .set_body([](TVMArgs args, TVMRetValue* rv) {
+.set_body([](TVMArgs args, TVMRetValue* rv) {
     CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
                                   "graph_runtime.remote_create is "
                                   "at least 4, but it has "
                                << args.num_args;
     void* mhandle = args[1];
+    ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
+
     const auto& contexts = GetAllContext(args);
     *rv = GraphRuntimeCreate(
-        args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
+        args[0], GetRef<Module>(mnode), contexts);
   });
 }  // namespace runtime
 }  // namespace tvm
index e8097a8..c83d68e 100644 (file)
@@ -18,8 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
- *
  * \brief Tiny graph runtime that can run graph
  *        containing only tvm PackedFunc.
  * \file graph_runtime.h
@@ -83,7 +81,7 @@ class GraphRuntime : public ModuleNode {
    * \return The corresponding member function.
    */
   virtual PackedFunc GetFunction(const std::string& name,
-                                 const std::shared_ptr<ModuleNode>& sptr_to_self);
+                                 const ObjectPtr<Object>& sptr_to_self);
 
   /*!
    * \return The type key of the executor.
index af809d7..d9b23fc 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file metal_module.cc
  */
 #include <dmlc/memory_io.h>
@@ -54,7 +53,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+      const ObjectPtr<Object>& sptr_to_self) final;
 
   void SaveToFile(const std::string& file_name,
                   const std::string& format) final {
@@ -187,7 +186,7 @@ class MetalWrappedFunc {
  public:
   // initialize the METAL function.
   void Init(MetalModuleNode* m,
-            std::shared_ptr<ModuleNode> sptr,
+            ObjectPtr<Object> sptr,
             const std::string& func_name,
             size_t num_buffer_args,
             size_t num_pack_args,
@@ -244,7 +243,7 @@ class MetalWrappedFunc {
   // internal module
   MetalModuleNode* m_;
   // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // The name of the function.
   std::string func_name_;
   // Number of buffer arguments
@@ -260,7 +259,7 @@ class MetalWrappedFunc {
 
 PackedFunc MetalModuleNode::GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) {
+      const ObjectPtr<Object>& sptr_to_self) {
   CHECK_EQ(sptr_to_self.get(), this);
   CHECK_NE(name, symbol::tvm_module_main)
       << "Device function do not have main";
@@ -281,8 +280,7 @@ Module MetalModuleCreate(
     std::unordered_map<std::string, FunctionInfo> fmap,
     std::string source) {
   metal::MetalWorkspace::Global()->Init();
-  std::shared_ptr<MetalModuleNode> n =
-      std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
+  auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
   return Module(n);
 }
 
index 88328a2..d1df67f 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file micro_device_api.cc
  */
 
@@ -50,7 +49,7 @@ class MicroDeviceAPI final : public DeviceAPI {
                        size_t nbytes,
                        size_t alignment,
                        TVMType type_hint) final {
-    std::shared_ptr<MicroSession>& session = MicroSession::Current();
+    ObjectPtr<MicroSession>& session = MicroSession::Current();
     void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>();
     CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
     MicroDevSpace* dev_space = new MicroDevSpace();
@@ -82,11 +81,12 @@ class MicroDeviceAPI final : public DeviceAPI {
       MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
       MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
       CHECK(from_space->session == to_space->session)
-          << "attempt to copy data between different micro sessions (" << from_space->session
-          << " != " << to_space->session << ")";
+          << "attempt to copy data between different micro sessions ("
+          << from_space->session.get()
+          << " != " << to_space->session.get() << ")";
       CHECK(ctx_from.device_id == ctx_to.device_id)
         << "can only copy between the same micro device";
-      std::shared_ptr<MicroSession>& session = from_space->session;
+      ObjectPtr<MicroSession>& session = from_space->session;
       const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
 
       DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
@@ -99,7 +99,7 @@ class MicroDeviceAPI final : public DeviceAPI {
       // Reading from the device.
 
       MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
-      std::shared_ptr<MicroSession>& session = from_space->session;
+      ObjectPtr<MicroSession>& session = from_space->session;
       const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
 
       DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
@@ -109,7 +109,7 @@ class MicroDeviceAPI final : public DeviceAPI {
       // Writing to the device.
 
       MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
-      std::shared_ptr<MicroSession>& session = to_space->session;
+      ObjectPtr<MicroSession>& session = to_space->session;
       const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
 
       void* from_host_ptr = GetHostLoc(from, from_offset);
@@ -124,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI {
   }
 
   void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
-    std::shared_ptr<MicroSession>& session = MicroSession::Current();
+    ObjectPtr<MicroSession>& session = MicroSession::Current();
 
     void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>();
     CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace";
@@ -136,7 +136,7 @@ class MicroDeviceAPI final : public DeviceAPI {
 
   void FreeWorkspace(TVMContext ctx, void* data) final {
     MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data);
-    std::shared_ptr<MicroSession>& session = dev_space->session;
+    ObjectPtr<MicroSession>& session = dev_space->session;
     session->FreeInSection(SectionKind::kWorkspace,
                            DevBaseOffset(reinterpret_cast<std::uintptr_t>(dev_space->data)));
     delete dev_space;
index 85cd359..e66c45b 100644 (file)
@@ -18,9 +18,8 @@
  */
 
 /*!
-*  Copyright (c) 2019 by Contributors
-* \file micro_module.cc
-*/
+ * \file micro_module.cc
+ */
 
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/c_runtime_api.h>
@@ -48,7 +47,7 @@ class MicroModuleNode final : public ModuleNode {
   }
 
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+                         const ObjectPtr<Object>& sptr_to_self) final;
 
   /*!
    * \brief initializes module by establishing device connection and loads binary
@@ -76,13 +75,13 @@ class MicroModuleNode final : public ModuleNode {
   /*! \brief path to module binary */
   std::string binary_path_;
   /*! \brief global session pointer */
-  std::shared_ptr<MicroSession> session_;
+  ObjectPtr<MicroSession> session_;
 };
 
 class MicroWrappedFunc {
  public:
   MicroWrappedFunc(MicroModuleNode* m,
-                   std::shared_ptr<MicroSession> session,
+                   ObjectPtr<MicroSession> session,
                    const std::string& func_name,
                    DevBaseOffset func_offset) {
     m_ = m;
@@ -99,7 +98,7 @@ class MicroWrappedFunc {
   /*! \brief internal module */
   MicroModuleNode* m_;
   /*! \brief reference to the session for this function (to keep the session alive) */
-  std::shared_ptr<MicroSession> session_;
+  ObjectPtr<MicroSession> session_;
   /*! \brief name of the function */
   std::string func_name_;
   /*! \brief offset of the function to be called */
@@ -108,7 +107,7 @@ class MicroWrappedFunc {
 
 PackedFunc MicroModuleNode::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   DevBaseOffset func_offset =
       session_->low_level_device()->ToDevOffset(binary_info_.symbol_map[name]);
   MicroWrappedFunc f(this, session_, name, func_offset);
@@ -118,9 +117,9 @@ PackedFunc MicroModuleNode::GetFunction(
 // register loadfile function to load module from Python frontend
 TVM_REGISTER_GLOBAL("module.loadfile_micro_dev")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
-    std::shared_ptr<MicroModuleNode> n = std::make_shared<MicroModuleNode>();
+    auto n = make_object<MicroModuleNode>();
     n->InitMicroModule(args[0]);
     *rv = runtime::Module(n);
-    });
+  });
 }  // namespace runtime
 }  // namespace tvm
index 9790154..febf726 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file micro_session.cc
  */
 
 #include <dmlc/thread_local.h>
 #include <tvm/runtime/registry.h>
-#include <memory>
 #include <stack>
 #include <tuple>
 #include <vector>
@@ -36,18 +34,18 @@ namespace tvm {
 namespace runtime {
 
 struct TVMMicroSessionThreadLocalEntry {
-  std::stack<std::shared_ptr<MicroSession>> session_stack;
+  std::stack<ObjectPtr<MicroSession>> session_stack;
 };
 
 typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore;
 
-std::shared_ptr<MicroSession>& MicroSession::Current() {
+ObjectPtr<MicroSession>& MicroSession::Current() {
   TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
   CHECK_GT(entry->session_stack.size(), 0) << "No current session";
   return entry->session_stack.top();
 }
 
-void MicroSession::EnterWithScope(std::shared_ptr<MicroSession> session) {
+void MicroSession::EnterWithScope(ObjectPtr<MicroSession> session) {
   TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
   entry->session_stack.push(session);
 }
@@ -121,7 +119,7 @@ void MicroSession::CreateSession(const std::string& device_type,
 void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) {
   int32_t (*func_dev_addr)(void*, void*, int32_t) =
       reinterpret_cast<int32_t (*)(void*, void*, int32_t)>(
-      low_level_device()->ToDevPtr(func).value());
+          low_level_device()->ToDevPtr(func).value());
 
   // Create an allocator stream for the memory region after the most recent
   // allocation in the args section.
@@ -355,10 +353,10 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map,
 
 PackedFunc MicroSession::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   if (name == "enter") {
-    return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
-      MicroSession::EnterWithScope(std::dynamic_pointer_cast<MicroSession>(sptr_to_self));
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      MicroSession::EnterWithScope(GetObjectPtr<MicroSession>(this));
     });
   } else if (name == "exit") {
     return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
@@ -378,7 +376,7 @@ TVM_REGISTER_GLOBAL("micro._CreateSession")
     uint64_t base_addr = args[3];
     const std::string& server_addr = args[4];
     int port = args[5];
-    std::shared_ptr<MicroSession> session = std::make_shared<MicroSession>();
+    ObjectPtr<MicroSession> session = make_object<MicroSession>();
     session->CreateSession(
         device_type, binary_path, toolchain_prefix, base_addr, server_addr, port);
     *rv = Module(session);
index 1400f74..65b6421 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file micro_session.h
  * \brief session to manage multiple micro modules
  *
@@ -66,7 +65,7 @@ class MicroSession : public ModuleNode {
    * \return The corresponding member function.
    */
   virtual PackedFunc GetFunction(const std::string& name,
-                                 const std::shared_ptr<ModuleNode>& sptr_to_self);
+                                 const ObjectPtr<Object>& sptr_to_self);
 
   /*!
    * \return The type key of the executor.
@@ -85,7 +84,7 @@ class MicroSession : public ModuleNode {
    */
   ~MicroSession();
 
-  static std::shared_ptr<MicroSession>& Current();
+  static ObjectPtr<MicroSession>& Current();
 
   /*!
    * \brief creates session by setting up a low-level device and initting allocators for it
@@ -240,7 +239,7 @@ class MicroSession : public ModuleNode {
     * \brief Push a new session context onto the thread-local stack.
     *  The session on top of the stack is used as the current global session.
     */
-  static void EnterWithScope(std::shared_ptr<MicroSession> session);
+  static void EnterWithScope(ObjectPtr<MicroSession> session);
   /*!
     * \brief Pop a session off the thread-local context stack,
     *  restoring the previous session as the current context.
@@ -258,7 +257,7 @@ struct MicroDevSpace {
   /*! \brief data being wrapped */
   void* data;
   /*! \brief shared ptr to session where this data is valid */
-  std::shared_ptr<MicroSession> session;
+  ObjectPtr<MicroSession> session;
 };
 
 }  // namespace runtime
index 80ce185..2123312 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file tcl_socket.h
  * \brief TCP socket wrapper for communicating using Tcl commands
  */
index c0acb31..161675c 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file module.cc
  * \brief TVM module system
  */
 namespace tvm {
 namespace runtime {
 
-void Module::Import(Module other) {
+void ModuleNode::Import(Module other) {
   // specially handle rpc
-  if (!std::strcmp((*this)->type_key(), "rpc")) {
+  if (!std::strcmp(this->type_key(), "rpc")) {
     static const PackedFunc* fimport_ = nullptr;
     if (fimport_ == nullptr) {
       fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
       CHECK(fimport_ != nullptr);
     }
-    (*fimport_)(*this, other);
+    (*fimport_)(GetRef<Module>(this), other);
     return;
   }
   // cyclic detection.
-  std::unordered_set<const ModuleNode*> visited{other.node_.get()};
-  std::vector<const ModuleNode*> stack{other.node_.get()};
+  std::unordered_set<const ModuleNode*> visited{other.operator->()};
+  std::vector<const ModuleNode*> stack{other.operator->()};
   while (!stack.empty()) {
     const ModuleNode* n = stack.back();
     stack.pop_back();
     for (const Module& m : n->imports_) {
-      const ModuleNode* next = m.node_.get();
+      const ModuleNode* next = m.operator->();
       if (visited.count(next)) continue;
       visited.insert(next);
       stack.push_back(next);
     }
   }
-  CHECK(!visited.count(node_.get()))
+  CHECK(!visited.count(this))
       << "Cyclic dependency detected during import";
-  node_->imports_.emplace_back(std::move(other));
+  this->imports_.emplace_back(std::move(other));
+}
+
+PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
+  ModuleNode* self = this;
+  PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
+  if (pf != nullptr) return pf;
+  if (query_imports) {
+    for (Module& m : self->imports_) {
+      pf = m->GetFunction(name, m.data_);
+      if (pf != nullptr) return pf;
+    }
+  }
+  return pf;
 }
 
 Module Module::LoadFromFile(const std::string& file_name,
index 456d282..445bfd3 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file module_util.cc
  * \brief Utilities for module.
  */
@@ -64,7 +63,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
 }
 
 PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
-                          const std::shared_ptr<ModuleNode>& sptr_to_self) {
+                          const ObjectPtr<Object>& sptr_to_self) {
   return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
       int ret = (*faddr)(
           const_cast<TVMValue*>(args.values),
index e5bbfe3..5f56c15 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file module_util.h
  * \brief Helper utilities for module building
  */
@@ -45,7 +44,7 @@ namespace runtime {
  * \param faddr The function address
  * \param mptr The module pointer node.
  */
-PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
+PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
 /*!
  * \brief Load and append module blob to module list
  * \param mblob The module blob.
index 5d71c2f..7a8aef8 100644 (file)
@@ -27,6 +27,7 @@
 #include <vector>
 #include <utility>
 #include <unordered_map>
+#include "object_internal.h"
 #include "runtime_base.h"
 
 namespace tvm {
@@ -200,18 +201,6 @@ uint32_t Object::TypeKey2Index(const std::string& key) {
   return TypeContext::Global()->TypeKey2Index(key);
 }
 
-class TVMObjectCAPI {
- public:
-  static void Free(TVMObjectHandle obj) {
-    if (obj != nullptr) {
-      static_cast<Object*>(obj)->DecRef();
-    }
-  }
-
-  static uint32_t TypeKey2Index(const std::string& type_key) {
-    return Object::TypeKey2Index(type_key);
-  }
-};
 }  // namespace runtime
 }  // namespace tvm
 
@@ -224,13 +213,13 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
 
 int TVMObjectFree(TVMObjectHandle obj) {
   API_BEGIN();
-  tvm::runtime::TVMObjectCAPI::Free(obj);
+  tvm::runtime::ObjectInternal::ObjectFree(obj);
   API_END();
 }
 
 int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
   API_BEGIN();
-  out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
+  out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(
       type_key);
   API_END();
 }
diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h
new file mode 100644 (file)
index 0000000..7955130
--- /dev/null
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/runtime/object_internal.h
+ * \brief Expose a few functions for CFFI purposes.
+ *        This file is not intended to be used
+ */
+#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_
+#define TVM_RUNTIME_OBJECT_INTERNAL_H_
+
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/module.h>
+#include <string>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Internal object namespace to expose
+ *        certain util functions for FFI.
+ */
+class ObjectInternal {
+ public:
+  /*!
+   * \brief Free an object handle.
+   */
+  static void ObjectFree(TVMObjectHandle obj) {
+    if (obj != nullptr) {
+      static_cast<Object*>(obj)->DecRef();
+    }
+  }
+  /*!
+   * \brief Expose TypeKey2Index
+   * \param type_key The original type key.
+   * \return the corresponding index.
+   */
+  static uint32_t ObjectTypeKey2Index(const std::string& type_key) {
+    return Object::TypeKey2Index(type_key);
+  }
+  /*!
+   * \brief Convert ModuleHandle to module node pointer.
+   * \param handle The module handle.
+   * \return the corresponding module node pointer.
+   */
+  static ModuleNode* GetModuleNode(TVMModuleHandle handle) {
+    // NOTE: we will need to convert to Object
+    // then to ModuleNode in order to get the correct
+    // address translation
+    return static_cast<ModuleNode*>(static_cast<Object*>(handle));
+  }
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif   // TVM_RUNTIME_OBJECT_INTERNAL_H_
index 48a6b8e..d9251f8 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file aocl_common.h
  * \brief AOCL common header
  */
index 2442c4d..84c29ee 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file aocl_device_api.cc
  */
 #include <tvm/runtime/registry.h>
index 2e8322f..70955cc 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
  * \file aocl_module.h
  * \brief Execution handling of OpenCL kernels for AOCL
  */
index ab84eef..bd93473 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -278,7 +278,7 @@ class OpenCLModuleNode : public ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+      const ObjectPtr<Object>& sptr_to_self) final;
   void SaveToFile(const std::string& file_name,
                   const std::string& format) final;
   void SaveToBinary(dmlc::Stream* stream) final;
index 971ae34..24687db 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file opencl_module.cc
  */
 #include <dmlc/memory_io.h>
@@ -36,7 +35,7 @@ class OpenCLWrappedFunc {
  public:
   // initialize the OpenCL function.
   void Init(OpenCLModuleNode* m,
-            std::shared_ptr<ModuleNode> sptr,
+            ObjectPtr<Object> sptr,
             OpenCLModuleNode::KTRefEntry entry,
             std::string func_name,
             std::vector<size_t> arg_size,
@@ -88,7 +87,7 @@ class OpenCLWrappedFunc {
   // The module
   OpenCLModuleNode* m_;
   // resource handle
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // global kernel id in the kernel table.
   OpenCLModuleNode::KTRefEntry entry_;
   // The name of the function.
@@ -122,7 +121,7 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace
 
 PackedFunc OpenCLModuleNode::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   CHECK_EQ(sptr_to_self.get(), this);
   CHECK_NE(name, symbol::tvm_module_main)
       << "Device function do not have main";
@@ -251,8 +250,7 @@ Module OpenCLModuleCreate(
     std::string fmt,
     std::unordered_map<std::string, FunctionInfo> fmap,
     std::string source) {
-  std::shared_ptr<OpenCLModuleNode> n =
-      std::make_shared<OpenCLModuleNode>(data, fmt, fmap, source);
+  auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
   n->Init();
   return Module(n);
 }
index cd63382..3b7ebb9 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file opencl_module.h
  * \brief Execution handling of OPENCL kernels
  */
index 9a1f774..0d3f953 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -44,7 +44,7 @@ class OpenGLModuleNode final : public ModuleNode {
   const char* type_key() const final { return "opengl"; }
 
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+                         const ObjectPtr<Object>& sptr_to_self) final;
 
   std::string GetSource(const std::string& format) final;
 
@@ -74,7 +74,7 @@ class OpenGLModuleNode final : public ModuleNode {
 class OpenGLWrappedFunc {
  public:
   OpenGLWrappedFunc(OpenGLModuleNode* m,
-                    std::shared_ptr<ModuleNode> sptr,
+                    ObjectPtr<Object> sptr,
                     std::string func_name,
                     std::vector<size_t> arg_size,
                     const std::vector<std::string>& thread_axis_tags);
@@ -85,7 +85,7 @@ class OpenGLWrappedFunc {
   // The module
   OpenGLModuleNode* m_;
   // resource handle
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // The name of the function.
   std::string func_name_;
   // convert code for void argument
@@ -111,7 +111,7 @@ OpenGLModuleNode::OpenGLModuleNode(
 
 PackedFunc OpenGLModuleNode::GetFunction(
     const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   CHECK_EQ(sptr_to_self.get(), this);
   CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
 
@@ -191,7 +191,7 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(
 
 OpenGLWrappedFunc::OpenGLWrappedFunc(
     OpenGLModuleNode* m,
-    std::shared_ptr<ModuleNode> sptr,
+    ObjectPtr<Object> sptr,
     std::string func_name,
     std::vector<size_t> arg_size,
     const std::vector<std::string>& thread_axis_tags)
@@ -251,9 +251,9 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
 Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
                           std::string fmt,
                           std::unordered_map<std::string, FunctionInfo> fmap) {
-  auto n = std::make_shared<OpenGLModuleNode>(std::move(shaders),
-                                              std::move(fmt),
-                                              std::move(fmap));
+  auto n = make_object<OpenGLModuleNode>(std::move(shaders),
+                                         std::move(fmt),
+                                         std::move(fmap));
   return Module(n);
 }
 
index b4459ae..f1b712e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 96d1948..c2bea8a 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file rocm_module.cc
  */
 #include <tvm/runtime/registry.h>
@@ -68,7 +67,7 @@ class ROCMModuleNode : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+      const ObjectPtr<Object>& sptr_to_self) final;
 
 
   void SaveToFile(const std::string& file_name,
@@ -158,7 +157,7 @@ class ROCMWrappedFunc {
  public:
   // initialize the ROCM function.
   void Init(ROCMModuleNode* m,
-            std::shared_ptr<ModuleNode> sptr,
+            ObjectPtr<Object> sptr,
             const std::string& func_name,
             size_t num_void_args,
             const std::vector<std::string>& thread_axis_tags) {
@@ -204,7 +203,7 @@ class ROCMWrappedFunc {
   // internal module
   ROCMModuleNode* m_;
   // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // The name of the function.
   std::string func_name_;
   // Device function cache per device.
@@ -217,7 +216,7 @@ class ROCMWrappedFunc {
 
 PackedFunc ROCMModuleNode::GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) {
+      const ObjectPtr<Object>& sptr_to_self) {
   CHECK_EQ(sptr_to_self.get(), this);
   CHECK_NE(name, symbol::tvm_module_main)
       << "Device function do not have main";
@@ -235,8 +234,7 @@ Module ROCMModuleCreate(
     std::unordered_map<std::string, FunctionInfo> fmap,
     std::string hip_source,
     std::string assembly) {
-  std::shared_ptr<ROCMModuleNode> n =
-    std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
+  auto n = make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
   return Module(n);
 }
 
index 2e1e64a..8c4486e 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -123,7 +123,7 @@ class RPCModuleNode final : public ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     RPCFuncHandle handle = GetFuncHandle(name);
     return WrapRemote(handle);
   }
@@ -195,8 +195,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
         return wf->operator()(args, rv);
       });
   } else if (tcode == kModuleHandle) {
-    std::shared_ptr<RPCModuleNode> n =
-        std::make_shared<RPCModuleNode>(handle, sess);
+    auto n = make_object<RPCModuleNode>(handle, sess);
     *rv = Module(n);
   } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
     CHECK_EQ(args.size(), 2);
@@ -209,8 +208,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
 }
 
 Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
-  std::shared_ptr<RPCModuleNode> n =
-      std::make_shared<RPCModuleNode>(nullptr, sess);
+  auto n = make_object<RPCModuleNode>(nullptr, sess);
   return Module(n);
 }
 
@@ -237,8 +235,7 @@ TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
     CHECK_EQ(tkey, "rpc");
     auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
     void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
-    std::shared_ptr<RPCModuleNode> n =
-        std::make_shared<RPCModuleNode>(mhandle, sess);
+    auto n = make_object<RPCModuleNode>(mhandle, sess);
     *rv = Module(n);
   });
 
index 39db150..b5fec10 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,6 +35,7 @@
 #include <cmath>
 #include <algorithm>
 #include "rpc_session.h"
+#include "../object_internal.h"
 #include "../../common/ring_buffer.h"
 #include "../../common/socket.h"
 
@@ -1119,25 +1120,29 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
   }
   std::string file_name = args[0];
   TVMRetValue ret = (*fsys_load_)(file_name);
-  Module m = ret;
-  *rv = static_cast<void*>(new Module(m));
+  // pass via void*
+  TVMValue value;
+  int rcode;
+  ret.MoveToCHost(&value, &rcode);
+  CHECK_EQ(rcode, kModuleHandle);
+  *rv = static_cast<void*>(value.v_handle);
 }
 
 void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
   void* pmod = args[0];
   void* cmod = args[1];
-  static_cast<Module*>(pmod)->Import(
-      *static_cast<Module*>(cmod));
+  ObjectInternal::GetModuleNode(pmod)->Import(
+      GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
 }
 
 void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
   void* mhandle = args[0];
-  delete static_cast<Module*>(mhandle);
+  ObjectInternal::ObjectFree(mhandle);
 }
 
 void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
   void* mhandle = args[0];
-  PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction(
+  PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
       args[1], false);
   if (pf != nullptr) {
     *rv = static_cast<void*>(new PackedFunc(pf));
@@ -1149,7 +1154,7 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
 void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
   void* mhandle = args[0];
   std::string fmt = args[1];
-  *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
+  *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
 }
 
 void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
index 3518455..ab5f16d 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file rpc_session.h
  * \brief Base RPC session interface.
  */
index fe6913e..07014a6 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * Implementation stack VM.
  * \file stackvm.cc
  */
index 4e7d422..4f86d07 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file stackvm_module.cc
  */
 #include <tvm/runtime/registry.h>
@@ -42,7 +41,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     if (name == runtime::symbol::tvm_module_main) {
       return GetFunction(entry_func_, sptr_to_self);
     }
@@ -89,8 +88,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
 
   static Module Create(std::unordered_map<std::string, StackVM> fmap,
                        std::string entry_func) {
-    std::shared_ptr<StackVMModuleNode> n =
-        std::make_shared<StackVMModuleNode>();
+    auto n = make_object<StackVMModuleNode>();
     n->fmap_ = std::move(fmap);
     n->entry_func_ = std::move(entry_func);
     return Module(n);
@@ -101,8 +99,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
     std::string entry_func, data;
     strm->Read(&fmap);
     strm->Read(&entry_func);
-    std::shared_ptr<StackVMModuleNode> n =
-        std::make_shared<StackVMModuleNode>();
+    auto n = make_object<StackVMModuleNode>();
     n->fmap_ = std::move(fmap);
     n->entry_func_ = std::move(entry_func);
     uint64_t num_imports;
index 247fae1..8a75a36 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  */
 
 /*!
- *  Copyright (c) 2017 by Contributors
  * \file system_lib_module.cc
  * \brief SystemLib module.
  */
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/memory.h>
 #include <tvm/runtime/c_backend_api.h>
 #include <mutex>
 #include "module_util.h"
@@ -40,7 +40,7 @@ class SystemLibModuleNode : public ModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     std::lock_guard<std::mutex> lock(mutex_);
 
     if (module_blob_ != nullptr) {
@@ -83,9 +83,8 @@ class SystemLibModuleNode : public ModuleNode {
     }
   }
 
-  static const std::shared_ptr<SystemLibModuleNode>& Global() {
-    static std::shared_ptr<SystemLibModuleNode> inst =
-        std::make_shared<SystemLibModuleNode>();
+  static const ObjectPtr<SystemLibModuleNode>& Global() {
+    static auto inst = make_object<SystemLibModuleNode>();
     return inst;
   }
 
index 4c4554c..2aeecc5 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file tvm/runtime/vm/executable.cc
  * \brief The implementation of a virtual machine executable APIs.
  */
@@ -51,7 +50,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr);
 Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
 
 PackedFunc Executable::GetFunction(const std::string& name,
-    const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const ObjectPtr<Object>& sptr_to_self) {
   if (name == "get_lib") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       *rv = this->GetLib();
@@ -440,7 +439,7 @@ void LoadHeader(dmlc::Stream* strm) {
 }
 
 runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
-  std::shared_ptr<Executable> exec = std::make_shared<Executable>();
+  auto exec = make_object<Executable>();
   exec->lib = lib;
   exec->code_ = code;
   dmlc::MemoryStringStream strm(&exec->code_);
index 821de0b..ed6cddb 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/runtime/vm/profiler/vm.cc
  * \brief The Relay debug virtual machine.
  */
@@ -41,7 +40,7 @@ namespace runtime {
 namespace vm {
 
 PackedFunc VirtualMachineDebug::GetFunction(
-    const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) {
+    const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
   if (name == "get_stat") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       double total_duration = 0.0;
@@ -124,7 +123,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
 }
 
 runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
-  std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
+  auto vm = make_object<VirtualMachineDebug>();
   vm->LoadExecutable(exec);
   return runtime::Module(vm);
 }
index ff3296c..2e95a07 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file src/runtime/vm/profiler/vm.h
  * \brief The Relay debug virtual machine.
  */
@@ -42,7 +41,7 @@ class VirtualMachineDebug : public VirtualMachine {
   VirtualMachineDebug() : VirtualMachine() {}
 
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final;
+                         const ObjectPtr<Object>& sptr_to_self) final;
 
   void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
                     Index output_size, const std::vector<ObjectRef>& args) final;
index 05935b7..463c575 100644 (file)
@@ -627,7 +627,7 @@ ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
 }
 
 PackedFunc VirtualMachine::GetFunction(const std::string& name,
-                                       const std::shared_ptr<ModuleNode>& sptr_to_self) {
+                                       const ObjectPtr<Object>& sptr_to_self) {
   if (name == "invoke") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       CHECK(exec) << "The executable is not created yet.";
@@ -1052,7 +1052,7 @@ void VirtualMachine::RunLoop() {
 }
 
 runtime::Module CreateVirtualMachine(const Executable* exec) {
-  std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>();
+  auto vm = make_object<VirtualMachine>();
   vm->LoadExecutable(exec);
   return runtime::Module(vm);
 }
index e3b2ac8..daf4ae7 100644 (file)
@@ -663,7 +663,9 @@ class VulkanModuleNode;
 // a wrapped function class to get packed func.
 class VulkanWrappedFunc {
  public:
-  void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name,
+  void Init(VulkanModuleNode* m,
+            ObjectPtr<Object> sptr,
+            const std::string& func_name,
             size_t num_buffer_args, size_t num_pack_args,
             const std::vector<std::string>& thread_axis_tags) {
     m_ = m;
@@ -680,7 +682,7 @@ class VulkanWrappedFunc {
   // internal module
   VulkanModuleNode* m_;
   // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
+  ObjectPtr<Object> sptr_;
   // v The name of the function.
   std::string func_name_;
   // Number of buffer arguments
@@ -705,7 +707,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
   const char* type_key() const final { return "vulkan"; }
 
   PackedFunc GetFunction(const std::string& name,
-                         const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+                         const ObjectPtr<Object>& sptr_to_self) final {
     CHECK_EQ(sptr_to_self.get(), this);
     CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
     auto it = fmap_.find(name);
@@ -939,7 +941,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
 
 Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
                           std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
-  std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source);
+  auto n = make_object<VulkanModuleNode>(smap, fmap, source);
   return Module(n);
 }
 
index 6ef6af8..27161c4 100644 (file)
@@ -226,7 +226,7 @@ class DPIModule final : public DPIModuleNode {
 
   PackedFunc GetFunction(
       const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+      const ObjectPtr<Object>& sptr_to_self) final {
     if (name == "WriteReg") {
       return TypedPackedFunc<void(int, int)>(
           [this](int addr, int value){
@@ -413,8 +413,7 @@ class DPIModule final : public DPIModuleNode {
 };
 
 Module DPIModuleNode::Load(std::string dll_name) {
-  std::shared_ptr<DPIModule> n =
-      std::make_shared<DPIModule>();
+  auto n = make_object<DPIModule>();
   n->Init(dll_name);
   return Module(n);
 }
index 12bc53c..91f10c0 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -31,6 +31,7 @@
 #include "../src/runtime/system_lib_module.cc"
 #include "../src/runtime/module.cc"
 #include "../src/runtime/ndarray.cc"
+#include "../src/runtime/object.cc"
 #include "../src/runtime/registry.cc"
 #include "../src/runtime/file_util.cc"
 #include "../src/runtime/dso_module.cc"