Avoid closure around kernel (#16165)
authorSebastian Messmer <messmer@fb.com>
Tue, 22 Jan 2019 21:21:38 +0000 (13:21 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 22 Jan 2019 21:29:11 +0000 (13:29 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16165

Store kernels as direct function pointers instead of std::function.
Using direct function pointers avoids a performance risk std::function would introduce.

Reviewed By: ezyang

Differential Revision: D13738627

fbshipit-source-id: a348906c8a201436699681980a82ca95065a06a0

21 files changed:
aten/src/ATen/core/dispatch/DispatchTable.h
aten/src/ATen/core/dispatch/KernelRegistration.h
aten/src/ATen/core/dispatch/OpSchema.h
caffe2/operators/experimental/c10/cpu/add_cpu.cc
caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
caffe2/operators/experimental/c10/cpu/cast_cpu.cc
caffe2/operators/experimental/c10/cpu/concat_cpu.cc
caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc
caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
caffe2/operators/experimental/c10/cpu/fc_cpu.cc
caffe2/operators/experimental/c10/cpu/filler_cpu.cc
caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
caffe2/operators/experimental/c10/cpu/mul_cpu.cc
caffe2/operators/experimental/c10/cpu/relu_cpu.cc
caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc
caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc
caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc
caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
caffe2/operators/layer_norm_op.cc

index 39c850a..cc7ef8a 100644 (file)
@@ -21,8 +21,8 @@ template <class Key>
 class ThreadsafeOperatorTable_ final {
  public:
   template <class Key_>
-  void emplace(Key_&& key, KernelFunction value) {
-    bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> bool {
+  void emplace(Key_&& key, KernelFunction* value) {
+    bool res = map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> bool {
       auto result = map.emplace(std::forward<Key>(key), std::move(value));
       return result.second;
     });
@@ -35,7 +35,7 @@ class ThreadsafeOperatorTable_ final {
 
   void erase(const Key& key) {
     auto num_removed =
-        map_.write([&](ska::flat_hash_map<Key, KernelFunction>& map) -> size_t {
+        map_.write([&](ska::flat_hash_map<Key, KernelFunction*>& map) -> size_t {
           return map.erase(key);
         });
     assert(num_removed <= 1); // This is not a multi-map
@@ -45,11 +45,11 @@ class ThreadsafeOperatorTable_ final {
     }
   }
 
-  const KernelFunction* lookup(const Key& key) const {
-    return map_.read([&](const ska::flat_hash_map<Key, KernelFunction>& map) -> const KernelFunction* {
+  KernelFunction* lookup(const Key& key) const {
+    return map_.read([&](const ska::flat_hash_map<Key, KernelFunction*>& map) -> KernelFunction* {
       auto found = map.find(key);
       if (found != map.end()) {
-        return &found->second;
+        return found->second;
       } else {
         return nullptr;
       }
@@ -57,7 +57,7 @@ class ThreadsafeOperatorTable_ final {
   }
 
  private:
-  LeftRight<ska::flat_hash_map<Key, KernelFunction>> map_;
+  LeftRight<ska::flat_hash_map<Key, KernelFunction*>> map_;
 };
 } // namespace details
 
@@ -87,9 +87,9 @@ class DispatchTable final {
    * @param dispatch_key Dispatch key to define when this kernel is selected
    */
   void registerKernel(
-      KernelFunction func,
+      KernelFunction* func,
       typename Schema::dispatch::dispatch_key_type dispatch_key) {
-    kernels_.emplace(std::move(dispatch_key), std::move(func));
+    kernels_.emplace(std::move(dispatch_key), func);
   }
 
   /**
@@ -118,14 +118,14 @@ class DispatchTable final {
     //      static_assert(std::is_same<typename Schema::return_type (Args...),
     //      typename Schema::func_type>::value, "Argument types don't match
     //      operator signature");
-    const auto& kernel_func = lookupKernelFunc_(args);
-    return kernel_func(args);
+    KernelFunction* kernel_func = lookupKernelFunc_(args);
+    return (*kernel_func)(args);
   }
 
  private:
-  const KernelFunction& lookupKernelFunc_(ArrayRef<IValue> args) const {
+  KernelFunction* lookupKernelFunc_(ArrayRef<IValue> args) const {
     auto dispatch_key = Schema::dispatch::dispatch_key(args);
-    const KernelFunction* found = kernels_.lookup(dispatch_key);
+    KernelFunction* found = kernels_.lookup(dispatch_key);
     if (found == nullptr) {
       // TODO Better error message - include op name and dispatch key (i.e.
       // argument types)
@@ -133,7 +133,7 @@ class DispatchTable final {
           std::string() + "Didn't find kernel to dispatch to for operator '" +
           Schema::metadata::name() + "'");
     }
-    return *found;
+    return found;
   }
 
   details::ThreadsafeOperatorTable_<
index 13aa17a..935b6c0 100644 (file)
@@ -34,7 +34,7 @@ public:
    * @param kernel The concrete function implementation to register
    * @param dispatch_key  The dispatch key to register the function to
    */
-  KernelRegistrar(KernelFunction kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
+  KernelRegistrar(KernelFunction* kernel, typename Schema::dispatch::dispatch_key_type dispatch_key)
   : dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
     Dispatcher<OpSchemaDef>::registerKernel(kernel, dispatch_key_);
   }
@@ -88,7 +88,7 @@ private:
   static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
   static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
 
-  c10::optional<KernelFunction> kernel_;
+  c10::optional<KernelFunction*> kernel_;
   c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
 
  public:
@@ -96,7 +96,7 @@ private:
       : KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {}
 
   KernelRegistrationBuilder(
-      c10::optional<KernelFunction> kernel,
+      c10::optional<KernelFunction*> kernel,
       c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
       : kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
 
@@ -116,9 +116,10 @@ private:
    * @param kernel concrete function implementation to be registered
    * @return "this" for method chaining
    */
-  KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(KernelFunction kernel_func) && {
+  template<KernelFunction* kernel_func>
+  KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
     static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
-    return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(std::move(kernel_func), std::move(dispatch_key_));
+    return KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT>(kernel_func, std::move(dispatch_key_));
   }
 
   /**
@@ -126,8 +127,9 @@ private:
    * @param kernel concrete function implementation to be registered
    * @return "this" for method chaining
    */
-  KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel(typename Schema::signature::func_type* kernel_func) && {
-    return std::move(*this).kernel(Schema::signature::wrap_kernel(kernel_func));
+  template<typename Schema::signature::func_type* kernel_func>
+  KernelRegistrationBuilder<OpSchemaDef, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
+    return std::move(*this).template kernel<&Schema::signature::template wrap_kernel<kernel_func>>();
   }
 
   /**
index db6f3e7..a386b8d 100644 (file)
@@ -10,8 +10,7 @@
 
 namespace c10 {
 
-// TODO Use folly::Function for perf
-using KernelFunction = std::function<IValue(ArrayRef<IValue>)>;
+using KernelFunction = IValue(ArrayRef<IValue>);
 
 namespace details {
 
@@ -128,50 +127,33 @@ struct ivalue_to_arg_type<ArrayRef<T>> {
   }
 };
 
-template<class ReturnType, class ParamTypes, class FuncType> struct _wrapKernel {};
-template<class ReturnType, class... ParamTypes, class FuncType> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class ReturnType, class ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel {};
+template<class ReturnType, class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<ReturnType, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
   using parameter_types = guts::typelist::typelist<ParamTypes...>;
 
   template<size_t... indices>
-  static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
-    return [kernel] (ArrayRef<IValue> args) -> IValue {
-      if (args.size() != sizeof...(ParamTypes)) {
-        throw std::runtime_error("Wrong number of arguments for operator call");
-      }
-      return return_type_to_ivalue(
-        (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
-      );
-    };
+  static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+    if (args.size() != sizeof...(ParamTypes)) {
+      throw std::runtime_error("Wrong number of arguments for operator call");
+    }
+    return return_type_to_ivalue(
+      (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...)
+    );
   }
 };
-template<class... ParamTypes, class FuncType> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType> {
+template<class... ParamTypes, class FuncType, FuncType* kernel> struct _wrapKernel<void, guts::typelist::typelist<ParamTypes...>, FuncType, kernel> {
   using parameter_types = guts::typelist::typelist<ParamTypes...>;
 
   template<size_t... indices>
-  static KernelFunction call(FuncType* kernel, guts::index_sequence<indices...>) {
-    return [kernel] (ArrayRef<IValue> args) -> IValue {
-      if (args.size() != sizeof...(ParamTypes)) {
-        throw std::runtime_error("Wrong number of arguments for operator call");
-      }
-      (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
-      return IValue();
-    };
+  static IValue call(ArrayRef<IValue> args, guts::index_sequence<indices...>) {
+    if (args.size() != sizeof...(ParamTypes)) {
+      throw std::runtime_error("Wrong number of arguments for operator call");
+    }
+    (*kernel)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<indices, parameter_types>>>>::call(args[indices])...);
+    return IValue();
   }
 };
 
-template<class SignatureTraits>
-KernelFunction wrapKernel(typename SignatureTraits::func_type* kernel) {
-  using return_type = typename SignatureTraits::return_type;
-  using parameter_types = typename SignatureTraits::parameter_types;
-  using func_type = typename SignatureTraits::func_type;
-  constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
-
-  return _wrapKernel<return_type, parameter_types, func_type>::call(
-    kernel,
-    guts::make_index_sequence<num_parameters>()
-  );
-}
-
 /**
  * Wrapper class around a user-provided schema definition some useful information about the schema.
  *
@@ -207,8 +189,10 @@ public:
 
   static constexpr size_t num_outputs = OpSchemaDef::num_outputs();
 
-  static KernelFunction wrap_kernel(func_type* kernel) {
-    return details::wrapKernel<signature_traits>(kernel);
+  template<func_type* kernel>
+  static IValue wrap_kernel(ArrayRef<IValue> args) {
+    constexpr size_t num_parameters = guts::typelist::size<parameter_types>::value;
+    return details::_wrapKernel<return_type, parameter_types, func_type, kernel>::call(args, guts::make_index_sequence<num_parameters>());
   }
 
 private:
index dbebbb5..9b81641 100644 (file)
@@ -74,7 +74,7 @@ void add_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Add)
-    .kernel(&caffe2::add_op_cpu_impl<float>)
+    .kernel<&caffe2::add_op_cpu_impl<float>>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index 4223905..de4b84d 100644 (file)
@@ -45,7 +45,7 @@ void averaged_loss_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
-    .kernel(&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>)
+    .kernel<&caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
index 3786786..3edb5ea 100644 (file)
@@ -58,7 +58,7 @@ void batch_gather_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
-    .kernel(&caffe2::batch_gather_op_cpu_impl<int64_t>)
+    .kernel<&caffe2::batch_gather_op_cpu_impl<int64_t>>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
@@ -69,7 +69,7 @@ C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
             caffe2::TypeMeta::Id<int64_t>()}});
 
 C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
-    .kernel(&caffe2::batch_gather_op_cpu_impl<int32_t>)
+    .kernel<&caffe2::batch_gather_op_cpu_impl<int32_t>>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index 5fa5d52..542a9ad 100644 (file)
@@ -270,7 +270,7 @@ void batch_matmul_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
-    .kernel(&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>)
+    .kernel<&caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index a236d53..62c2a62 100644 (file)
@@ -79,55 +79,55 @@ void cast_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<float>)
+    .kernel<&caffe2::cast_op_cpu_impl<float>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<float>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<int32_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<int32_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<int32_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<bool>)
+    .kernel<&caffe2::cast_op_cpu_impl<bool>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<bool>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<uint8_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<uint8_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<uint8_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<int8_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<int8_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<int8_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<uint16_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<uint16_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<uint16_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<int16_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<int16_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<int16_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<int64_t>)
+    .kernel<&caffe2::cast_op_cpu_impl<int64_t>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
         caffe2::TypeMeta::Id<int64_t>()}});
 C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel(&caffe2::cast_op_cpu_impl<double>)
+    .kernel<&caffe2::cast_op_cpu_impl<double>>()
     .dispatchKey(c10::DispatchKey<1>{c10::details::TensorParameterDispatchKey{
         DeviceTypeId::CPU,
         LayoutId(0),
index d049ce0..858675c 100644 (file)
@@ -108,6 +108,6 @@ void concat_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Concat)
-    .kernel(&caffe2::concat_op_cpu_impl<float, CPUContext>)
+    .kernel<&caffe2::concat_op_cpu_impl<float, CPUContext>>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 } // namespace c10
index 46df603..1db6ed2 100644 (file)
@@ -28,7 +28,7 @@ void enforce_finite_op_impl_cpu(const at::Tensor& input_) {
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite)
-    .kernel(&caffe2::enforce_finite_op_impl_cpu<float>)
+    .kernel<&caffe2::enforce_finite_op_impl_cpu<float>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index cc73d77..cbabe07 100644 (file)
@@ -54,7 +54,7 @@ void expand_dims_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::ExpandDims)
-    .kernel(&caffe2::expand_dims_op_cpu_impl<float>)
+    .kernel<&caffe2::expand_dims_op_cpu_impl<float>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index b22a977..8a074c8 100644 (file)
@@ -124,7 +124,7 @@ void fc_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
-    .kernel(&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>)
+    .kernel<&caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
     .dispatchKey(c10::DispatchKey<3>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index dc878b4..376bc92 100644 (file)
@@ -154,22 +154,22 @@ void uniform_fill_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
-    .kernel(&caffe2::constant_fill_op_cpu_impl)
+    .kernel<&caffe2::constant_fill_op_cpu_impl>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 
 C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
-    .kernel(&caffe2::uniform_fill_op_cpu_impl)
+    .kernel<&caffe2::uniform_fill_op_cpu_impl>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 
 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<float>)
-    .kernel(&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>)
+    .kernel<&caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 
 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int>)
-    .kernel(&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>)
+    .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 
 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill<int64_t>)
-    .kernel(&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>)
+    .kernel<&caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
     .dispatchKey(c10::DeviceTypeId::CPU);
 } // namespace c10
index a8eb14a..26d1623 100644 (file)
@@ -30,7 +30,7 @@ void flatten_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Flatten)
-    .kernel(&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>)
+    .kernel<&caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index 067eefc..38f7385 100644 (file)
@@ -75,7 +75,7 @@ void mul_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Mul)
-    .kernel(&caffe2::mul_op_cpu_impl<float>)
+    .kernel<&caffe2::mul_op_cpu_impl<float>>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index ca66403..7c5b862 100644 (file)
@@ -44,7 +44,7 @@ void relu_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Relu)
-    .kernel(&caffe2::relu_op_cpu_impl<float>)
+    .kernel<&caffe2::relu_op_cpu_impl<float>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index 13bf7d2..78febf5 100644 (file)
@@ -27,7 +27,7 @@ void sigmoid_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::Sigmoid)
-    .kernel(&caffe2::sigmoid_op_cpu_impl<float>)
+    .kernel<&caffe2::sigmoid_op_cpu_impl<float>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index 6c21ae1..32853fd 100644 (file)
@@ -74,7 +74,7 @@ void sigmoid_cross_entropy_with_logits_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
-    .kernel(&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl)
+    .kernel<&caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
     .dispatchKey(c10::DispatchKey<2>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
index a58f762..9471116 100644 (file)
@@ -60,7 +60,7 @@ void sparse_lengths_sum_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
-    .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>)
+    .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int32_t>>()
     .dispatchKey(c10::DispatchKey<3>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
@@ -73,7 +73,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
                                                  LayoutId(0),
                                                  caffe2::TypeMeta::Id<int>()}});
 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
-    .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>)
+    .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int32_t>>()
     .dispatchKey(c10::DispatchKey<3>{
         c10::details::TensorParameterDispatchKey{
             DeviceTypeId::CPU,
@@ -87,7 +87,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
                                                  LayoutId(0),
                                                  caffe2::TypeMeta::Id<int>()}});
 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
-    .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>)
+    .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<float, int64_t>>()
     .dispatchKey(c10::DispatchKey<3>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),
@@ -100,7 +100,7 @@ C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
                                                  LayoutId(0),
                                                  caffe2::TypeMeta::Id<int>()}});
 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
-    .kernel(&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>)
+    .kernel<&caffe2::sparse_lengths_sum_op_cpu_impl<at::Half, int64_t>>()
     .dispatchKey(c10::DispatchKey<3>{
         c10::details::TensorParameterDispatchKey{
             DeviceTypeId::CPU,
index d77e430..1165978 100644 (file)
@@ -23,7 +23,7 @@ void stop_gradient_op_cpu_impl(
 
 namespace c10 {
 C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
-    .kernel(&caffe2::stop_gradient_op_cpu_impl<float>)
+    .kernel<&caffe2::stop_gradient_op_cpu_impl<float>>()
     .dispatchKey({DeviceTypeId::CPU,
                   LayoutId(0),
                   caffe2::TypeMeta::Id<float>()});
index b4180a6..fa90975 100644 (file)
@@ -219,7 +219,7 @@ c10::IValue layer_norm_c10(c10::ArrayRef<c10::IValue> inputs) {
 }
 namespace c10 {
 C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm)
-    .kernel(&layer_norm_c10<float>)
+    .kernel<&layer_norm_c10<float>>()
     .dispatchKey(c10::DispatchKey<1>{
         c10::details::TensorParameterDispatchKey{DeviceTypeId::CPU,
                                                  LayoutId(0),