class ThreadsafeOperatorTable_ final {
public:
template <class Key_>
- void emplace(Key_&& key, KernelFunction value) {
- bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> bool {
+ void emplace(Key_&& key, KernelFunction* value) {
+ bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> bool {
auto result = map.emplace(std::forward<Key>(key), std::move(value));
return result.second;
});
void erase(const Key& key) {
auto num_removed =
- map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> size_t {
+ map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> size_t {
return map.erase(key);
});
assert(num_removed <= 1); // This is not a multi-map
}
}
- const KernelFunction* lookup(const Key& key) const {
- return map_.read([&](const ska::flat_hash_map<Key, KernelFunction>& map) -> const KernelFunction* {
+ KernelFunction* lookup(const Key& key) const {
+ return map_.read([&](const ska::flat_hash_map<Key, KernelFunction*>& map) -> KernelFunction* {
auto found = map.find(key);
if (found != map.end()) {
- return &found->second;
+ return found->second;
} else {
return nullptr;
}
}
private:
- LeftRight<ska::flat_hash_map<Key, KernelFunction>> map_;
+ LeftRight<ska::flat_hash_map<Key, KernelFunction*>> map_;
};
} // namespace details
* @param dispatch_key Dispatch key to define when this kernel is selected
*/
void registerKernel(
- KernelFunction func,
+ KernelFunction* func,
typename Schema::dispatch::dispatch_key_type dispatch_key) {
- kernels_.emplace(std::move(dispatch_key), std::move(func));
+ kernels_.emplace(std::move(dispatch_key), func);
}
/**
// static_assert(std::is_same<typename Schema::return_type (Args...),
// typename Schema::func_type>::value, "Argument types don't match
// operator signature");
- const auto& kernel_func = lookupKernelFunc_(args);
- return kernel_func(args);
+ KernelFunction* kernel_func = lookupKernelFunc_(args);
+ return (*kernel_func)(args);
}
private:
- const KernelFunction& lookupKernelFunc_(ArrayRef<IValue> args) const {
+ KernelFunction* lookupKernelFunc_(ArrayRef<IValue> args) const {
auto dispatch_key = Schema::dispatch::dispatch_key(args);
- const KernelFunction* found = kernels_.lookup(dispatch_key);
+ KernelFunction* found = kernels_.lookup(dispatch_key);
if (found == nullptr) {
// TODO Better error message - include op name and dispatch key (i.e.
// argument types)
std::string() + "Didn't find kernel to dispatch to for operator '" +
Schema::metadata::name() + "'");
}
- return *found;
+ return found;
}
details::ThreadsafeOperatorTable_<
* @param kernel The concrete function implementation to register
* @param dispatch_key The dispatch key to register the function to
*/
- KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
+ KernelRegistrar(KernelFunction* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
: dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_);
}
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
- c10::optional<KernelFunction> kernel_;
+ c10::optional<KernelFunction*> kernel_;
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
public:
: KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {}
KernelRegistrationBuilder(
- c10::optional<KernelFunction> kernel,
+ c10::optional<KernelFunction*> kernel,
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(KernelFunction kernel_func) && {
+ template<KernelFunction* kernel_func>
+ KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
- return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(std::move(kernel_func), std::move(dispatch_key_));
+ return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(kernel_func, std::move(dispatch_key_));
}
/**
* @param kernel concrete function implementation to be registered
* @return "this" for method chaining
*/
- KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && {
- return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func));
+ template<typename Schema::signature::func_type* kernel_func>
+ KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+ return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<kernel_func>>();
}
/**
namespace c10 {
-// TODO Use folly::Function for perf
-using KernelFunction = std::function<IValue(ArrayRef<IValue>)>;
+using KernelFunction = IValue(ArrayRef<IValue>);
namespace details {
}
};
-template<class ReturnType, class ParamTypes, class FuncType> struct _wrapKernel {};
-template<class ReturnType, class... ParamTypes, class FuncType> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class ReturnType, class ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel {};
+template<class ReturnType, class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
using parameter_types = guts::typelist::typelist<ParamTypes...>;
template<size_t... indices>
- static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
- return [kernel] (ArrayRef<IValue> args) -> IValue {
- if (args.size() != sizeof...(ParamTypes)) {
- throw std::runtime_error("Wrong number of arguments for operator call");
- }
- return return_type_to_ivalue(
- (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
- );
- };
+ static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+ if (args.size() != sizeof...(ParamTypes)) {
+ throw std::runtime_error("Wrong number of arguments for operator call");
+ }
+ return return_type_to_ivalue(
+ (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
+ );
}
};
-template<class... ParamTypes, class FuncType> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
using parameter_types = guts::typelist::typelist<ParamTypes...>;
template<size_t... indices>
- static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
- return [kernel] (ArrayRef<IValue> args) -> IValue {
- if (args.size() != sizeof...(ParamTypes)) {
- throw std::runtime_error("Wrong number of arguments for operator call");
- }
- (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
- return IValue();
- };
+ static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+ if (args.size() != sizeof...(ParamTypes)) {
+ throw std::runtime_error("Wrong number of arguments for operator call");
+ }
+ (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
+ return IValue();
}
};
-template<class SignatureTraits>
-KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) {
- using return_type = typename SignatureTraits::return_type;
- using parameter_types = typename SignatureTraits::parameter_types;
- using func_type = typename SignatureTraits::func_type;
- constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
-
- return _wrapKernel<return_type, parameter_types, func_type>::call(
- kernel,
- guts::make_index_sequence<num_parameters>()
- );
-}
-
/**
* Wrapper class around a user-provided schema definition some useful information about the schema.
*
static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
- static KernelFunction wrap_kernel(func_type* kernel) {
- return details::wrapKernel<signature_traits>(kernel);
+ template<func_type* kernel>
+ static IValue wrap_kernel(ArrayRef<IValue> args) {
+ constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
+ return details::_wrapKernel<return_type, parameter_types, func_type, kernel>::call(args, guts::make_index_sequence<num_parameters>());
}
private:
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Add)
- .kernel(&caffe2::add_op_cpu_impl<float>)
+ .kernel<&caffe2::add_op_cpu_impl<float>>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
- .kernel(&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>)
+ .kernel<&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
- .kernel(&caffe2::batch_gather_op_cpu_impl<int64_t>)
+ .kernel<&caffe2::batch_gather_op_cpu_impl<int64_t>>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int64_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
- .kernel(&caffe2::batch_gather_op_cpu_impl<int32_t>)
+ .kernel<&caffe2::batch_gather_op_cpu_impl<int32_t>>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
- .kernel(&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>)
+ .kernel<&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<float>)
+ .kernel<&caffe2::cast_op_cpu_impl<float>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<int32_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<int32_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int32_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<bool>)
+ .kernel<&caffe2::cast_op_cpu_impl<bool>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<bool>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<uint8_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<uint8_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<uint8_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<int8_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<int8_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int8_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<uint16_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<uint16_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<uint16_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<int16_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<int16_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int16_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<int64_t>)
+ .kernel<&caffe2::cast_op_cpu_impl<int64_t>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int64_t>()}});
C10_REGISTER_KERNEL(caffe2::ops::Cast)
- .kernel(&caffe2::cast_op_cpu_impl<double>)
+ .kernel<&caffe2::cast_op_cpu_impl<double>>()
.dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Concat)
- .kernel(&caffe2::concat_op_cpu_impl<float, CPUContext>)
+ .kernel<&caffe2::concat_op_cpu_impl<float, CPUContext>>()
.dispatchKey(c10::DeviceTypeId::CPU);
} // namespace c10
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite)
- .kernel(&caffe2::enforce_finite_op_impl_cpu<float>)
+ .kernel<&caffe2::enforce_finite_op_impl_cpu<float>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::ExpandDims)
- .kernel(&caffe2::expand_dims_op_cpu_impl<float>)
+ .kernel<&caffe2::expand_dims_op_cpu_impl<float>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
- .kernel(&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>)
+ .kernel<&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
.dispatchKey(c10::DispatchKey<3>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
- .kernel(&caffe2::constant_fill_op_cpu_impl)
+ .kernel<&caffe2::constant_fill_op_cpu_impl>()
.dispatchKey(c10::DeviceTypeId::CPU);
C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
- .kernel(&caffe2::uniform_fill_op_cpu_impl)
+ .kernel<&caffe2::uniform_fill_op_cpu_impl>()
.dispatchKey(c10::DeviceTypeId::CPU);
C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<float>)
- .kernel(&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>)
+ .kernel<&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
.dispatchKey(c10::DeviceTypeId::CPU);
C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int>)
- .kernel(&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>)
+ .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
.dispatchKey(c10::DeviceTypeId::CPU);
C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int64_t>)
- .kernel(&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>)
+ .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
.dispatchKey(c10::DeviceTypeId::CPU);
} // namespace c10
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Flatten)
- .kernel(&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>)
+ .kernel<&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Mul)
- .kernel(&caffe2::mul_op_cpu_impl<float>)
+ .kernel<&caffe2::mul_op_cpu_impl<float>>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Relu)
- .kernel(&caffe2::relu_op_cpu_impl<float>)
+ .kernel<&caffe2::relu_op_cpu_impl<float>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::Sigmoid)
- .kernel(&caffe2::sigmoid_op_cpu_impl<float>)
+ .kernel<&caffe2::sigmoid_op_cpu_impl<float>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
- .kernel(&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl)
+ .kernel<&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
.dispatchKey(c10::DispatchKey<2>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>)
+ .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>>()
.dispatchKey(c10::DispatchKey<3>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
LayoutId(0),
caffe2::TypeMeta::Id<int>()}});
C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>)
+ .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>>()
.dispatchKey(c10::DispatchKey<3>{
c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<int>()}});
C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>)
+ .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>>()
.dispatchKey(c10::DispatchKey<3>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),
LayoutId(0),
caffe2::TypeMeta::Id<int>()}});
C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
- .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>)
+ .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>>()
.dispatchKey(c10::DispatchKey<3>{
c10::details::TensorParameterDispatchKey{
DeviceTypeId::CPU,
namespace c10 {
C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
- .kernel(&caffe2::stop_gradient_op_cpu_impl<float>)
+ .kernel<&caffe2::stop_gradient_op_cpu_impl<float>>()
.dispatchKey({DeviceTypeId::CPU,
LayoutId(0),
caffe2::TypeMeta::Id<float>()});
}
namespace c10 {
C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm)
- .kernel(&layer_norm_c10<float>)
+ .kernel<&layer_norm_c10<float>>()
.dispatchKey(c10::DispatchKey<1>{
c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
LayoutId(0),