From: Sebastian Messmer Date: Tue, 22 Jan 2019 21:21:38 +0000 (-0800) Subject: Avoid closure around kernel (#16165) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1732 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cd8f4154f41a88ccd4f55ce72af2036afbf826a7;p=platform%2Fupstream%2Fpytorch.git Avoid closure around kernel (#16165) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16165 Store kernels as direct function pointers instead of std::function. Using direct function pointers avoids a performance risk std::function would introduce. Reviewed By: ezyang Differential Revision: D13738627 fbshipit-source-id: a348906c8a201436699681980a82ca95065a06a0 --- diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 39c850a..cc7ef8a 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -21,8 +21,8 @@ template class ThreadsafeOperatorTable_ final { public: template - void emplace(Key_&& key, KernelFunction value) { - bool res = map_.write([&](ska::flat_hash_map& map) -> bool { + void emplace(Key_&& key, KernelFunction* value) { + bool res = map_.write([&](ska::flat_hash_map& map) -> bool { auto result = map.emplace(std::forward(key), std::move(value)); return result.second; }); @@ -35,7 +35,7 @@ class ThreadsafeOperatorTable_ final { void erase(const Key& key) { auto num_removed = - map_.write([&](ska::flat_hash_map& map) -> size_t { + map_.write([&](ska::flat_hash_map& map) -> size_t { return map.erase(key); }); assert(num_removed <= 1); // This is not a multi-map @@ -45,11 +45,11 @@ class ThreadsafeOperatorTable_ final { } } - const KernelFunction* lookup(const Key& key) const { - return map_.read([&](const ska::flat_hash_map& map) -> const KernelFunction* { + KernelFunction* lookup(const Key& key) const { + return map_.read([&](const ska::flat_hash_map& map) -> KernelFunction* { auto found = map.find(key); if (found != map.end()) { - return &found->second; + return found->second; } else { return nullptr; } @@ -57,7 +57,7 @@ class ThreadsafeOperatorTable_ final { } private: - LeftRight> map_; + LeftRight> map_; }; } // namespace details @@ -87,9 +87,9 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected */ void registerKernel( - KernelFunction func, + KernelFunction* func, typename Schema::dispatch::dispatch_key_type dispatch_key) { - kernels_.emplace(std::move(dispatch_key), std::move(func)); + kernels_.emplace(std::move(dispatch_key), func); } /** @@ -118,14 +118,14 @@ class DispatchTable final { // static_assert(std::is_same::value, "Argument types don't match // operator signature"); - const auto& kernel_func = lookupKernelFunc_(args); - return kernel_func(args); + KernelFunction* kernel_func = lookupKernelFunc_(args); + return (*kernel_func)(args); } private: - const KernelFunction& lookupKernelFunc_(ArrayRef args) const { + KernelFunction* lookupKernelFunc_(ArrayRef args) const { auto dispatch_key = Schema::dispatch::dispatch_key(args); - const KernelFunction* found = kernels_.lookup(dispatch_key); + KernelFunction* found = kernels_.lookup(dispatch_key); if (found == nullptr) { // TODO Better error message - include op name and dispatch key (i.e. // argument types) @@ -133,7 +133,7 @@ class DispatchTable final { std::string() + "Didn't find kernel to dispatch to for operator '" + Schema::metadata::name() + "'"); } - return *found; + return found; } details::ThreadsafeOperatorTable_< diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h index 13aa17a..935b6c0 100644 --- a/aten/src/ATen/core/dispatch/KernelRegistration.h +++ b/aten/src/ATen/core/dispatch/KernelRegistration.h @@ -34,7 +34,7 @@ public: * @param kernel The concrete function implementation to register * @param dispatch_key The dispatch key to register the function to */ - KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) + KernelRegistrar(KernelFunction* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key) : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { Dispatcher::registerKernel(kernel, dispatch_key_); } @@ -88,7 +88,7 @@ private: static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0; static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1; - c10::optional kernel_; + c10::optional kernel_; c10::optional dispatch_key_; public: @@ -96,7 +96,7 @@ private: : KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {} KernelRegistrationBuilder( - c10::optional kernel, + c10::optional kernel, c10::optional dispatch_key) : kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {} @@ -116,9 +116,10 @@ private: * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - KernelRegistrationBuilder kernel(KernelFunction kernel_func) && { + template + KernelRegistrationBuilder kernel() && { static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration"); - return KernelRegistrationBuilder(std::move(kernel_func), std::move(dispatch_key_)); + return KernelRegistrationBuilder(kernel_func, std::move(dispatch_key_)); } /** @@ -126,8 +127,9 @@ private: * @param kernel concrete function implementation to be registered * @return "this" for method chaining */ - KernelRegistrationBuilder kernel(typename Schema::signature::func_type* kernel_func) && { - return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func)); + template + KernelRegistrationBuilder kernel() && { + return std::move(*this).template kernel<&Schema::signature::template wrap_kernel>(); } /** diff --git a/aten/src/ATen/core/dispatch/OpSchema.h b/aten/src/ATen/core/dispatch/OpSchema.h index db6f3e7..a386b8d 100644 --- a/aten/src/ATen/core/dispatch/OpSchema.h +++ b/aten/src/ATen/core/dispatch/OpSchema.h @@ -10,8 +10,7 @@ namespace c10 { -// TODO Use folly::Function for perf -using KernelFunction = std::function)>; +using KernelFunction = IValue(ArrayRef); namespace details { @@ -128,50 +127,33 @@ struct ivalue_to_arg_type> { } }; -template struct _wrapKernel {}; -template struct _wrapKernel, FuncType> { +template struct _wrapKernel {}; +template struct _wrapKernel, FuncType, kernel> { using parameter_types = guts::typelist::typelist; template - static KernelFunction call(FuncType* kernel, guts::index_sequence) { - return [kernel] (ArrayRef args) -> IValue { - if (args.size() != sizeof...(ParamTypes)) { - throw std::runtime_error("Wrong number of arguments for operator call"); - } - return return_type_to_ivalue( - (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...) - ); - }; + static IValue call(ArrayRef args, guts::index_sequence) { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + return return_type_to_ivalue( + (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...) + ); } }; -template struct _wrapKernel, FuncType> { +template struct _wrapKernel, FuncType, kernel> { using parameter_types = guts::typelist::typelist; template - static KernelFunction call(FuncType* kernel, guts::index_sequence) { - return [kernel] (ArrayRef args) -> IValue { - if (args.size() != sizeof...(ParamTypes)) { - throw std::runtime_error("Wrong number of arguments for operator call"); - } - (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...); - return IValue(); - }; + static IValue call(ArrayRef args, guts::index_sequence) { + if (args.size() != sizeof...(ParamTypes)) { + throw std::runtime_error("Wrong number of arguments for operator call"); + } + (*kernel)(ivalue_to_arg_type>>>::call(args[indices])...); + return IValue(); } }; -template -KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) { - using return_type = typename SignatureTraits::return_type; - using parameter_types = typename SignatureTraits::parameter_types; - using func_type = typename SignatureTraits::func_type; - constexpr size_t num_parameters = guts::typelist::size::value; - - return _wrapKernel::call( - kernel, - guts::make_index_sequence() - ); -} - /** * Wrapper class around a user-provided schema definition some useful information about the schema. * @@ -207,8 +189,10 @@ public: static constexpr size_t num_outputs = OpSchemaDef::num_outputs(); - static KernelFunction wrap_kernel(func_type* kernel) { - return details::wrapKernel(kernel); + template + static IValue wrap_kernel(ArrayRef args) { + constexpr size_t num_parameters = guts::typelist::size::value; + return details::_wrapKernel::call(args, guts::make_index_sequence()); } private: diff --git a/caffe2/operators/experimental/c10/cpu/add_cpu.cc b/caffe2/operators/experimental/c10/cpu/add_cpu.cc index dbebbb5..9b81641 100644 --- a/caffe2/operators/experimental/c10/cpu/add_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/add_cpu.cc @@ -74,7 +74,7 @@ void add_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Add) - .kernel(&caffe2::add_op_cpu_impl) + .kernel<&caffe2::add_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc index 4223905..de4b84d 100644 --- a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc @@ -45,7 +45,7 @@ void averaged_loss_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss) - .kernel(&caffe2::averaged_loss_op_cpu_impl) + .kernel<&caffe2::averaged_loss_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc index 3786786..3edb5ea 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc @@ -58,7 +58,7 @@ void batch_gather_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::BatchGather) - .kernel(&caffe2::batch_gather_op_cpu_impl) + .kernel<&caffe2::batch_gather_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), @@ -69,7 +69,7 @@ C10_REGISTER_KERNEL(caffe2::ops::BatchGather) caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::BatchGather) - .kernel(&caffe2::batch_gather_op_cpu_impl) + .kernel<&caffe2::batch_gather_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc index 5fa5d52..542a9ad 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc @@ -270,7 +270,7 @@ void batch_matmul_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul) - .kernel(&caffe2::batch_matmul_op_cpu_impl) + .kernel<&caffe2::batch_matmul_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index a236d53..62c2a62 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -79,55 +79,55 @@ void cast_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::Cast) - .kernel(&caffe2::cast_op_cpu_impl) + .kernel<&caffe2::cast_op_cpu_impl>() .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc index d049ce0..858675c 100644 --- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc @@ -108,6 +108,6 @@ void concat_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Concat) - .kernel(&caffe2::concat_op_cpu_impl) + .kernel<&caffe2::concat_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc index 46df603..1db6ed2 100644 --- a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc @@ -28,7 +28,7 @@ void enforce_finite_op_impl_cpu(const at::Tensor& input_) { namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite) - .kernel(&caffe2::enforce_finite_op_impl_cpu) + .kernel<&caffe2::enforce_finite_op_impl_cpu>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc index cc73d77..cbabe07 100644 --- a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc @@ -54,7 +54,7 @@ void expand_dims_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::ExpandDims) - .kernel(&caffe2::expand_dims_op_cpu_impl) + .kernel<&caffe2::expand_dims_op_cpu_impl>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc index b22a977..8a074c8 100644 --- a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc @@ -124,7 +124,7 @@ void fc_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::FullyConnected) - .kernel(&caffe2::fc_op_cpu_impl) + .kernel<&caffe2::fc_op_cpu_impl>() .dispatchKey(c10::DispatchKey<3>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc index dc878b4..376bc92 100644 --- a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc @@ -154,22 +154,22 @@ void uniform_fill_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::ConstantFill) - .kernel(&caffe2::constant_fill_op_cpu_impl) + .kernel<&caffe2::constant_fill_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); C10_REGISTER_KERNEL(caffe2::ops::UniformFill) - .kernel(&caffe2::uniform_fill_op_cpu_impl) + .kernel<&caffe2::uniform_fill_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill) - .kernel(&caffe2::given_tensor_fill_op_cpu_impl) + .kernel<&caffe2::given_tensor_fill_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill) - .kernel(&caffe2::given_tensor_fill_op_cpu_impl) + .kernel<&caffe2::given_tensor_fill_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill) - .kernel(&caffe2::given_tensor_fill_op_cpu_impl) + .kernel<&caffe2::given_tensor_fill_op_cpu_impl>() .dispatchKey(c10::DeviceTypeId::CPU); } // namespace c10 diff --git a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc index a8eb14a..26d1623 100644 --- a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc @@ -30,7 +30,7 @@ void flatten_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Flatten) - .kernel(&caffe2::flatten_op_cpu_impl) + .kernel<&caffe2::flatten_op_cpu_impl>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc index 067eefc..38f7385 100644 --- a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc @@ -75,7 +75,7 @@ void mul_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Mul) - .kernel(&caffe2::mul_op_cpu_impl) + .kernel<&caffe2::mul_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc index ca66403..7c5b862 100644 --- a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc @@ -44,7 +44,7 @@ void relu_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Relu) - .kernel(&caffe2::relu_op_cpu_impl) + .kernel<&caffe2::relu_op_cpu_impl>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc index 13bf7d2..78febf5 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc @@ -27,7 +27,7 @@ void sigmoid_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::Sigmoid) - .kernel(&caffe2::sigmoid_op_cpu_impl) + .kernel<&caffe2::sigmoid_op_cpu_impl>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc index 6c21ae1..32853fd 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc @@ -74,7 +74,7 @@ void sigmoid_cross_entropy_with_logits_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits) - .kernel(&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl) + .kernel<&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>() .dispatchKey(c10::DispatchKey<2>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), diff --git a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc index a58f762..9471116 100644 --- a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc @@ -60,7 +60,7 @@ void sparse_lengths_sum_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl) + .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl>() .dispatchKey(c10::DispatchKey<3>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), @@ -73,7 +73,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl) + .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl>() .dispatchKey(c10::DispatchKey<3>{ c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, @@ -87,7 +87,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl) + .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl>() .dispatchKey(c10::DispatchKey<3>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0), @@ -100,7 +100,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) LayoutId(0), caffe2::TypeMeta::Id()}}); C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum) - .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl) + .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl>() .dispatchKey(c10::DispatchKey<3>{ c10::details::TensorParameterDispatchKey{ DeviceTypeId::CPU, diff --git a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc index d77e430..1165978 100644 --- a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc @@ -23,7 +23,7 @@ void stop_gradient_op_cpu_impl( namespace c10 { C10_REGISTER_KERNEL(caffe2::ops::StopGradient) - .kernel(&caffe2::stop_gradient_op_cpu_impl) + .kernel<&caffe2::stop_gradient_op_cpu_impl>() .dispatchKey({DeviceTypeId::CPU, LayoutId(0), caffe2::TypeMeta::Id()}); diff --git a/caffe2/operators/layer_norm_op.cc b/caffe2/operators/layer_norm_op.cc index b4180a6..fa90975 100644 --- a/caffe2/operators/layer_norm_op.cc +++ b/caffe2/operators/layer_norm_op.cc @@ -219,7 +219,7 @@ c10::IValue layer_norm_c10(c10::ArrayRef inputs) { } namespace c10 { C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm) - .kernel(&layer_norm_c10) + .kernel<&layer_norm_c10>() .dispatchKey(c10::DispatchKey<1>{ c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU, LayoutId(0),