class OperatorKernel : public KernelCache {};
namespace detail {
+ // supported_primitive_arg_types defines which primitive types we allow in
+ // kernel functions as arguments or returns.
+ // Additionally, we support lists, dicts and optionals containing these types.
+ using supported_primitive_arg_types = guts::typelist::typelist<
+ int64_t,
+ double,
+ bool,
+ std::string,
+ at::Tensor,
+ at::Scalar
+ >;
+
// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
// cast it to the type that should be passed to the kernel function.
// Examples: If the IValue contains a plain type like an int, return that.
// If the IValue contains an IntList, return it as ArrayRef<int>.
// TODO Should we move the IValue so we can avoid bumping the Tensor refcount?
+ template<class T, class Enable = void> struct ivalue_to_arg_type {
+ // This base case is hit whenever a type does not have a specialisation below.
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type.");
+ };
template<class T>
- struct ivalue_to_arg_type {
+ struct ivalue_to_arg_type<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
static T call(const IValue& v) {
return std::move(v).to<T>();
}
template<class T>
struct ivalue_to_arg_type<ArrayRef<T>> {
static ArrayRef<T> call(const IValue& v) {
+ // TODO Do we want to support ArrayRef<optional<T>> ?
+ static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported argument type: c10::ArrayRef<T> and T is not one of the supported primitive types.");
return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
}
};
template<class T>
- struct ivalue_to_arg_type<std::vector<T>> {
- static ArrayRef<T> call(const IValue& v) {
- // We don't support std::vector because that would prevent us from doing
- // internal optimization to how we represent lists (e.g. SmallVector).
- // Users should use ArrayRef instead.
- static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
- }
- };
- template<class T>
struct ivalue_to_arg_type<optional<T>> {
static optional<T> call(const IValue& v) {
if (v.isNone()) {
return nullopt;
}
- return v.to<T>();
+ return ivalue_to_arg_type<T>::call(v);
}
};
-
+ // The following specialisations of ivalue_to_arg_type are technically not
+ // necessary since we would hit the base case and show an error message
+ // there if they didn't exist, but we can show a better error message
+ // in some common error scenarios.
template<class T>
+ struct ivalue_to_arg_type<std::vector<T>> {
+ // We don't support std::vector because that would prevent us from doing
+ // internal optimization to how we represent lists (e.g. SmallVector).
+ // Users should use ArrayRef instead.
+ static_assert(guts::false_t<std::vector<T>>::value, "You tried to register a kernel with an unsupported argument type: std::vector<T>. Please use c10::ArrayRef<T> instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+ // There is no reason to support float when we have double. Keep the API lean.
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: float. Please use double instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported argument type: const char*. Please use std::string instead.");
+ };
+ template<class T>
+ struct ivalue_to_arg_type<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral argument type. Please use int64_t instead.");
+ };
+
+ template<class T, class Enable = void>
struct return_type_to_ivalue_ {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
static IValue call(T&& v) {
return IValue(std::move(v));
}
if (!v.has_value()) {
return IValue();
}
- return IValue(std::move(*v));
+ return return_type_to_ivalue_<T>::call(std::move(*v));
+ }
+ };
+ template<class T>
+ struct return_type_to_ivalue_<std::vector<T>> {
+ static IValue call(std::vector<T>&& v) {
+ // TODO Do we want to support vector<optional<T>> ?
+ static_assert(guts::typelist::contains<supported_primitive_arg_types, T>::value, "You tried to register a kernel with an unsupported return type: vector<T> and T is not one of the supported primitive types.");
+ return IValue(std::move(v));
}
};
+ // The following specialisations of return_type_to_ivalue_ are technically not
+ // necessary since we would hit the base case and show an error message
+ // there if they didn't exist, but we can show a better error message
+ // in some common error scenarios.
+ template<class T>
+ struct return_type_to_ivalue_<c10::ArrayRef<T>> {
+ static_assert(guts::false_t<c10::ArrayRef<T>>::value, "You tried to register a kernel with an unsupported return type: c10::ArrayRef<T>. Please use std::vector<T> instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<float, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: float. Please use double instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_same<const char*, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported return type: const char*. Please use std::string instead.");
+ };
+ template<class T>
+ struct return_type_to_ivalue_<T, guts::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
+ static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported integral return argument type. Please use int64_t instead.");
+ };
+
template<class T>
IValue return_type_to_ivalue(T&& v) {
return return_type_to_ivalue_<guts::decay_t<T>>::call(std::move(v));
};
+/**
+ * Checks if a typelist contains a certain type.
+ * Examples:
+ * contains<typelist<int, string>, string> == true_type
+ * contains<typelist<int, string>, double> == false_type
+ */
+namespace detail {
+template<class TypeList, class Type, class Enable = void> struct contains {};
+template<class Type> struct contains<typelist<>, Type, void> : std::false_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<std::is_same<Head, Type>::value>> : std::true_type {};
+template<class Type, class Head, class... Tail>
+struct contains<typelist<Head, Tail...>, Type, guts::enable_if_t<!std::is_same<Head, Type>::value>> : contains<typelist<Tail...>, Type> {};
+}
+template<class TypeList, class Type>
+using contains = typename detail::contains<TypeList, Type>::type;
+
/**
* Returns true iff the type trait is true for all types in the type list