#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
namespace detail {
+/// Checks the static C++ type `T` for correctness to catch common error cases.
+template <typename T>
+void checkStaticTypes() {
+ // Give nice error messages for some of the common error cases.
+ // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
+ static_assert(
+ !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
+ "INVALID TYPE: Only int64_t is supported as an integral argument type");
+ static_assert(
+ !std::is_same<T, float>::value,
+ "INVALID TYPE: float is not supported as an argument type, use double instead");
+}
+
+template <typename First, typename Second, typename... Rest>
+void checkStaticTypes() {
+ checkStaticTypes<First>();
+ checkStaticTypes<Second, Rest...>();
+}
+
template <typename... Ts, size_t... Is>
std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
+ checkStaticTypes<decay_t<Ts>...>();
// Arguments are named "_<index>"
return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
}
/// Create a single-element `vector` for simple (non-tuple) return types.
template <typename ReturnType>
std::vector<Argument> createReturns(ReturnType*) {
+ checkStaticTypes<decay_t<ReturnType>>();
return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
}
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
using ReturnType = typename FunctionTraits::return_type;
+
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+
return {name, arguments, returns};
}