[RUNTIME] Improve PackedFunc robustness (#5517)
authorTianqi Chen <tqchen@users.noreply.github.com>
Wed, 6 May 2020 18:50:09 +0000 (11:50 -0700)
committerGitHub <noreply@github.com>
Wed, 6 May 2020 18:50:09 +0000 (11:50 -0700)
* [RUNTIME] Improve PackedFunc robustness

- Add static assert to warn about unsupported type deduction.
- Always inline template expansions for PackedFunc calls.

* Fix style issue

include/tvm/runtime/packed_func.h
src/ir/op.cc
src/relay/quantize/quantize.cc

index cf6d5fa..0726292 100644 (file)
 #define TVM_RUNTIME_HEADER_ONLY 0
 #endif
 
+// Always inline macro only use in template
+// expansion cases where we know inline is important.
+#ifdef _MSC_VER
+#define TVM_ALWAYS_INLINE __forceinline inline
+#else
+#define TVM_ALWAYS_INLINE inline __attribute__((always_inline))
+#endif
+
 namespace tvm {
 namespace runtime {
 
@@ -273,7 +281,7 @@ class TypedPackedFunc<R(Args...)> {
    * \param args The arguments
    * \returns The return value.
    */
-  inline R operator()(Args ...args) const;
+  TVM_ALWAYS_INLINE R operator()(Args ...args) const;
   /*!
    * \brief convert to PackedFunc
    * \return the internal PackedFunc
@@ -1076,11 +1084,15 @@ struct func_signature_helper {
 template<typename T, typename R, typename ...Args>
 struct func_signature_helper<R (T::*)(Args...)> {
   using FType = R(Args...);
+  static_assert(!std::is_reference<R>::value,
+                "TypedPackedFunc return reference");
 };
 
 template<typename T, typename R, typename ...Args>
 struct func_signature_helper<R (T::*)(Args...) const> {
   using FType = R(Args...);
+  static_assert(!std::is_reference<R>::value,
+                "TypedPackedFunc return reference");
 };
 
 /*!
@@ -1096,12 +1108,16 @@ struct function_signature {
 template<typename R, typename ...Args>
 struct function_signature<R(Args...)> {
   using FType = R(Args...);
+  static_assert(!std::is_reference<R>::value,
+                "TypedPackedFunc return reference");
 };
 
 // handle case of function ptr.
 template<typename R, typename ...Args>
 struct function_signature<R (*)(Args...)> {
   using FType = R(Args...);
+  static_assert(!std::is_reference<R>::value,
+                "TypedPackedFunc return reference");
 };
 }  // namespace detail
 
@@ -1114,66 +1130,66 @@ class TVMArgsSetter {
   template<typename T,
            typename = typename std::enable_if<
              std::is_integral<T>::value>::type>
-  void operator()(size_t i, T value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, T value) const {
     values_[i].v_int64 = static_cast<int64_t>(value);
     type_codes_[i] = kDLInt;
   }
-  void operator()(size_t i, uint64_t value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
     values_[i].v_int64 = static_cast<int64_t>(value);
     CHECK_LE(value,
              static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
     type_codes_[i] = kDLInt;
   }
-  void operator()(size_t i, double value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, double value) const {
     values_[i].v_float64 = value;
     type_codes_[i] = kDLFloat;
   }
-  void operator()(size_t i, std::nullptr_t value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const {
     values_[i].v_handle = value;
     type_codes_[i] = kTVMNullptr;
   }
-  void operator()(size_t i, const TVMArgValue& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const {
     values_[i] = value.value_;
     type_codes_[i] = value.type_code_;
   }
-  void operator()(size_t i, void* value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const {
     values_[i].v_handle = value;
     type_codes_[i] = kTVMOpaqueHandle;
   }
-  void operator()(size_t i, DLTensor* value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const {
     values_[i].v_handle = value;
     type_codes_[i] = kTVMDLTensorHandle;
   }
-  void operator()(size_t i, TVMContext value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const {
     values_[i].v_ctx = value;
     type_codes_[i] = kTVMContext;
   }
-  void operator()(size_t i, DLDataType value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const {
     values_[i].v_type = value;
     type_codes_[i] = kTVMDataType;
   }
-  void operator()(size_t i, DataType dtype) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const {
     operator()(i, dtype.operator DLDataType());
   }
-  void operator()(size_t i, const char* value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const {
     values_[i].v_str = value;
     type_codes_[i] = kTVMStr;
   }
   // setters for container types
-  void operator()(size_t i, const std::string& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const {
     values_[i].v_str = value.c_str();
     type_codes_[i] = kTVMStr;
   }
-  void operator()(size_t i, const TVMByteArray& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const {
     values_[i].v_handle = const_cast<TVMByteArray*>(&value);
     type_codes_[i] = kTVMBytes;
   }
-  void operator()(size_t i, const PackedFunc& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
     values_[i].v_handle = const_cast<PackedFunc*>(&value);
     type_codes_[i] = kTVMPackedFuncHandle;
   }
   template<typename FType>
-  void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
     operator()(i, value.packed());
   }
   void operator()(size_t i, const TVMRetValue& value) const {
@@ -1191,7 +1207,7 @@ class TVMArgsSetter {
            typename = typename std::enable_if<
              std::is_base_of<ObjectRef, TObjectRef>::value>
            ::type>
-  void operator()(size_t i, const TObjectRef& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const {
     this->SetObject(i, value);
   }
 
@@ -1200,7 +1216,7 @@ class TVMArgsSetter {
              std::is_base_of<ObjectRef,
                              typename std::remove_reference<TObjectRef>::type>::value>
            ::type>
-  void operator()(size_t i, TObjectRef&& value) const {
+  TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const {
     this->SetObject(i, std::forward<TObjectRef>(value));
   }
 
@@ -1230,10 +1246,10 @@ namespace detail {
 template<typename R, int nleft, int index, typename F>
 struct unpack_call_dispatcher {
   template<typename ...Args>
-  static void run(const F& f,
-                  const TVMArgs& args_pack,
-                  TVMRetValue* rv,
-                  Args&&... unpacked_args) {
+  TVM_ALWAYS_INLINE static void run(const F& f,
+                                    const TVMArgs& args_pack,
+                                    TVMRetValue* rv,
+                                    Args&&... unpacked_args) {
     // construct a movable argument value
     // which allows potential move of argument to the input of F.
     unpack_call_dispatcher<R, nleft - 1, index + 1, F>
@@ -1247,27 +1263,33 @@ struct unpack_call_dispatcher {
 template<typename R, int index, typename F>
 struct unpack_call_dispatcher<R, 0, index, F> {
   template<typename ...Args>
-  static void run(const F& f,
-                  const TVMArgs& args_pack,
-                  TVMRetValue* rv,
-                  Args&&... unpacked_args) {
-    *rv = R(f(std::forward<Args>(unpacked_args)...));
+  TVM_ALWAYS_INLINE static void run(const F& f,
+                                    const TVMArgs& args_pack,
+                                    TVMRetValue* rv,
+                                    Args&&... unpacked_args) {
+    using RetType = decltype(f(std::forward<Args>(unpacked_args)...));
+    if (std::is_same<RetType, R>::value) {
+      *rv = f(std::forward<Args>(unpacked_args)...);
+    } else {
+      *rv = R(f(std::forward<Args>(unpacked_args)...));
+    }
   }
 };
 
 template<int index, typename F>
 struct unpack_call_dispatcher<void, 0, index, F> {
   template<typename ...Args>
-  static void run(const F& f,
-                  const TVMArgs& args_pack,
-                  TVMRetValue* rv,
-                  Args&&... unpacked_args) {
+  TVM_ALWAYS_INLINE static void run(const F& f,
+                                    const TVMArgs& args_pack,
+                                    TVMRetValue* rv,
+                                    Args&&... unpacked_args) {
     f(std::forward<Args>(unpacked_args)...);
   }
 };
 
 template<typename R, int nargs, typename F>
-inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
+TVM_ALWAYS_INLINE void unpack_call(
+    const F& f, const TVMArgs& args, TVMRetValue* rv) {
   CHECK_EQ(nargs, args.size())
       << "Expect " << nargs << " arguments but get " << args.size();
   unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
@@ -1280,22 +1302,23 @@ struct unpack_call_by_signature {
 template<typename R, typename ...Args>
 struct unpack_call_by_signature<R(Args...)> {
   template<typename F>
-  static void run(const F& f,
-                  const TVMArgs& args,
-                  TVMRetValue* rv) {
+  TVM_ALWAYS_INLINE static void run(
+      const F& f,
+      const TVMArgs& args,
+      TVMRetValue* rv) {
     unpack_call<R, sizeof...(Args)>(f, args, rv);
   }
 };
 
 template<typename R, typename ...Args>
-inline R call_packed(const PackedFunc& pf, Args&& ...args) {
+TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&& ...args) {
   return R(pf(std::forward<Args>(args)...));
 }
 
 template<typename R>
 struct typed_packed_call_dispatcher {
   template<typename ...Args>
-  static inline R run(const PackedFunc& pf, Args&& ...args) {
+  TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&& ...args) {
     return pf(std::forward<Args>(args)...);
   }
 };
@@ -1303,7 +1326,7 @@ struct typed_packed_call_dispatcher {
 template<>
 struct typed_packed_call_dispatcher<void> {
   template<typename ...Args>
-  static inline void run(const PackedFunc& pf, Args&& ...args) {
+  TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&& ...args) {
     pf(std::forward<Args>(args)...);
   }
 };
@@ -1334,7 +1357,7 @@ inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
 }
 
 template<typename R, typename ...Args>
-inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
+TVM_ALWAYS_INLINE R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
   return detail::typed_packed_call_dispatcher<R>
       ::run(packed_, std::forward<Args>(args)...);
 }
index b024165..bd8a6e2 100644 (file)
@@ -148,7 +148,10 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
     return ret;
   });
 
-TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get);
+TVM_REGISTER_GLOBAL("relay.op._GetOp")
+.set_body_typed([](std::string name) -> Op {
+  return Op::Get(name);
+});
 
 TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
index 631d8c0..431e18b 100644 (file)
@@ -135,7 +135,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 });
 
 TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig")
-.set_body_typed(QConfig::Current);
+.set_body_typed([]() -> QConfig {
+  return QConfig::Current();
+});
 
 TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope")
 .set_body_typed(QConfig::EnterQConfigScope);