Add functor- and function-based kernel registration API (#18162)
authorSebastian Messmer <messmer@fb.com>
Sat, 30 Mar 2019 07:03:44 +0000 (00:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 07:07:19 +0000 (00:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18162

- Adds the API to register a functor- and function-based kernel.
- Change the experimental c10 ops to use this new API instead of the old one
- Deletes the old APIs in KernelRegistration.h and OpSchemaRegistration.h

Reviewed By: dzhulgakov

Differential Revision: D14514239

fbshipit-source-id: 35b2f6e8f62964e54886450a6a5fac812ed20f26

69 files changed:
aten/src/ATen/core/dispatch/DispatchTable.h
aten/src/ATen/core/dispatch/Dispatcher.cpp
aten/src/ATen/core/dispatch/Dispatcher.h
aten/src/ATen/core/dispatch/KernelCache.h
aten/src/ATen/core/dispatch/KernelRegistration.h [deleted file]
aten/src/ATen/core/dispatch/OpSchemaRegistration.h [deleted file]
aten/src/ATen/core/dispatch/README.md
aten/src/ATen/core/op_registration/base.h
aten/src/ATen/core/op_registration/kernel_function.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_function_test.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_functor.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_functor_test.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_stackbased.h
aten/src/ATen/core/op_registration/op_registration.h
aten/src/ATen/core/op_registration/test_helpers.h
c10/util/C++17.h
c10/util/Metaprogramming.h
caffe2/operators/CMakeLists.txt
caffe2/operators/experimental/c10/cpu/add_cpu.cc
caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc
caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc
caffe2/operators/experimental/c10/cpu/cast_cpu.cc
caffe2/operators/experimental/c10/cpu/concat_cpu.cc
caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc
caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc
caffe2/operators/experimental/c10/cpu/fc_cpu.cc
caffe2/operators/experimental/c10/cpu/filler_cpu.cc
caffe2/operators/experimental/c10/cpu/flatten_cpu.cc
caffe2/operators/experimental/c10/cpu/mul_cpu.cc
caffe2/operators/experimental/c10/cpu/relu_cpu.cc
caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc
caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc
caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc
caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc
caffe2/operators/experimental/c10/schemas/add.cc [deleted file]
caffe2/operators/experimental/c10/schemas/add.h [deleted file]
caffe2/operators/experimental/c10/schemas/averaged_loss.cc [deleted file]
caffe2/operators/experimental/c10/schemas/averaged_loss.h [deleted file]
caffe2/operators/experimental/c10/schemas/batch_gather.cc [deleted file]
caffe2/operators/experimental/c10/schemas/batch_gather.h [deleted file]
caffe2/operators/experimental/c10/schemas/batch_matmul.cc [deleted file]
caffe2/operators/experimental/c10/schemas/batch_matmul.h [deleted file]
caffe2/operators/experimental/c10/schemas/cast.cc [deleted file]
caffe2/operators/experimental/c10/schemas/cast.h [deleted file]
caffe2/operators/experimental/c10/schemas/concat.cc [deleted file]
caffe2/operators/experimental/c10/schemas/concat.h [deleted file]
caffe2/operators/experimental/c10/schemas/enforce_finite.cc [deleted file]
caffe2/operators/experimental/c10/schemas/enforce_finite.h [deleted file]
caffe2/operators/experimental/c10/schemas/expand_dims.cc [deleted file]
caffe2/operators/experimental/c10/schemas/expand_dims.h [deleted file]
caffe2/operators/experimental/c10/schemas/fc.cc [deleted file]
caffe2/operators/experimental/c10/schemas/fc.h [deleted file]
caffe2/operators/experimental/c10/schemas/filler.cc [deleted file]
caffe2/operators/experimental/c10/schemas/filler.h [deleted file]
caffe2/operators/experimental/c10/schemas/flatten.cc [deleted file]
caffe2/operators/experimental/c10/schemas/flatten.h [deleted file]
caffe2/operators/experimental/c10/schemas/mul.cc [deleted file]
caffe2/operators/experimental/c10/schemas/mul.h [deleted file]
caffe2/operators/experimental/c10/schemas/relu.cc [deleted file]
caffe2/operators/experimental/c10/schemas/relu.h [deleted file]
caffe2/operators/experimental/c10/schemas/sigmoid.cc [deleted file]
caffe2/operators/experimental/c10/schemas/sigmoid.h [deleted file]
caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc [deleted file]
caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h [deleted file]
caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc [deleted file]
caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h [deleted file]
caffe2/operators/experimental/c10/schemas/stop_gradient.cc [deleted file]
caffe2/operators/experimental/c10/schemas/stop_gradient.h [deleted file]

index 78b048b..15865cc 100644 (file)
@@ -14,6 +14,7 @@
 #include <type_traits>
 #include <sstream>
 #include <unordered_map>
+#include <functional>
 
 namespace c10 {
 
@@ -23,7 +24,7 @@ namespace c10 {
  * so we can create a new cache instance when a kernel is looked up
  * from the dispatch table.
  */
-using KernelCacheCreatorFunction = std::unique_ptr<c10::KernelCache> ();
+using KernelCacheCreatorFunction = std::function<std::unique_ptr<c10::KernelCache> ()>;
 /**
  * The dispatch table stores a pointer to a kernel function and a pointer
  * to a function initializing a cache for the kernel. If the kernel wants
@@ -34,7 +35,7 @@ using KernelCacheCreatorFunction = std::unique_ptr<c10::KernelCache> ();
  */
 struct DispatchTableEntry final {
   /*not-nullable*/ KernelFunction* kernel_func;
-  /*not-nullable*/ KernelCacheCreatorFunction* cache_creator_func;
+  /*not-nullable*/ KernelCacheCreatorFunction cache_creator_func;
 };
 
 namespace detail {
index 85a148f..8626fae 100644 (file)
@@ -103,9 +103,9 @@ void Dispatcher::deregisterSchema(const OperatorHandle& op) {
   }
 }
 
-void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
+void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func) {
   // note: this doesn't need the mutex because write operations on the list keep iterators intact.
-  op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
+  op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, std::move(cache_creator_func)});
 }
 
 void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
index 35b32a3..08645a2 100644 (file)
@@ -39,8 +39,8 @@ public:
   }
 
 private:
-  explicit OpKernel(KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
-  : kernel_(kernel), cache_((*cache_creator)()) {}
+  explicit OpKernel(KernelFunction* kernel, const KernelCacheCreatorFunction& cache_creator)
+  : kernel_(kernel), cache_(cache_creator()) {}
   friend class Dispatcher;
 
   KernelFunction* kernel_;
@@ -126,7 +126,7 @@ public:
   /**
    * Register an operator to the dispatch table for an operator.
    */
-  void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func);
+  void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func);
 
   /**
    * Remove an operator from the dispatch table for an operator.
index e879b9d..3ff270b 100644 (file)
@@ -1,5 +1,7 @@
 #pragma once
 
+#include <c10/macros/Macros.h>
+
 namespace c10 {
 
 /**
diff --git a/aten/src/ATen/core/dispatch/KernelRegistration.h b/aten/src/ATen/core/dispatch/KernelRegistration.h
deleted file mode 100644 (file)
index 4fef909..0000000
+++ /dev/null
@@ -1,279 +0,0 @@
-#pragma once
-
-#include <c10/util/Optional.h>
-#include <ATen/core/dispatch/Dispatcher.h>
-
-/**
- * To register your own kernel for an operator, do in one (!) cpp file:
- *   C10_REGISTER_KERNEL(OperatorHandle)
- *      .kernel<decltype(&kernel_func), &kernel_func>()
- *      .dispatchKey(dispatch_key);
- *
- * Example:
- *
- *  Tensor my_kernel_cpu(Tensor in) {...}
- *
- *  C10_REGISTER_KERNEL(MyOpSchema)
- *      .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>()
- *      .dispatchKey(CPUTensorId());
- */
-
-namespace c10 {
-
-// TODO Test different order for builder
-// TODO Test no dispatch key defined
-
-/**
- * Class which, on construction, registers an operator in the dispatch table. The intent is that
- * this class is constructed at static initialization time so that operators automatically get
- * registered when a dlopen() occurs.
- *
- * You shouldn't call this directly; instead, use the C10_REGISTER_KERNEL macros.
- */
-class KernelRegistrar final {
-public:
-  using OpHandleGetter = const OperatorHandle& ();
-
-  /**
-   * @param op The operator to register the kernel for
-   * @param dispatch_key  The dispatch key to register the function to
-   * @param kernel The concrete function implementation to register
-   * @param cache_creator A function initializing the cache for the kernel
-   */
-  explicit KernelRegistrar(OpHandleGetter *op, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
-  : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
-    Dispatcher::singleton().registerKernel(op_(), dispatch_key_, kernel, cache_creator);
-  }
-
-  KernelRegistrar(KernelRegistrar&& rhs)
-  : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(true) {
-    rhs.owns_registration_ = false;
-  }
-
-  // not needed for now
-  KernelRegistrar& operator=(KernelRegistrar&& rhs) = delete;
-
-  ~KernelRegistrar() {
-    if (owns_registration_) {
-      Dispatcher::singleton().deregisterKernel(op_(), dispatch_key_);
-    }
-  }
-
-private:
-  OpHandleGetter *op_;
-  const TensorTypeId dispatch_key_;
-  bool owns_registration_;
-
-  C10_DISABLE_COPY_AND_ASSIGN(KernelRegistrar);
-};
-
-namespace detail {
-// ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
-// cast it to the type that should be passed to the kernel function.
-// Examples: If the IValue contains a plain type like an int, return that.
-//           If the IValue contains an IntList, return it as ArrayRef<int>.
-template<class T>
-struct ivalue_to_arg_type {
-  static T call(const IValue& v) {
-    return std::move(v).to<T>();
-  }
-};
-template<class T>
-struct ivalue_to_arg_type<ArrayRef<T>> {
-  static ArrayRef<T> call(const IValue& v) {
-    return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
-  }
-};
-
-// call_with_ivalue_args: Take a function pointer and an ArrayRef<IValue>
-// containing the arguments to call the function pointer with, and call it.
-// The extra_args are appended as additional arguments at the end of the function call.
-// Example:
-// int myfunc(int a, ArrayRef<int> b, string c);
-// int main() {
-//   std::vector<IValue> ivalue_args = {IValue(2), IntList::create(3, 4)};
-//   call_with_ivalue_args<decltype(myfunc), &myfunc>(ivalue_args, "extra_arg");
-// }
-template<class FuncType, FuncType* func, class... ExtraArgs, size_t... ivalue_arg_indices>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args_(ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>, ExtraArgs&&... extra_args) {
-  using IValueArgTypes = typename guts::function_traits<FuncType>::parameter_types;
-  return (*func)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])..., std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class FuncType, FuncType* func, class... ExtraArgs>
-typename guts::function_traits<FuncType>::return_type call_with_ivalue_args(ArrayRef<IValue> ivalue_args, ExtraArgs&&... extra_args) {
-  constexpr size_t num_ivalue_args = guts::function_traits<FuncType>::number_of_parameters - sizeof...(ExtraArgs);
-  AT_ASSERTM(num_ivalue_args == ivalue_args.size(), "Wrong number of ivalue arguments");
-  return call_with_ivalue_args_<FuncType, func>(ivalue_args, guts::make_index_sequence<num_ivalue_args>(), std::forward<ExtraArgs>(extra_args)...);
-}
-
-template<class OutputType>
-struct push_outputs final {
-  static void call(OutputType&& output, Stack* stack) {
-    push_outputs<std::tuple<OutputType>>(std::tuple<OutputType>(std::move(output)), stack);
-  }
-};
-template<class... OutputTypes>
-struct push_outputs<std::tuple<OutputTypes...>> final {
-  static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
-    for (size_t i = 0; i < sizeof...(OutputTypes); ++i) {
-      torch::jit::push(return_type_to_ivalue(std::move(output)));
-    }
-  }
-};
-
-// SFINAE over (1) does the operator kernel have a cache and (2) does it return a value or void
-template<class CacheTypeOrVoid, class FuncType, FuncType* kernel, class Enable = void> struct wrap_kernel {};
-// SFINAE version for kernels with output and with cache
-template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
-struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
-  static typename guts::function_traits<FuncType>::return_type call(Stack* stack, KernelCache* cache) {
-    constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
-    auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
-    push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
-  }
-};
-// SFINAE version for kernels with output and without a cache
-template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
-struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && !std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
-  static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
-    constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
-    auto output = call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
-    push_outputs<typename guts::function_traits<FuncType>::return_type>(std::move(output), stack);
-  }
-};
-// SFINAE version for kernels without output and with a cache
-template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
-struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<!std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
-  static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* cache) {
-    constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters - 1; // -1 because it takes the kernel cache as last argument
-    call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs), static_cast<CacheTypeOrVoid*>(cache));
-  }
-};
-// SFINAE version for kernels without output and without a cache
-template<class CacheTypeOrVoid, class FuncType, FuncType* kernel>
-struct wrap_kernel<CacheTypeOrVoid, FuncType, kernel, guts::enable_if_t<std::is_same<void, CacheTypeOrVoid>::value && std::is_same<void, typename guts::function_traits<FuncType>::return_type>::value>> final {
-  static typename guts::function_traits<FuncType>::return_type call(Stack* stack, c10::KernelCache* /*cache*/) {
-    constexpr size_t num_inputs = guts::function_traits<FuncType>::number_of_parameters;
-    call_with_ivalue_args<FuncType, kernel>(torch::jit::last(*stack, num_inputs));
-  }
-};
-}
-
-/**
- * Helper class for building a KernelRegistrar.  This permits "keyword-argument" like syntax
- * when performing operator registration, e.g., as in:
- *
- * C10_REGISTER_KERNEL(::ops::add_notensor)
- *      .kernel(&add_notensor_op)
- *      .dispatchKey("bla");
- *
- * Expanded, this macro invocation looks like:
- *
- * static KernelRegistrar<::ops::add_notensor> _anon0 =
- *    KernelRegistrationBuilder<::ops::add_notensor, false, false>()
- *      .kernel(&add_notensor_op)
- *      .dispatchKey("bla");
- *
- * The resulting full expression is implicitly convertible to a KernelRegistrar.
- */
-template<class CacheTypeOrVoid, uint64_t FieldsPresentFlags>
-class KernelRegistrationBuilder final {
-private:
-  static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 0;
-  static constexpr uint64_t KERNEL_PRESENT = 0x01 << 1;
-  static constexpr uint64_t CACHE_PRESENT = 0x01 << 2;
-
-  using OpHandleGetter = KernelRegistrar::OpHandleGetter;
-
-  static std::unique_ptr<c10::KernelCache> defaultCacheCreator() {
-    return nullptr;
-  }
-
-  template<class Cache>
-  static std::unique_ptr<c10::KernelCache> cacheCreator() {
-    static_assert(std::is_default_constructible<Cache>::value, "Cache class must be default constructible");
-    return guts::make_unique<Cache>();
-  }
-
-  OpHandleGetter *op_;
-  c10::optional<TensorTypeId> dispatch_key_;
-  KernelFunction* kernel_;
-  KernelCacheCreatorFunction* cache_creator_;
-
- public:
-  constexpr explicit KernelRegistrationBuilder(OpHandleGetter *op)
-      : KernelRegistrationBuilder(std::move(op), c10::nullopt, nullptr, &defaultCacheCreator) {}
-
-  constexpr explicit KernelRegistrationBuilder(
-      OpHandleGetter *op,
-      c10::optional<TensorTypeId> dispatch_key,
-      KernelFunction* kernel,
-      KernelCacheCreatorFunction* cache_creator)
-      : op_(std::move(op)), dispatch_key_(std::move(dispatch_key)), kernel_(kernel), cache_creator_(cache_creator)  {}
-
-  /**
-   * Implicit coercion to KernelRegistrar that finalizes the builder and
-   * creates the object.
-   * @return Produced KernelRegistrar
-   */
-  operator KernelRegistrar() && {
-    static_assert(FieldsPresentFlags & KERNEL_PRESENT, "Forgot to call .kernel() in kernel registration");
-    static_assert(FieldsPresentFlags & DISPATCH_KEY_PRESENT, "Forgot to call .dispatchKey() in kernel registration");
-    return KernelRegistrar(op_, std::move(*dispatch_key_), kernel_, cache_creator_);
-  }
-
-  /**
-   * Specify the dispatch key for this dispatch registration
-   * @param dispatch_key dispatch key to register the function to
-   * @return "this" for method chaining
-   */
-  AT_CPP14_CONSTEXPR KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT> dispatchKey(TensorTypeId dispatch_key) && {
-    static_assert(!(FieldsPresentFlags & DISPATCH_KEY_PRESENT), "Tried to define kernel twice in same op registration");
-    return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | DISPATCH_KEY_PRESENT>(std::move(op_), std::move(dispatch_key), kernel_, cache_creator_);
-  }
-
-  /**
-   * Specify the concrete function implementation for this dispatch registration
-   * @param kernel concrete function implementation to be registered
-   * @return "this" for method chaining
-   */
-  template<KernelFunction* kernel_func>
-  AT_CPP14_CONSTEXPR KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
-    static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Tried to define kernel twice in same op registration");
-    // TODO Better error message when kernel function mismatches, one common mismatch is missing cache parameter or cache parameter present while not expected.
-    return KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_func, cache_creator_);
-  }
-
-  /**
-   * Specify the concrete function implementation for this dispatch registration
-   * @param kernel concrete function implementation to be registered
-   * @return "this" for method chaining
-   */
-  template<class FuncType, FuncType* kernel_func>
-  AT_CPP14_CONSTEXPR KernelRegistrationBuilder<CacheTypeOrVoid, FieldsPresentFlags | KERNEL_PRESENT> kernel() && {
-    // TODO Better error message if FuncType is not a func type
-    return std::move(*this).template kernel<&detail::wrap_kernel<CacheTypeOrVoid, FuncType, kernel_func>::call>();
-  }
-
-  /**
-   * Specify the dispatch key for this dispatch registration
-   * @param dispatch_key dispatch key to register the function to
-   * @return "this" for method chaining
-   */
-  template<class Cache>
-  AT_CPP14_CONSTEXPR KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT> withCache() && {
-    static_assert(!(FieldsPresentFlags & CACHE_PRESENT), "Tried to define cache twice in same op registration");
-    static_assert(std::is_base_of<c10::KernelCache, Cache>::value, "Cache must inherit from c10::KernelCache");
-
-    static_assert(!(FieldsPresentFlags & KERNEL_PRESENT), "Cannot set the cache after the kernel function is already set. Please call .withCache() first and .kernel() later in the chain.");
-
-    return KernelRegistrationBuilder<Cache, FieldsPresentFlags | CACHE_PRESENT>(std::move(op_), std::move(dispatch_key_), kernel_, &cacheCreator<Cache>);
-  }
-};
-
-} // namespace c10
-
-// NB: Semicolon after applying this macro is MANDATORY
-#define C10_REGISTER_KERNEL(OperatorHandle)                                                           \
-  static KernelRegistrar MACRO_CONCAT(__kernelRegistrationBuilder_, __COUNTER__) = KernelRegistrationBuilder<void, 0>(OperatorHandle)
diff --git a/aten/src/ATen/core/dispatch/OpSchemaRegistration.h b/aten/src/ATen/core/dispatch/OpSchemaRegistration.h
deleted file mode 100644 (file)
index 2e07c3a..0000000
+++ /dev/null
@@ -1,46 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/Dispatcher.h>
-
-namespace c10 {
-namespace detail {
-class OpSchemaRegistrar final {
-public:
-  explicit OpSchemaRegistrar(FunctionSchema schema)
-  : opHandle_(c10::Dispatcher::singleton().registerSchema(std::move(schema))) {}
-
-  ~OpSchemaRegistrar() {
-    c10::Dispatcher::singleton().deregisterSchema(opHandle_);
-  }
-
-  const c10::OperatorHandle& opHandle() const {
-    return opHandle_;
-  }
-
-private:
-  c10::OperatorHandle opHandle_;
-};
-}  // namespace detail
-}  // namespace c10
-
-/**
- * Macro for defining an operator schema.  Every operator schema must
- * invoke C10_DECLARE_OP_SCHEMA in a header and C10_DEFINE_OP_SCHEMA in one (!)
- * cpp file.  Internally, this arranges for the dispatch table for
- * the operator to be created.
- */
-#define C10_DECLARE_OP_SCHEMA(Name)                                             \
-  CAFFE2_API const c10::OperatorHandle& Name();                                 \
-
-#define C10_DEFINE_OP_SCHEMA(Name, Schema)                                      \
-  C10_EXPORT const c10::OperatorHandle& Name() {                                \
-    /* must be meyers singleton to make sure this is registered before any */   \
-    /* kernels referencing it are registered. */                                \
-    static ::c10::detail::OpSchemaRegistrar registrar(Schema);                  \
-    return registrar.opHandle();                                                \
-  }                                                                             \
-  namespace {                                                                   \
-    /* to make sure the schema is registered even if it is not referenced by */ \
-    /* a kernel registration, call it in a global object.                    */ \
-    const c10::OperatorHandle& _c10_op_registration_instance_##Name = Name();   \
-  }
index e29c93c..d2f935e 100644 (file)
@@ -8,5 +8,3 @@ This folder contains the following files:
 - DispatchTable.h: Implementation of the actual dispatch mechanism. Hash table with kernels, lookup, ...
 - KernelCache.h: An interface operator kernels can use to inherit from if they need to keep around a cache between invocations
 - KernelFunction.h: The core interface (i.e. function pointer) for calling a kernel
-- OpSchemaRegistration.h (deprecated): The mechanisms to register new operators with the c10 dispatcher
-- KernelRegistration.h (deprecated): The mechanisms to register kernels with the c10 dispatcher
index 8a9430a..3c4ea05 100644 (file)
@@ -22,9 +22,9 @@ namespace detail {
   // You shouldn't call this directly; instead, use the RegisterOperators class.
   class OperatorRegistrar final {
   public:
-    explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+    explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
     : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
-      Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, cache_creator);
+      Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator));
     }
 
     OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
@@ -55,7 +55,7 @@ namespace detail {
   struct KernelRegistrationConfig final {
     TensorTypeId dispatch_key;
     KernelFunction* kernel_func = nullptr;
-    KernelCacheCreatorFunction* cache_creator_func = nullptr;
+    KernelCacheCreatorFunction cache_creator_func = nullptr;
   };
 
   // is_registration_config_parameter is a concept that returns true_type iff its argument is
diff --git a/aten/src/ATen/core/op_registration/kernel_function.h b/aten/src/ATen/core/op_registration/kernel_function.h
new file mode 100644 (file)
index 0000000..ba14d5a
--- /dev/null
@@ -0,0 +1,49 @@
+#pragma once
+
+#include <ATen/core/op_registration/kernel_functor.h>
+
+namespace c10 {
+namespace detail {
+  // WrapKernelFunction: Wraps a compile time function pointer into a kernel functor.
+  // Since it is a compile time function pointer, many compilers can inline it
+  // into the wrapper and you don't get any performance overhead for wrapping.
+  template<class FuncType, FuncType* kernel_func, class ReturnType, class ParameterList> class WrapKernelFunction_ {};
+  template<class FuncType, FuncType* kernel_func, class ReturnType, class... Parameters>
+  class WrapKernelFunction_<FuncType, kernel_func, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
+  public:
+    auto operator()(Parameters&&... args) -> decltype((*kernel_func)(std::forward<Parameters>(args)...)) {
+      return (*kernel_func)(std::forward<Parameters>(args)...);
+    }
+  };
+  template<class FuncType, FuncType* kernel_func, class Enable = guts::enable_if_t<guts::is_function_type<FuncType>::value>>
+  struct WrapKernelFunction final {
+    using type = WrapKernelFunction_<
+        FuncType,
+        kernel_func,
+        typename guts::function_traits<FuncType>::return_type,
+        typename guts::function_traits<FuncType>::parameter_types
+    >;
+  };
+}
+
+/**
+ * Use this to register an operator whose kernel is implemented by a function:
+ *
+ * Example:
+ *
+ * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(),
+ * >         c10::dispatchKey(CPUTensorId()));
+ */
+template<class FuncType, FuncType* kernel_func>
+inline constexpr auto kernel() ->
+// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
+guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value,
+decltype(kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>())> {
+  return kernel<typename detail::WrapKernelFunction<FuncType, kernel_func>::type>();
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp
new file mode 100644 (file)
index 0000000..02231d3
--- /dev/null
@@ -0,0 +1,539 @@
+#include <gtest/gtest.h>
+#include <ATen/core/op_registration/test_helpers.h>
+
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/core/Tensor.h>
+
+using c10::RegisterOperators;
+using c10::FunctionSchema;
+using c10::Argument;
+using c10::IntType;
+using c10::ListType;
+using c10::kernel;
+using c10::dispatchKey;
+using c10::TensorTypeId;
+using c10::KernelCache;
+using c10::Stack;
+using c10::guts::make_unique;
+using c10::ivalue::TensorList;
+using c10::ivalue::IntList;
+using c10::intrusive_ptr;
+using c10::ArrayRef;
+using std::unique_ptr;
+using at::Tensor;
+
+namespace {
+
+C10_DECLARE_TENSOR_TYPE(TensorType1);
+C10_DEFINE_TENSOR_TYPE(TensorType1);
+C10_DECLARE_TENSOR_TYPE(TensorType2);
+C10_DEFINE_TENSOR_TYPE(TensorType2);
+
+void errorKernel(const Tensor&) {
+  EXPECT_TRUE(false); // this kernel should never be called
+}
+
+FunctionSchema errorOpSchema(
+    "_test::error",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+int incrementKernel(const Tensor& tensor, int input) {
+  return input + 1;
+}
+
+int decrementKernel(const Tensor& tensor, int input) {
+  return input - 1;
+}
+
+FunctionSchema opSchema(
+    "_test::my_op",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+void expectCallsIncrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(6, result[0].toInt());
+}
+
+void expectCallsDecrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(4, result[0].toInt());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
+  auto registrar = RegisterOperators()
+      .op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()))
+      .op(opSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()))
+      .op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()))
+      .op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
+  auto registrar1 = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+  auto registrar2 = RegisterOperators().op(opSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+  auto registrar3 = RegisterOperators().op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()));
+  auto registrar4 = RegisterOperators().op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
+  {
+    auto registrar1 = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+    {
+      auto registrar2 = RegisterOperators().op(opSchema, kernel<decltype(decrementKernel), &decrementKernel>(), dispatchKey(TensorType2()));
+
+      // assert that schema and cpu kernel are present
+      expectCallsIncrement(TensorType1());
+      expectCallsDecrement(TensorType2());
+    }
+
+    // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
+    expectCallsIncrement(TensorType1());
+    expectDoesntFindKernel("_test::my_op", TensorType2());
+  }
+
+  // now both registrars are destructed. Assert that the whole schema is gone
+  expectDoesntFindOperator("_test::my_op");
+}
+
+bool was_called = false;
+
+void kernelWithoutOutput(const Tensor&) {
+  was_called = true;
+}
+
+FunctionSchema opWithoutOutputSchema(
+    "_test::no_return",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithoutOutputSchema, kernel<decltype(kernelWithoutOutput), &kernelWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+std::tuple<> kernelWithZeroOutputs(const Tensor&) {
+  was_called = true;
+  return std::make_tuple();
+}
+
+FunctionSchema opWithZeroOutputsSchema(
+    "_test::zero_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, kernel<decltype(kernelWithZeroOutputs), &kernelWithZeroOutputs>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+int kernelWithIntOutput(Tensor, int a, int b) {
+  return a + b;
+}
+
+FunctionSchema opWithIntOutputSchema(
+    "_test::int_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("a", IntType::get()),
+                           Argument("b", IntType::get())}),
+    (std::vector<Argument>{Argument("sum", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntOutputSchema, kernel<decltype(kernelWithIntOutput), &kernelWithIntOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(9, result[0].toInt());
+}
+
+Tensor kernelWithTensorOutput(const Tensor& input) {
+  return input;
+}
+
+FunctionSchema opWithTensorOutput(
+    "_test::returning_tensor",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorOutput, kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorOutput, kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+std::vector<Tensor> kernelWithTensorListOutput(const Tensor& input1, const Tensor& input2, const Tensor& input3) {
+  return {input1, input2, input3};
+}
+
+FunctionSchema opWithTensorListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("input1"),
+                           Argument("input2"),
+                           Argument("input3")}),
+    (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListOutputSchema, kernel<decltype(kernelWithTensorListOutput), &kernelWithTensorListOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
+}
+
+std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) {
+  return {input1, input2, input3};
+}
+
+FunctionSchema opWithIntListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input1", IntType::get()),
+                           Argument("input2", IntType::get()),
+                           Argument("input3", IntType::get())}),
+    (std::vector<Argument>{Argument("output", ListType::ofInts())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListOutputSchema, kernel<decltype(kernelWithIntListOutput), &kernelWithIntListOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toIntListRef().size());
+  EXPECT_EQ(2, result[0].toIntListRef()[0]);
+  EXPECT_EQ(4, result[0].toIntListRef()[1]);
+  EXPECT_EQ(6, result[0].toIntListRef()[2]);
+}
+
+std::tuple<Tensor, int64_t, std::vector<Tensor>> kernelWithMultipleOutputs(Tensor) {
+  return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
+    dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
+  );
+}
+
+FunctionSchema opWithMultipleOutputsSchema(
+    "_test::multiple_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{Argument("output1"),
+                           Argument("output2", IntType::get()),
+                           Argument("output3", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+     .op(opWithMultipleOutputsSchema, kernel<decltype(kernelWithMultipleOutputs), &kernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(3, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+  EXPECT_EQ(5, result[1].toInt());
+  EXPECT_EQ(2, result[2].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
+}
+
+Tensor kernelWithTensorInputByReferenceWithOutput(const Tensor& input1) {
+  return input1;
+}
+
+Tensor kernelWithTensorInputByValueWithOutput(Tensor input1) {
+  return input1;
+}
+
+FunctionSchema opWithTensorInputWithOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+Tensor captured_input;
+
+void kernelWithTensorInputByReferenceWithoutOutput(const Tensor& input1) {
+  captured_input = input1;
+}
+
+void kernelWithTensorInputByValueWithoutOutput(Tensor input1) {
+  captured_input = input1;
+}
+
+FunctionSchema opWithTensorInputWithoutOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+int captured_int_input = 0;
+
+void kernelWithIntInputWithoutOutput(Tensor, int input1) {
+  captured_int_input = input1;
+}
+
+FunctionSchema opWithIntInputWithoutOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithoutOutput, kernel<decltype(kernelWithIntInputWithoutOutput), &kernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_int_input = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_int_input);
+}
+
+int kernelWithIntInputWithOutput(Tensor, int input1) {
+  return input1 + 1;
+}
+
+FunctionSchema opWithIntInputWithOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithOutput, kernel<decltype(kernelWithIntInputWithOutput), &kernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(4, outputs[0].toInt());
+}
+
+int captured_input_list_size = 0;
+
+void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef<int64_t> input1) {
+  captured_input_list_size = input1.size();
+}
+
+FunctionSchema opWithIntListInputWithoutOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithoutOutput, kernel<decltype(kernelWithIntListInputWithoutOutput), &kernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_input_list_size);
+}
+
+int kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
+  return input1.size();
+}
+
+FunctionSchema opWithIntListInputWithOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithOutput, kernel<decltype(kernelWithIntListInputWithOutput), &kernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(3, outputs[0].toInt());
+}
+
+void kernelWithTensorListInputWithoutOutput(ArrayRef<Tensor> input1) {
+  captured_input_list_size = input1.size();
+}
+
+FunctionSchema opWithTensorListInputWithoutOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithoutOutput, kernel<decltype(kernelWithTensorListInputWithoutOutput), &kernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(2, captured_input_list_size);
+}
+
+int kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
+  return input1.size();
+}
+
+FunctionSchema opWithTensorListInputWithOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithOutput, kernel<decltype(kernelWithTensorListInputWithOutput), &kernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(2, outputs[0].toInt());
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_functor.h b/aten/src/ATen/core/op_registration/kernel_functor.h
new file mode 100644 (file)
index 0000000..19014a9
--- /dev/null
@@ -0,0 +1,187 @@
+#pragma once
+
+#include <ATen/core/op_registration/kernel_stackbased.h>
+
+namespace c10 {
+/**
+ * Inherit from OperatorKernel to implement a c10 kernel.
+ *
+ * Example:
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ *
+ * The kernel class is allowed to have members to cache things between calls
+ * but it is not allowed to change behavior based on the cache.
+ * The cache is purely a performance optimization and the kernel must
+ * return the same outputs regardless of what's in the cache.
+ *
+ * See below for how to register this kernel with PyTorch.
+ */
+class OperatorKernel : public KernelCache {};
+
+namespace detail {
+  // ivalue_to_arg_type<T>: Take an IValue that is an argument to a kernel and
+  // cast it to the type that should be passed to the kernel function.
+  // Examples: If the IValue contains a plain type like an int, return that.
+  //           If the IValue contains an IntList, return it as ArrayRef<int>.
+  template<class T>
+  struct ivalue_to_arg_type {
+    static T call(const IValue& v) {
+      return std::move(v).to<T>();
+    }
+  };
+  template<class T>
+  struct ivalue_to_arg_type<ArrayRef<T>> {
+    static ArrayRef<T> call(const IValue& v) {
+      return v.to<intrusive_ptr<ivalue::List<T>>>()->elements();
+    }
+  };
+
+  template<class T>
+  IValue return_type_to_ivalue(T&& t) {
+    return IValue(std::forward<T>(t));
+  }
+
+  template<class Functor, size_t... ivalue_arg_indices>
+  typename guts::infer_function_traits_t<Functor>::return_type call_functor_with_ivalue_args_(Functor* functor, ArrayRef<IValue> ivalue_args, guts::index_sequence<ivalue_arg_indices...>) {
+    using IValueArgTypes = typename guts::infer_function_traits_t<Functor>::parameter_types;
+    return (*functor)(ivalue_to_arg_type<guts::remove_cv_t<guts::remove_reference_t<guts::typelist::element_t<ivalue_arg_indices, IValueArgTypes>>>>::call(ivalue_args[ivalue_arg_indices])...);
+  }
+
+  template<class Functor>
+  typename guts::infer_function_traits_t<Functor>::return_type call_functor_with_ivalue_args(Functor* functor, ArrayRef<IValue> ivalue_args) {
+    constexpr size_t num_ivalue_args = guts::infer_function_traits_t<Functor>::number_of_parameters;
+    AT_ASSERTM(num_ivalue_args == ivalue_args.size(), "Wrong number of ivalue arguments");
+    return call_functor_with_ivalue_args_<Functor>(functor, ivalue_args, guts::make_index_sequence<num_ivalue_args>());
+  }
+
+  template<class OutputType>
+  struct push_outputs final {
+    static void call(OutputType&& output, Stack* stack) {
+      torch::jit::push(*stack, return_type_to_ivalue(std::move(output)));
+    }
+  };
+  template<class... OutputTypes>
+  struct push_outputs<std::tuple<OutputTypes...>> final {
+    static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
+      call_(std::move(output), stack, guts::make_index_sequence<sizeof...(OutputTypes)>());
+    }
+
+  private:
+    template<size_t... indices>
+    static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, guts::index_sequence<indices...>) {
+      (void)(stack); // silence compiler warning of weird compilers somehow thinking this parameter is unused.
+      // iterate over all outputs and push them
+      (void)std::initializer_list<int>{(
+        torch::jit::push(*stack, return_type_to_ivalue(std::move(std::get<indices>(output))))
+      , 0)...};
+    }
+  };
+
+  template<class KernelFunctor, class Enable = void> struct wrap_kernel_functor final {};
+
+  // SFINAE version for kernels that return an output
+  template<class KernelFunctor>
+  struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<!std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
+    static void call(Stack* stack, KernelCache* cache) {
+      static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel");
+
+      constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
+      KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
+      auto output = call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
+      torch::jit::pop(*stack, num_inputs);
+      push_outputs<typename guts::infer_function_traits_t<KernelFunctor>::return_type>::call(std::move(output), stack);
+    }
+  };
+
+  // SFINAE version for kernels that don't return an output
+  template<class KernelFunctor>
+  struct wrap_kernel_functor<KernelFunctor, guts::enable_if_t<std::is_same<void, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value>> final {
+    static void call(Stack* stack, KernelCache* cache) {
+      static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
+
+      constexpr size_t num_inputs = guts::infer_function_traits_t<KernelFunctor>::number_of_parameters;
+      KernelFunctor* functor = static_cast<KernelFunctor*>(cache);
+      call_functor_with_ivalue_args<KernelFunctor>(functor, torch::jit::last(*stack, num_inputs));
+      torch::jit::pop(*stack, num_inputs);
+    }
+  };
+
+  template<class KernelFunctor, class... Args>
+  class KernelFactory final {
+  public:
+    static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Kernel functor must inherit from c10::OperatorKernel.");
+    static_assert(std::is_constructible<KernelFunctor, Args...>::value, "Wrong argument types for constructor of kernel functor.");
+
+    explicit constexpr KernelFactory(Args... args)
+    : constructor_parameters_(std::move(args)...) {}
+
+    std::unique_ptr<KernelCache> operator()() const {
+      return guts::apply(
+        [] (const Args&... params) {return guts::make_unique<KernelFunctor>(params...); },
+        constructor_parameters_);
+    }
+
+  private:
+    std::tuple<Args...> constructor_parameters_;
+  };
+}
+
+/**
+ * Use this to register an operator whose kernel is implemented as a functor
+ *
+ * Example:
+ *
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel<my_kernel_cpu>(),
+ * >         c10::dispatchKey(CPUTensorId()));
+ *
+ * The functor constructor can take arguments to configure the kernel.
+ * The arguments are defined in the kernel registration.
+ * Example:
+ *
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
+ * >         : ... {...}
+ * >
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel<my_kernel_cpu>("some_configuration", 3, true),
+ * >         c10::dispatchKey(CPUTensorId()));
+ */
+template<class KernelFunctor, class... ConstructorParameters>
+inline constexpr auto kernel(ConstructorParameters&&... constructorParameters)
+// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
+-> guts::enable_if_t<
+guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+decltype(kernel(
+    &detail::wrap_kernel_functor<KernelFunctor>::call,
+    detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
+))> {
+  static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "KernelFunctor cannot be constructed with the given arguments");
+
+  return kernel(
+      &detail::wrap_kernel_functor<KernelFunctor>::call,
+      detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
+  );
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp
new file mode 100644 (file)
index 0000000..3e638ff
--- /dev/null
@@ -0,0 +1,693 @@
+#include <gtest/gtest.h>
+#include <ATen/core/op_registration/test_helpers.h>
+
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/core/Tensor.h>
+
+using c10::RegisterOperators;
+using c10::FunctionSchema;
+using c10::OperatorKernel;
+using c10::Argument;
+using c10::IntType;
+using c10::ListType;
+using c10::kernel;
+using c10::dispatchKey;
+using c10::TensorTypeId;
+using c10::KernelCache;
+using c10::Stack;
+using c10::guts::make_unique;
+using c10::ivalue::TensorList;
+using c10::ivalue::IntList;
+using c10::intrusive_ptr;
+using c10::ArrayRef;
+using std::unique_ptr;
+using at::Tensor;
+
+namespace {
+
+C10_DECLARE_TENSOR_TYPE(TensorType1);
+C10_DEFINE_TENSOR_TYPE(TensorType1);
+C10_DECLARE_TENSOR_TYPE(TensorType2);
+C10_DEFINE_TENSOR_TYPE(TensorType2);
+
+struct ErrorKernel final : public OperatorKernel {
+  void operator()(const Tensor&) {
+    EXPECT_TRUE(false); // this kernel should never be called
+  }
+};
+
+FunctionSchema errorOpSchema(
+    "_test::error",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+struct IncrementKernel final : OperatorKernel {
+  int operator()(const Tensor& tensor, int input) {
+    return input + 1;
+  }
+};
+
+struct DecrementKernel final : OperatorKernel {
+  int operator()(const Tensor& tensor, int input) {
+    return input - 1;
+  }
+};
+
+FunctionSchema opSchema(
+    "_test::my_op",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+void expectCallsIncrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(6, result[0].toInt());
+}
+
+void expectCallsDecrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(4, result[0].toInt());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
+  auto registrar = RegisterOperators()
+      .op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()))
+      .op(opSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()))
+      .op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType1()))
+      .op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
+  auto registrar1 = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+  auto registrar2 = RegisterOperators().op(opSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+  auto registrar3 = RegisterOperators().op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType1()));
+  auto registrar4 = RegisterOperators().op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
+  {
+    auto registrar1 = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+    {
+      auto registrar2 = RegisterOperators().op(opSchema, kernel<DecrementKernel>(), dispatchKey(TensorType2()));
+
+      // assert that schema and cpu kernel are present
+      expectCallsIncrement(TensorType1());
+      expectCallsDecrement(TensorType2());
+    }
+
+    // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
+    expectCallsIncrement(TensorType1());
+    expectDoesntFindKernel("_test::my_op", TensorType2());
+  }
+
+  // now both registrars are destructed. Assert that the whole schema is gone
+  expectDoesntFindOperator("_test::my_op");
+}
+
+bool was_called = false;
+
+struct KernelWithoutOutput final : OperatorKernel {
+  void operator()(const Tensor&) {
+    was_called = true;
+  }
+};
+
+FunctionSchema opWithoutOutputSchema(
+    "_test::no_return",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithoutOutputSchema, kernel<KernelWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+struct KernelWithZeroOutputs final : OperatorKernel {
+  std::tuple<> operator()(const Tensor&) {
+    was_called = true;
+    return std::make_tuple();
+  }
+};
+
+FunctionSchema opWithZeroOutputsSchema(
+    "_test::zero_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, kernel<KernelWithZeroOutputs>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+struct KernelWithIntOutput final : OperatorKernel {
+  int operator()(Tensor, int a, int b) {
+    return a + b;
+  }
+};
+
+
+FunctionSchema opWithIntOutputSchema(
+    "_test::int_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("a", IntType::get()),
+                           Argument("b", IntType::get())}),
+    (std::vector<Argument>{Argument("sum", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntOutputSchema, kernel<KernelWithIntOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(9, result[0].toInt());
+}
+
+struct KernelWithTensorOutput final : OperatorKernel {
+  Tensor operator()(const Tensor& input) {
+    return input;
+  }
+};
+
+FunctionSchema opWithTensorOutput(
+    "_test::returning_tensor",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorOutput, kernel<KernelWithTensorOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorOutput, kernel<KernelWithTensorOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+struct KernelWithTensorListOutput final : OperatorKernel {
+  std::vector<Tensor> operator()(const Tensor& input1, const Tensor& input2, const Tensor& input3) {
+    return {input1, input2, input3};
+  }
+};
+
+FunctionSchema opWithTensorListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("input1"),
+                           Argument("input2"),
+                           Argument("input3")}),
+    (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListOutputSchema, kernel<KernelWithTensorListOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
+}
+
+struct KernelWithIntListOutput final : OperatorKernel {
+  std::vector<int64_t> operator()(const Tensor&, int input1, int input2, int input3) {
+    return {input1, input2, input3};
+  }
+};
+
+FunctionSchema opWithIntListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input1", IntType::get()),
+                           Argument("input2", IntType::get()),
+                           Argument("input3", IntType::get())}),
+    (std::vector<Argument>{Argument("output", ListType::ofInts())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListOutputSchema, kernel<KernelWithIntListOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toIntListRef().size());
+  EXPECT_EQ(2, result[0].toIntListRef()[0]);
+  EXPECT_EQ(4, result[0].toIntListRef()[1]);
+  EXPECT_EQ(6, result[0].toIntListRef()[2]);
+}
+
+struct KernelWithMultipleOutputs final : OperatorKernel {
+  std::tuple<Tensor, int64_t, std::vector<Tensor>> operator()(Tensor) {
+    return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
+      dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
+    );
+  }
+};
+
+FunctionSchema opWithMultipleOutputsSchema(
+    "_test::multiple_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{Argument("output1"),
+                           Argument("output2", IntType::get()),
+                           Argument("output3", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+     .op(opWithMultipleOutputsSchema, kernel<KernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(3, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+  EXPECT_EQ(5, result[1].toInt());
+  EXPECT_EQ(2, result[2].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
+}
+
+struct KernelWithTensorInputByReferenceWithOutput final : OperatorKernel {
+  Tensor operator()(const Tensor& input1) {
+    return input1;
+  }
+};
+
+struct KernelWithTensorInputByValueWithOutput final : OperatorKernel {
+  Tensor operator()(Tensor input1) {
+    return input1;
+  }
+};
+
+FunctionSchema opWithTensorInputWithOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+Tensor captured_input;
+
+struct KernelWithTensorInputByReferenceWithoutOutput final : OperatorKernel {
+  void operator()(const Tensor& input1) {
+    captured_input = input1;
+  }
+};
+
+struct KernelWithTensorInputByValueWithoutOutput final : OperatorKernel {
+  void operator()(Tensor input1) {
+    captured_input = input1;
+  }
+};
+
+FunctionSchema opWithTensorInputWithoutOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+int captured_int_input = 0;
+
+struct KernelWithIntInputWithoutOutput final : OperatorKernel {
+  void operator()(Tensor, int input1) {
+    captured_int_input = input1;
+  }
+};
+
+FunctionSchema opWithIntInputWithoutOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithoutOutput, kernel<KernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_int_input = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_int_input);
+}
+
+struct KernelWithIntInputWithOutput final : OperatorKernel {
+  int operator()(Tensor, int input1) {
+    return input1 + 1;
+  }
+};
+
+FunctionSchema opWithIntInputWithOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithOutput, kernel<KernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(4, outputs[0].toInt());
+}
+
+int captured_input_list_size = 0;
+
+struct KernelWithIntListInputWithoutOutput final : OperatorKernel {
+  void operator()(Tensor, ArrayRef<int64_t> input1) {
+    captured_input_list_size = input1.size();
+  }
+};
+
+FunctionSchema opWithIntListInputWithoutOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithoutOutput, kernel<KernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_input_list_size);
+}
+
+struct KernelWithIntListInputWithOutput final : OperatorKernel {
+  int operator()(Tensor, ArrayRef<int64_t> input1) {
+    return input1.size();
+  }
+};
+
+FunctionSchema opWithIntListInputWithOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithOutput, kernel<KernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(3, outputs[0].toInt());
+}
+
+struct KernelWithTensorListInputWithoutOutput final : OperatorKernel {
+  void operator()(ArrayRef<Tensor> input1) {
+    captured_input_list_size = input1.size();
+  }
+};
+
+FunctionSchema opWithTensorListInputWithoutOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithoutOutput, kernel<KernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(2, captured_input_list_size);
+}
+
+struct KernelWithTensorListInputWithOutput final : OperatorKernel {
+  int operator()(ArrayRef<Tensor> input1) {
+    return input1.size();
+  }
+};
+
+FunctionSchema opWithTensorListInputWithOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithOutput, kernel<KernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(2, outputs[0].toInt());
+}
+
+class KernelWithCache final : public OperatorKernel {
+public:
+  KernelWithCache(): counter(3) {}
+
+  int64_t operator()(Tensor) {
+    return ++counter;
+  }
+private:
+  int64_t counter;
+};
+
+FunctionSchema opWithCacheSchema(
+    "_test::cache_op",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithCache_thenCacheIsKeptCorrectly) {
+  auto registrar = RegisterOperators()
+      .op(opWithCacheSchema, kernel<KernelWithCache>(), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::cache_op", "");
+  ASSERT_TRUE(op.has_value());
+
+  // expect first time calling returns a 4 (4 is the initial value in the cache)
+  auto stack = makeStack(dummyTensor(TensorType1()));
+  auto kernel = c10::Dispatcher::singleton().lookup(*op, &stack);
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(4, stack[0].toInt());
+
+  // expect second time calling returns a 5
+  stack = makeStack(dummyTensor(TensorType1()));
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(5, stack[0].toInt());
+
+  // expect third time calling returns a 6
+  stack = makeStack(dummyTensor(TensorType1()));
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(6, stack[0].toInt());
+}
+
+class KernelWithConstructorArg final : public OperatorKernel {
+public:
+  explicit KernelWithConstructorArg(int64_t offset)
+  : offset_(offset) {}
+
+  int64_t operator()(const Tensor&, int64_t input) {
+    return input + offset_;
+  }
+
+private:
+  int64_t offset_;
+};
+
+FunctionSchema opWithConstructorArgsSchema(
+    "_test::offset_op",
+    "",
+    (std::vector<Argument>{Argument("tensor"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithConstructorArg_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithConstructorArgsSchema, kernel<KernelWithConstructorArg>(2), dispatchKey(TensorType1()))
+      .op(opWithConstructorArgsSchema, kernel<KernelWithConstructorArg>(4), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::offset_op", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 4);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(6, outputs[0].toInt());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()), 4);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(8, outputs[0].toInt());
+}
+
+class KernelWithMultipleConstructorArgs final : public OperatorKernel {
+public:
+  explicit KernelWithMultipleConstructorArgs(int64_t offset1, int64_t offset2)
+  : offset_(offset1 + offset2) {}
+
+  int64_t operator()(const Tensor&, int64_t input) {
+    return input + offset_;
+  }
+
+private:
+  int64_t offset_;
+};
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstructorArgs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithConstructorArgsSchema, kernel<KernelWithMultipleConstructorArgs>(2, 3), dispatchKey(TensorType1()))
+      .op(opWithConstructorArgsSchema, kernel<KernelWithMultipleConstructorArgs>(4, 5), dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::offset_op", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 4);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(9, outputs[0].toInt());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()), 4);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(13, outputs[0].toInt());
+}
+
+
+}
index 09750aa..412a5dd 100644 (file)
 namespace c10 {
 
 namespace detail {
+
+  template<class KernelCacheCreatorFunction_>
   struct KernelRegistrationConfigParameter final {
-    explicit constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func)
-    : kernel_func_(kernel_func), cache_creator_func_(std::move(cache_creator_func)) {
+    template<class KernelCacheCreatorFunction__>
+    constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func)
+    : kernel_func_(kernel_func), cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func)) {
     }
 
-    void apply(KernelRegistrationConfig* registration) const {
+    void apply(KernelRegistrationConfig* registration) const {
       registration->kernel_func = kernel_func_;
       registration->cache_creator_func = cache_creator_func_;
     }
 
+    void apply(KernelRegistrationConfig* registration) && {
+      registration->kernel_func = kernel_func_;
+      registration->cache_creator_func = std::move(cache_creator_func_);
+    }
+
   private:
     KernelFunction* kernel_func_;
-    KernelCacheCreatorFunction* cache_creator_func_;
+    KernelCacheCreatorFunction_ cache_creator_func_;
   };
 
-  static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+  static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
 }
 
 /**
@@ -52,8 +60,11 @@ namespace detail {
  * >         c10::kernel(&my_kernel_cpu, &my_cache_creator),
  * >         c10::dispatchKey(CPUTensorId()));
  */
-inline constexpr detail::KernelRegistrationConfigParameter kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator) {
-  return detail::KernelRegistrationConfigParameter(kernel_func, cache_creator);
+template<class KernelCacheCreatorFunction_>
+inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
+  static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+
+  return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator)};
 }
 
 }
index d6bd887..96e12ce 100644 (file)
@@ -8,6 +8,8 @@
 #include <ATen/core/op_registration/base.h>
 #include <ATen/core/op_registration/dispatch_key.h>
 #include <ATen/core/op_registration/kernel_stackbased.h>
+#include <ATen/core/op_registration/kernel_functor.h>
+#include <ATen/core/op_registration/kernel_function.h>
 
 namespace c10 {
 
@@ -63,15 +65,15 @@ public:
   guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
   op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
     detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
-    registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, config.cache_creator_func);
+    registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
     return std::move(*this);
   }
 
   // TODO error if dispatch key is not specified
-  // TODO Add functor, function and lambda based kernel APIs
+  // TODO Add deprecated function and lambda based kernel APIs
 
 private:
-  std::vector<c10::detail::OperatorRegistrar> registrars_;
+  std::vector<detail::OperatorRegistrar> registrars_;
 };
 
 }
index ed9d35a..595912c 100644 (file)
@@ -8,11 +8,11 @@
 #include <c10/core/CPUAllocator.h>
 
 template<class... Inputs>
-std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
+inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
   return {std::forward<Inputs>(inputs)...};
 }
 
-at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) {
+inline at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) {
   auto* allocator = c10::GetCPUAllocator();
   int64_t nelements = 1;
   auto dtype = caffe2::TypeMeta::Make<float>();
@@ -26,21 +26,21 @@ at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) {
 }
 
 template<class... Args>
-std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... args) {
+inline std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... args) {
   auto stack = makeStack(std::forward<Args>(args)...);
   auto kernel = c10::Dispatcher::singleton().lookup(op, &stack);
   kernel.call(&stack);
   return stack;
 }
 
-void expectDoesntFindKernel(const char* op_name, c10::TensorTypeId dispatch_key) {
+inline void expectDoesntFindKernel(const char* op_name, c10::TensorTypeId dispatch_key) {
   auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
   EXPECT_ANY_THROW(
     callOp(*op, dummyTensor(dispatch_key), 5);
   );
 }
 
-void expectDoesntFindOperator(const char* op_name) {
+inline void expectDoesntFindOperator(const char* op_name) {
   auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
   EXPECT_FALSE(op.has_value());
 }
index c1c375d..2d7345b 100644 (file)
@@ -183,7 +183,7 @@ template<typename... Ts> using void_t = typename make_void<Ts...>::type;
 
 template <class F, class Tuple>
 inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
-  return std::apply(c10::guts::forward<F>(f), c10::guts::forward<Tuple>(t));
+  return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
 }
 
 #else
index 8fa7724..60e4553 100644 (file)
@@ -44,6 +44,16 @@ struct function_traits<Result (Args...)> {
 };
 
 /**
+ * Evaluates to true_type, iff the given class is a Functor
+ * (i.e. has a call operator with some set of arguments)
+ */
+
+template<class Functor, class Enable = void>
+struct is_functor : std::false_type {};
+template<class Functor>
+struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
+
+/**
  * infer_function_traits: creates a `function_traits` type for a simple
  * function (pointer) or functor (lambda/struct). Currently does not support
  * class methods.
index 262655b..80db17f 100644 (file)
@@ -59,23 +59,6 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c1
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/fc_cpu.cc)
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/enforce_finite_cpu.cc)
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/add_cpu.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/sigmoid.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/filler.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/expand_dims.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/mul.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/relu.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/stop_gradient.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/enforce_finite.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/cast.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/averaged_loss.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/batch_matmul.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/batch_gather.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/fc.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/concat.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/sparse_lengths_sum.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/add.cc)
-set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/flatten.cc)
 
 set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp})
 # exclude test files and gpu files
index 9ec1aa1..8c67e0d 100644 (file)
@@ -1,6 +1,6 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/operators/elementwise_ops_utils.h"
-#include "caffe2/operators/experimental/c10/schemas/add.h"
 #include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
@@ -69,11 +69,25 @@ void add_op_cpu_impl(
       C.mutable_data<DataType>(),
       static_cast<CPUContext*>(&context));
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Add",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("input1"),
+            c10::Argument("input2"),
+            c10::Argument("output"),
+            c10::Argument("legacy_broadcast", BoolType::get()),
+            c10::Argument("axis", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<decltype(add_op_cpu_impl<float>), &add_op_cpu_impl<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Add)
-    .kernel<decltype(caffe2::add_op_cpu_impl<float>), &caffe2::add_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Add",
+    C10Add_DontUseThisOpYet)
+
+} // namespace caffe2
index cc5823d..cd8f090 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/averaged_loss.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -10,45 +10,53 @@ using std::vector;
 namespace caffe2 {
 namespace {
 
-struct Cache final : public c10::KernelCache {
-  at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU)));
-};
-
 template <class T, class Context>
-void averaged_loss_op_cpu_impl(
-    const at::Tensor& X_,
-    const at::Tensor& sum_,
-    Cache* state) {
-  Tensor X{C10Tensor(X_)};
-  Tensor sum{C10Tensor(sum_)};
-  CPUContext context;
-
-  sum.Resize(vector<int64_t>());
-
-  T* data = sum.template mutable_data<T>();
-
-  Tensor scratch(state->scratch);
-  caffe2::math::Sum<T, Context>(
-      X.numel(),
-      X.template data<T>(),
-      data,
-      static_cast<Context*>(&context),
-      &scratch);
-  if (X.numel() > 0) {
-    caffe2::math::Scale<T, T, Context>(
-        1,
-        static_cast<T>(1.) / X.numel(),
-        sum.template data<T>(),
+class averaged_loss_cpu final : public c10::OperatorKernel {
+ public:
+  void operator()(const at::Tensor& X_, const at::Tensor& sum_) {
+    Tensor X{C10Tensor(X_)};
+    Tensor sum{C10Tensor(sum_)};
+    CPUContext context;
+
+    sum.Resize(vector<int64_t>());
+
+    T* data = sum.template mutable_data<T>();
+
+    Tensor scratch(scratch_);
+    caffe2::math::Sum<T, Context>(
+        X.numel(),
+        X.template data<T>(),
         data,
-        static_cast<Context*>(&context));
+        static_cast<Context*>(&context),
+        &scratch);
+    if (X.numel() > 0) {
+      caffe2::math::Scale<T, T, Context>(
+          1,
+          static_cast<T>(1.) / X.numel(),
+          sum.template data<T>(),
+          data,
+          static_cast<Context*>(&context));
+    }
   }
-}
+
+ private:
+  at::Tensor scratch_ = at::Tensor(C10Tensor(empty({}, CPU)));
+};
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::AveragedLoss",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<averaged_loss_cpu<float, CPUContext>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
-    .withCache<caffe2::Cache>()
-    .kernel<decltype(caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::AveragedLoss",
+    C10AveragedLoss_DontUseThisOpYet)
+
+} // namespace caffe2
index a34c098..c41ead2 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/batch_gather.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -63,11 +63,22 @@ void batch_gather_op_cpu(const at::Tensor& data,
     default: throw std::runtime_error(string() + "Unsupported dtype: " + toString(data.scalar_type()));
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::BatchGather",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("data"),
+                                    c10::Argument("indices"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<decltype(batch_gather_op_cpu), &batch_gather_op_cpu>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
-    .kernel<decltype(caffe2::batch_gather_op_cpu), &caffe2::batch_gather_op_cpu>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::BatchGather",
+    C10BatchGather_DontUseThisOpYet)
+
+} // namespace caffe2
index dfe8523..4970f5e 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/batch_matmul.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -16,264 +16,277 @@ struct Cache final : public c10::KernelCache {
 };
 
 template <class T, class Context>
-void batch_matmul_op_cpu_impl(
-    const at::Tensor& A_,
-    const at::Tensor& B_,
-    const at::Tensor& Y_,
-    int64_t trans_a,
-    int64_t trans_b,
-    int64_t broadcast,
-    Cache* cache) {
-  Tensor A{C10Tensor(A_)};
-  Tensor B{C10Tensor(B_)};
-  Tensor Y{C10Tensor(Y_)};
-  CPUContext context;
-  using Engine = caffe2::DefaultEngine;
-
-  auto ndims_A = A.dim();
-  auto dims_A = A.sizes().vec();
-  auto ndims_B = B.dim();
-  auto dims_B = B.sizes().vec();
+class batch_matmul_cpu final : public c10::OperatorKernel {
+ public:
+  void operator()(
+      const at::Tensor& A_,
+      const at::Tensor& B_,
+      const at::Tensor& Y_,
+      int64_t trans_a,
+      int64_t trans_b,
+      int64_t broadcast) {
+    Tensor A{C10Tensor(A_)};
+    Tensor B{C10Tensor(B_)};
+    Tensor Y{C10Tensor(Y_)};
+    CPUContext context;
+    using Engine = caffe2::DefaultEngine;
 
-  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
-    std::stringstream ss;
-    ss << "Inputs with dimensions A = ";
-    ss << dim1;
-    ss << " and B = ";
-    ss << dim2;
-    ss << " is not supported with broadcast=0. Did you forget to set the "
-          "broadcast flag?";
-    return ss.str();
-  };
+    auto ndims_A = A.dim();
+    auto dims_A = A.sizes().vec();
+    auto ndims_B = B.dim();
+    auto dims_B = B.sizes().vec();
 
-  // These should all be false if we're not broadcasting.
-  bool dimMismatch = ndims_A != ndims_B;
-  bool dimsLessThan1D = ndims_A < 2;
-  CAFFE_ENFORCE(
-      broadcast || (!dimMismatch && !dimsLessThan1D),
-      noBroadcastErrorMsg(ndims_A, ndims_B));
+    auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
+      std::stringstream ss;
+      ss << "Inputs with dimensions A = ";
+      ss << dim1;
+      ss << " and B = ";
+      ss << dim2;
+      ss << " is not supported with broadcast=0. Did you forget to set the "
+            "broadcast flag?";
+      return ss.str();
+    };
 
-  auto* data_A = A.template data<T>();
-  auto* data_B = B.template data<T>();
+    // These should all be false if we're not broadcasting.
+    bool dimMismatch = ndims_A != ndims_B;
+    bool dimsLessThan1D = ndims_A < 2;
+    CAFFE_ENFORCE(
+        broadcast || (!dimMismatch && !dimsLessThan1D),
+        noBroadcastErrorMsg(ndims_A, ndims_B));
 
-  auto dimMismatchErrorString = [](size_t dimnum1,
-                                   size_t dim1,
-                                   size_t dimnum2,
-                                   size_t dim2,
-                                   bool trans_a_,
-                                   bool trans_b_) {
-    std::stringstream ss;
-    ss << "Expected dimension ";
-    ss << dimnum1;
-    ss << " of tensor A with value ";
-    ss << dim1;
-    ss << " to match dimension ";
-    ss << dimnum2;
-    ss << " of tensor B with value ";
-    ss << dim2;
-    ss << ". trans_a = ";
-    ss << trans_a_;
-    ss << " trans_b = ";
-    ss << trans_b_;
-    return ss.str();
-  };
-
-  if (ndims_A == 1 && ndims_B == 1) {
-    // vector-vector
-    CAFFE_ENFORCE_EQ(
-        dims_A[0],
-        dims_B[0],
-        "Vector-vector product requires each of the vectors to "
-        "be the same size.");
-    Y.Resize(1);
-    math::Dot<T, Context>(
-        dims_A[0], data_A, data_B, Y.template mutable_data<T>(), static_cast<Context*>(&context));
-  } else {
-    bool A_broadcasted = false, B_broadcasted = false;
-    if (ndims_A == 1) {
-      dims_A.insert(dims_A.begin(), 1);
-      ndims_A = 2;
-      A_broadcasted = true;
-    }
-    if (ndims_B == 1) {
-      dims_B.push_back(1);
-      ndims_B = 2;
-      B_broadcasted = true;
-    }
-    // matrix-matrix with batches
-    // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
-    // In the event that A or B are one-dimensional, the trailing or leading
-    // 1 is not added to the output tensor's size.
+    auto* data_A = A.template data<T>();
+    auto* data_B = B.template data<T>();
 
-    // First step: partition the tensors into inner and outer blocks.
-    // Ignoring the last two dimensions of A and B, ensure that one of the
-    // tensors' dimensions is a suffix of the other. For example,
-    // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
-    // dimensions of size 2 and 3 will be broadcasted, so we partition into
-    // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
-    size_t num_inner_dims = std::min(ndims_A, ndims_B);
-    for (size_t i = 2; i < num_inner_dims; ++i) {
-      auto first_r_itr = dims_A.rbegin();
-      auto second_r_itr = dims_B.rbegin();
-      CAFFE_ENFORCE_EQ(
-          *(first_r_itr + i),
-          *(second_r_itr + i),
-          dimMismatchErrorString(
-              ndims_A - i - 1,
-              *(first_r_itr + i),
-              ndims_B - i - 1,
-              *(second_r_itr + i),
-              trans_a,
-              trans_b));
-    }
-    size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
+    auto dimMismatchErrorString = [](size_t dimnum1,
+                                     size_t dim1,
+                                     size_t dimnum2,
+                                     size_t dim2,
+                                     bool trans_a_,
+                                     bool trans_b_) {
+      std::stringstream ss;
+      ss << "Expected dimension ";
+      ss << dimnum1;
+      ss << " of tensor A with value ";
+      ss << dim1;
+      ss << " to match dimension ";
+      ss << dimnum2;
+      ss << " of tensor B with value ";
+      ss << dim2;
+      ss << ". trans_a = ";
+      ss << trans_a_;
+      ss << " trans_b = ";
+      ss << trans_b_;
+      return ss.str();
+    };
 
-    // Standard M, N, and K parameters respecting GEMM API and transpose
-    // flags
-    size_t M, N, K, K_dim;
-    if (trans_a) {
-      M = dims_A[ndims_A - 1];
-      K = dims_A[ndims_A - 2];
-      K_dim = ndims_A - 2;
-    } else {
-      M = dims_A[ndims_A - 2];
-      K = dims_A[ndims_A - 1];
-      K_dim = ndims_A - 1;
-    }
-    if (trans_b) {
-      N = dims_B[ndims_B - 2];
+    if (ndims_A == 1 && ndims_B == 1) {
+      // vector-vector
       CAFFE_ENFORCE_EQ(
-          K,
-          dims_B[ndims_B - 1],
-          dimMismatchErrorString(
-              K_dim,
-              K,
-              ndims_B - 1,
-              dims_B[ndims_B - 1],
-              trans_a,
-              trans_b));
+          dims_A[0],
+          dims_B[0],
+          "Vector-vector product requires each of the vectors to "
+          "be the same size.");
+      Y.Resize(1);
+      math::Dot<T, Context>(
+          dims_A[0],
+          data_A,
+          data_B,
+          Y.template mutable_data<T>(),
+          static_cast<Context*>(&context));
     } else {
-      N = dims_B[ndims_B - 1];
-      CAFFE_ENFORCE_EQ(
-          K,
-          dims_B[ndims_B - 2],
-          dimMismatchErrorString(
-              K_dim,
-              K,
-              ndims_B - 2,
-              dims_B[ndims_B - 2],
-              trans_a,
-              trans_b));
-    }
+      bool A_broadcasted = false, B_broadcasted = false;
+      if (ndims_A == 1) {
+        dims_A.insert(dims_A.begin(), 1);
+        ndims_A = 2;
+        A_broadcasted = true;
+      }
+      if (ndims_B == 1) {
+        dims_B.push_back(1);
+        ndims_B = 2;
+        B_broadcasted = true;
+      }
+      // matrix-matrix with batches
+      // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
+      // In the event that A or B are one-dimensional, the trailing or leading
+      // 1 is not added to the output tensor's size.
 
-    // Calculate output tensor shapes [B..., (M), (N)]
-    // Batch dimensions will be broadcasted out to those of the longer tensor
-    // A or B. Either M or N are optional if A or B, respectively are 1-D.
-    std::vector<int64_t> new_dims;
-    if (ndims_A >= ndims_B) {
-      new_dims.assign(dims_A.begin(), dims_A.end() - 2);
-    } else {
-      new_dims.assign(dims_B.begin(), dims_B.end() - 2);
-    }
-    if (!A_broadcasted) {
-      new_dims.push_back(M);
-    } else {
-      new_dims.push_back(1);
-    }
-    if (!B_broadcasted) {
-      new_dims.push_back(N);
-    } else {
-      new_dims.push_back(1);
-    }
+      // First step: partition the tensors into inner and outer blocks.
+      // Ignoring the last two dimensions of A and B, ensure that one of the
+      // tensors' dimensions is a suffix of the other. For example,
+      // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
+      // dimensions of size 2 and 3 will be broadcasted, so we partition into
+      // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
+      size_t num_inner_dims = std::min(ndims_A, ndims_B);
+      for (size_t i = 2; i < num_inner_dims; ++i) {
+        auto first_r_itr = dims_A.rbegin();
+        auto second_r_itr = dims_B.rbegin();
+        CAFFE_ENFORCE_EQ(
+            *(first_r_itr + i),
+            *(second_r_itr + i),
+            dimMismatchErrorString(
+                ndims_A - i - 1,
+                *(first_r_itr + i),
+                ndims_B - i - 1,
+                *(second_r_itr + i),
+                trans_a,
+                trans_b));
+      }
+      size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
 
-    // Calculate strides. Continuing our example above,
-    //   [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
-    // We calculate this as follows:
-    //   1) Treat the outer batch dimensions as flattened, i.e. view the B
-    //      tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
-    //      soning is analogous for the case where # dims A >= # dims B.
-    //   2) Perform this operation:
-    //        for i in range(6):
-    //          Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
-    size_t A_stride = 1; // How far to increment A pointer each itr
-    size_t B_stride = 1; // How far to increment B pointer each itr
-    size_t Y_stride = 1; // How far to increment Y pointer each itr
-    // How many "inner batches" we have. That is, the product of sizes for
-    // the slices excluding M, K, and N, for their respective matrices.
-    size_t num_sub_batches = 1;
-    if (ndims_A >= ndims_B) {
-      auto first_r_itr = dims_A.rbegin();
-      auto output_r_itr = new_dims.rbegin();
-      for (size_t i = 0; i < num_inner_dims; ++i) {
-        A_stride *= *(first_r_itr + i);
-        Y_stride *= *(output_r_itr + i);
-        if (i >= 2) {
-          num_sub_batches *= *(first_r_itr + i);
-        }
+      // Standard M, N, and K parameters respecting GEMM API and transpose
+      // flags
+      size_t M, N, K, K_dim;
+      if (trans_a) {
+        M = dims_A[ndims_A - 1];
+        K = dims_A[ndims_A - 2];
+        K_dim = ndims_A - 2;
+      } else {
+        M = dims_A[ndims_A - 2];
+        K = dims_A[ndims_A - 1];
+        K_dim = ndims_A - 1;
       }
-      B_stride = 0;
-    } else {
-      A_stride = 0;
-      auto second_r_itr = dims_B.rbegin();
-      auto output_r_itr = new_dims.rbegin();
-      for (size_t i = 0; i < num_inner_dims; ++i) {
-        B_stride *= *(second_r_itr + i);
-        Y_stride *= *(output_r_itr + i);
-        if (i >= 2) {
-          num_sub_batches *= *(second_r_itr + i);
+      if (trans_b) {
+        N = dims_B[ndims_B - 2];
+        CAFFE_ENFORCE_EQ(
+            K,
+            dims_B[ndims_B - 1],
+            dimMismatchErrorString(
+                K_dim, K, ndims_B - 1, dims_B[ndims_B - 1], trans_a, trans_b));
+      } else {
+        N = dims_B[ndims_B - 1];
+        CAFFE_ENFORCE_EQ(
+            K,
+            dims_B[ndims_B - 2],
+            dimMismatchErrorString(
+                K_dim, K, ndims_B - 2, dims_B[ndims_B - 2], trans_a, trans_b));
+      }
+
+      // Calculate output tensor shapes [B..., (M), (N)]
+      // Batch dimensions will be broadcasted out to those of the longer tensor
+      // A or B. Either M or N are optional if A or B, respectively are 1-D.
+      std::vector<int64_t> new_dims;
+      if (ndims_A >= ndims_B) {
+        new_dims.assign(dims_A.begin(), dims_A.end() - 2);
+      } else {
+        new_dims.assign(dims_B.begin(), dims_B.end() - 2);
+      }
+      if (!A_broadcasted) {
+        new_dims.push_back(M);
+      } else {
+        new_dims.push_back(1);
+      }
+      if (!B_broadcasted) {
+        new_dims.push_back(N);
+      } else {
+        new_dims.push_back(1);
+      }
+
+      // Calculate strides. Continuing our example above,
+      //   [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
+      // We calculate this as follows:
+      //   1) Treat the outer batch dimensions as flattened, i.e. view the B
+      //      tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
+      //      soning is analogous for the case where # dims A >= # dims B.
+      //   2) Perform this operation:
+      //        for i in range(6):
+      //          Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
+      size_t A_stride = 1; // How far to increment A pointer each itr
+      size_t B_stride = 1; // How far to increment B pointer each itr
+      size_t Y_stride = 1; // How far to increment Y pointer each itr
+      // How many "inner batches" we have. That is, the product of sizes for
+      // the slices excluding M, K, and N, for their respective matrices.
+      size_t num_sub_batches = 1;
+      if (ndims_A >= ndims_B) {
+        auto first_r_itr = dims_A.rbegin();
+        auto output_r_itr = new_dims.rbegin();
+        for (size_t i = 0; i < num_inner_dims; ++i) {
+          A_stride *= *(first_r_itr + i);
+          Y_stride *= *(output_r_itr + i);
+          if (i >= 2) {
+            num_sub_batches *= *(first_r_itr + i);
+          }
+        }
+        B_stride = 0;
+      } else {
+        A_stride = 0;
+        auto second_r_itr = dims_B.rbegin();
+        auto output_r_itr = new_dims.rbegin();
+        for (size_t i = 0; i < num_inner_dims; ++i) {
+          B_stride *= *(second_r_itr + i);
+          Y_stride *= *(output_r_itr + i);
+          if (i >= 2) {
+            num_sub_batches *= *(second_r_itr + i);
+          }
         }
       }
-    }
 
-    size_t num_outer_batches = 1;
-    for (size_t i = 0; i < num_outer_dims; ++i) {
-      num_outer_batches *= new_dims[i];
-    }
+      size_t num_outer_batches = 1;
+      for (size_t i = 0; i < num_outer_dims; ++i) {
+        num_outer_batches *= new_dims[i];
+      }
 
-    // Mutually exclusive since otherwise we would've taken the vector-vector
-    // path above
-    if (A_broadcasted) {
-      new_dims.erase(new_dims.end() - 2);
-    } else if (B_broadcasted) {
-      new_dims.erase(new_dims.end() - 1);
-    }
+      // Mutually exclusive since otherwise we would've taken the vector-vector
+      // path above
+      if (A_broadcasted) {
+        new_dims.erase(new_dims.end() - 2);
+      } else if (B_broadcasted) {
+        new_dims.erase(new_dims.end() - 1);
+      }
 
-    // Allocate output tensor
-    Y.Resize(new_dims);
-    auto* Y_data = Y.template mutable_data<T>();
+      // Allocate output tensor
+      Y.Resize(new_dims);
+      auto* Y_data = Y.template mutable_data<T>();
 
-    // Zero batch dimension indicates no elements
-    if (num_sub_batches == 0 || num_outer_batches == 0) {
-      return;
-    }
+      // Zero batch dimension indicates no elements
+      if (num_sub_batches == 0 || num_outer_batches == 0) {
+        return;
+      }
 
-    // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
-    for (size_t p = 0; p < num_outer_batches; ++p) {
-      math::GemmStridedBatched<T, Context, Engine>(
-          trans_a ? CblasTrans : CblasNoTrans,
-          trans_b ? CblasTrans : CblasNoTrans,
-          num_sub_batches,
-          M,
-          N,
-          K,
-          1.0f,
-          data_A + p * A_stride,
-          M * K,
-          data_B + p * B_stride,
-          K * N,
-          0.0f,
-          Y_data + p * Y_stride,
-          M * N,
-          static_cast<Context*>(&context));
+      // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
+      for (size_t p = 0; p < num_outer_batches; ++p) {
+        math::GemmStridedBatched<T, Context, Engine>(
+            trans_a ? CblasTrans : CblasNoTrans,
+            trans_b ? CblasTrans : CblasNoTrans,
+            num_sub_batches,
+            M,
+            N,
+            K,
+            1.0f,
+            data_A + p * A_stride,
+            M * K,
+            data_B + p * B_stride,
+            K * N,
+            0.0f,
+            Y_data + p * Y_stride,
+            M * N,
+            static_cast<Context*>(&context));
+      }
     }
   }
-}
+
+ private:
+  at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU)));
+};
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::BatchMatmul",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("A"),
+            c10::Argument("B"),
+            c10::Argument("output"),
+            c10::Argument("trans_a", IntType::get()),
+            c10::Argument("trans_b", IntType::get()),
+            c10::Argument("broadcast", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<batch_matmul_cpu<float, CPUContext>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::BatchMatmul)
-    .withCache<caffe2::Cache>()
-    .kernel<decltype(caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::batch_matmul_op_cpu_impl<float, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::BatchMatmul",
+    C10BatchMatMul_DontUseThisOpYet)
+
+} // namespace caffe2
index 66ffcf5..1fa560e 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/cast.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::CPUContext;
 using caffe2::Tensor;
@@ -85,11 +85,24 @@ void cast_op_cpu(
     default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type()));
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Cast",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("input"),
+            c10::Argument("output"),
+            c10::Argument("to_dtype", IntType::get()),
+        }),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<decltype(cast_op_cpu), &cast_op_cpu>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Cast)
-    .kernel<decltype(caffe2::cast_op_cpu), &caffe2::cast_op_cpu>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Cast",
+    C10Cast_DontUseThisOpYet)
+
+} // namespace caffe2
index 5a87d20..c23aff3 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/concat.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::CPUContext;
@@ -103,11 +103,27 @@ void concat_op_cpu_impl(
     output_offset += axis_dim * after * input.itemsize();
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Concat",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("inputs", ListType::ofTensors()),
+            c10::Argument("output"),
+            c10::Argument("split_info", FloatType::get()),
+            c10::Argument("add", IntType::get()),
+            c10::Argument("add_axis", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(concat_op_cpu_impl<float, CPUContext>),
+        &concat_op_cpu_impl<float, CPUContext>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Concat)
-    .kernel<decltype(caffe2::concat_op_cpu_impl<float, CPUContext>), &caffe2::concat_op_cpu_impl<float, CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Concat",
+    C10Concat_DontUseThisOpYet)
+
+} // namespace caffe2
index 2e9f608..d833a67 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/enforce_finite.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::CPUContext;
 using caffe2::Tensor;
@@ -23,11 +23,22 @@ void enforce_finite_op_impl_cpu(const at::Tensor& input_) {
         input_data[i]);
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::EnforceFinite",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(enforce_finite_op_impl_cpu<float>),
+        &enforce_finite_op_impl_cpu<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::EnforceFinite)
-    .kernel<decltype(caffe2::enforce_finite_op_impl_cpu<float>), &caffe2::enforce_finite_op_impl_cpu<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::EnforceFinite",
+    C10EnforceFinite_DontUseThisOpYet)
+
+} // namespace caffe2
index a719500..81566e7 100644 (file)
@@ -1,65 +1,74 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/expand_dims.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::Tensor;
 
 namespace caffe2 {
 namespace {
 
-struct Cache final : public c10::KernelCache {
-  std::vector<int64_t> dims;
-  bool initialized = false;
-};
-
 template <class DataType>
-void expand_dims_op_cpu_impl(
-    const at::Tensor& input_,
-    const at::Tensor& output_,
-    ArrayRef<int64_t> dims,
-    Cache* cache) {
-  Tensor input{C10Tensor(input_)};
-  Tensor output{C10Tensor(output_)};
+class expand_dims_cpu final : public c10::OperatorKernel {
+ public:
+  void operator()(
+      const at::Tensor& input_,
+      const at::Tensor& output_,
+      ArrayRef<int64_t> dims) {
+    Tensor input{C10Tensor(input_)};
+    Tensor output{C10Tensor(output_)};
 
-  if (!cache->initialized) {
-    cache->dims = dims.vec();
-    auto originalSize = cache->dims.size();
-    CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
-    std::sort(cache->dims.begin(), cache->dims.end());
-    cache->dims.erase(
-        std::unique(cache->dims.begin(), cache->dims.end()), cache->dims.end());
-    if (cache->dims.size() < originalSize) {
-      LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
+    if (!initialized_) {
+      dims_ = dims.vec();
+      auto originalSize = dims_.size();
+      CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
+      std::sort(dims_.begin(), dims_.end());
+      dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
+      if (dims_.size() < originalSize) {
+        LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
+      }
+      CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
+      initialized_ = true;
     }
-    CAFFE_ENFORCE(
-        cache->dims.front() >= 0, "Dimension ids must be non-negative.");
-    cache->initialized = true;
-  }
 
-  output.CopyFrom(input);
-  if (cache->dims.empty()) {
-    return;
-  }
+    output.CopyFrom(input);
+    if (dims_.empty()) {
+      return;
+    }
 
-  auto newDims = input.sizes().vec();
-  CAFFE_ENFORCE_GE(
-      input.sizes().size() + cache->dims.size(),
-      cache->dims.back() + 1,
-      "Input needs at least ",
-      (1 + cache->dims.back() - cache->dims.size()),
-      " dimensions given `dims`.");
-  for (const auto dim : cache->dims) {
-    newDims.insert(newDims.begin() + dim, 1);
+    auto newDims = input.sizes().vec();
+    CAFFE_ENFORCE_GE(
+        input.sizes().size() + dims_.size(),
+        dims_.back() + 1,
+        "Input needs at least ",
+        (1 + dims_.back() - dims_.size()),
+        " dimensions given `dims`.");
+    for (const auto dim : dims_) {
+      newDims.insert(newDims.begin() + dim, 1);
+    }
+    output.Reshape(newDims);
   }
-  output.Reshape(newDims);
-}
+
+ private:
+  std::vector<int64_t> dims_;
+  bool initialized_ = false;
+};
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::ExpandDims",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output"),
+                                    c10::Argument("dims", ListType::ofInts())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<expand_dims_cpu<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::ExpandDims)
-    .withCache<caffe2::Cache>()
-    .kernel<decltype(caffe2::expand_dims_op_cpu_impl<float>), &caffe2::expand_dims_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::ExpandDims",
+    C10ExpandDims_DontUseThisOpYet)
+
+} // namespace caffe2
index 99e0545..51d034a 100644 (file)
@@ -1,10 +1,10 @@
+#include <ATen/core/op_registration/op_registration.h>
 #include "caffe2/core/context.h"
-#include <ATen/core/dispatch/KernelRegistration.h>
 #include "caffe2/core/operator.h"
-#include "caffe2/operators/experimental/c10/schemas/fc.h"
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
 #include "caffe2/utils/conversions.h"
 #include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -12,124 +12,140 @@ using caffe2::Tensor;
 namespace caffe2 {
 namespace {
 
-struct Cache final : public c10::KernelCache {
-  vector<int64_t> Y_shape_cache_;
-  at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor()));
-};
-
 template <class DataType, class Context>
-void fc_op_cpu_impl(
-    const at::Tensor& X_,
-    const at::Tensor& W_,
-    const at::Tensor& b_,
-    const at::Tensor& Y_,
-    int64_t axis,
-    int64_t axis_w,
-    Cache* cache) {
-  Tensor X{C10Tensor(X_)};
-  Tensor W{C10Tensor(W_)};
-  Tensor b{C10Tensor(b_)};
-  Tensor Y{C10Tensor(Y_)};
-  CPUContext context;
-
-  constexpr bool TransposeWeight = true;
-
-  CAFFE_ENFORCE(b.dim() == 1, b.dim());
-  // batch size
-  const auto canonical_axis = X.canonical_axis_index(axis);
-  const auto M = X.size_to_dim(canonical_axis);
-  const auto K = X.size_from_dim(canonical_axis);
-  const auto canonical_axis_w = W.canonical_axis_index(axis_w);
-  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
-                                : W.size_from_dim(canonical_axis_w);
-
-  auto dimErrorString = [&]() {
-    return c10::str(
-        "Dimension mismatch: ",
-        "X: ",
-        X.sizes(),
-        ", W: ",
-        W.sizes(),
-        ", b: ",
-        b.sizes(),
-        ", axis: ",
-        axis,
-        ", M: ",
+class fc_op_cpu final : public c10::OperatorKernel {
+ public:
+  void operator()(
+      const at::Tensor& X_,
+      const at::Tensor& W_,
+      const at::Tensor& b_,
+      const at::Tensor& Y_,
+      int64_t axis,
+      int64_t axis_w) {
+    Tensor X{C10Tensor(X_)};
+    Tensor W{C10Tensor(W_)};
+    Tensor b{C10Tensor(b_)};
+    Tensor Y{C10Tensor(Y_)};
+    CPUContext context;
+
+    constexpr bool TransposeWeight = true;
+
+    CAFFE_ENFORCE(b.dim() == 1, b.dim());
+    // batch size
+    const auto canonical_axis = X.canonical_axis_index(axis);
+    const auto M = X.size_to_dim(canonical_axis);
+    const auto K = X.size_from_dim(canonical_axis);
+    const auto canonical_axis_w = W.canonical_axis_index(axis_w);
+    const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
+                                  : W.size_from_dim(canonical_axis_w);
+
+    auto dimErrorString = [&]() {
+      return c10::str(
+          "Dimension mismatch: ",
+          "X: ",
+          X.sizes(),
+          ", W: ",
+          W.sizes(),
+          ", b: ",
+          b.sizes(),
+          ", axis: ",
+          axis,
+          ", M: ",
+          M,
+          ", N: ",
+          N,
+          ", K: ",
+          K);
+    };
+
+    // Error checking
+    CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
+    CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
+    CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
+    CAFFE_ENFORCE(N == b.numel(), dimErrorString());
+
+    Y_shape_cache_ = X.sizes().vec();
+    // This is an invariant of canonical_axis, so we can DCHECK.
+    DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
+    Y_shape_cache_.resize(canonical_axis + 1);
+    Y_shape_cache_[canonical_axis] = N;
+    Y.Resize(Y_shape_cache_);
+    CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString());
+
+    if (X.numel() == 0) {
+      // skip the rest of the computation if X is empty
+      Y.template mutable_data<DataType>();
+      return;
+    }
+
+    // default to FLOAT as math.h does.
+    caffe2::TensorProto::DataType math_type =
+        caffe2::TensorProto_DataType_FLOAT;
+    if (caffe2::fp16_type<DataType>()) {
+      math_type = caffe2::TensorProto_DataType_FLOAT16;
+    }
+
+    // W * x
+    caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
+        CblasNoTrans,
+        TransposeWeight ? CblasTrans : CblasNoTrans,
+        M,
+        N,
+        K,
+        1,
+        X.template data<DataType>(),
+        W.template data<DataType>(),
+        0,
+        Y.template mutable_data<DataType>(),
+        static_cast<Context*>(&context),
+        math_type);
+    // Add bias term
+    Tensor bias_multiplier(bias_multiplier_);
+    ReinitializeTensor(
+        &bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
+    caffe2::math::Set<DataType, Context>(
+        M,
+        caffe2::convert::To<float, DataType>(1),
+        bias_multiplier.template mutable_data<DataType>(),
+        static_cast<Context*>(&context));
+    caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
+        CblasNoTrans,
+        CblasNoTrans,
         M,
-        ", N: ",
         N,
-        ", K: ",
-        K);
-  };
-
-  // Error checking
-  CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
-  CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
-  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
-  CAFFE_ENFORCE(N == b.numel(), dimErrorString());
-
-  cache->Y_shape_cache_ = X.sizes().vec();
-  // This is an invariant of canonical_axis, so we can DCHECK.
-  DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size());
-  cache->Y_shape_cache_.resize(canonical_axis + 1);
-  cache->Y_shape_cache_[canonical_axis] = N;
-  Y.Resize(cache->Y_shape_cache_);
-  CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString());
-
-  if (X.numel() == 0) {
-    // skip the rest of the computation if X is empty
-    Y.template mutable_data<DataType>();
-    return;
+        1,
+        1,
+        bias_multiplier.template data<DataType>(),
+        b.template data<DataType>(),
+        1,
+        Y.template mutable_data<DataType>(),
+        static_cast<Context*>(&context),
+        math_type);
   }
 
-  // default to FLOAT as math.h does.
-  caffe2::TensorProto::DataType math_type = caffe2::TensorProto_DataType_FLOAT;
-  if (caffe2::fp16_type<DataType>()) {
-    math_type = caffe2::TensorProto_DataType_FLOAT16;
-  }
+ private:
+  vector<int64_t> Y_shape_cache_;
+  at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor()));
+};
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::FullyConnected",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("X"),
+                                    c10::Argument("W"),
+                                    c10::Argument("b"),
+                                    c10::Argument("output"),
+                                    c10::Argument("axis", IntType::get()),
+                                    c10::Argument("axis_w", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<fc_op_cpu<float, CPUContext>>(),
+    c10::dispatchKey(CPUTensorId()));
 
-  // W * x
-  caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
-      CblasNoTrans,
-      TransposeWeight ? CblasTrans : CblasNoTrans,
-      M,
-      N,
-      K,
-      1,
-      X.template data<DataType>(),
-      W.template data<DataType>(),
-      0,
-      Y.template mutable_data<DataType>(),
-      static_cast<Context*>(&context),
-      math_type);
-  // Add bias term
-  Tensor bias_multiplier(cache->bias_multiplier_);
-  ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
-  caffe2::math::Set<DataType, Context>(
-      M,
-      caffe2::convert::To<float, DataType>(1),
-      bias_multiplier.template mutable_data<DataType>(),
-      static_cast<Context*>(&context));
-  caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
-      CblasNoTrans,
-      CblasNoTrans,
-      M,
-      N,
-      1,
-      1,
-      bias_multiplier.template data<DataType>(),
-      b.template data<DataType>(),
-      1,
-      Y.template mutable_data<DataType>(),
-      static_cast<Context*>(&context),
-      math_type);
-}
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
-    .withCache<caffe2::Cache>()
-    .kernel<decltype(caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::FullyConnected",
+    C10FC_DontUseThisOpYet)
+
+} // namespace caffe2
index 6db6685..55e225d 100644 (file)
@@ -1,8 +1,8 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/filler.h"
-#include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
+#include <ATen/core/op_registration/op_registration.h>
 #include <c10/core/Tensor.h>
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::CPUContext;
 using caffe2::Tensor;
@@ -143,27 +143,107 @@ void uniform_fill_op_cpu_impl(
       output.template mutable_data<float>(),
       static_cast<CPUContext*>(&context));
 }
-} // namespace
-} // namespace caffe2
-
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
-    .kernel<decltype(caffe2::constant_fill_op_cpu_impl), &caffe2::constant_fill_op_cpu_impl>()
-    .dispatchKey(CPUTensorId());
 
-C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
-    .kernel<decltype(caffe2::uniform_fill_op_cpu_impl), &caffe2::uniform_fill_op_cpu_impl>()
-    .dispatchKey(CPUTensorId());
+static auto registry =
+    c10::RegisterOperators()
+        .op(FunctionSchema(
+                "_c10_experimental::ConstantFill",
+                "",
+                (std::vector<c10::Argument>{
+                    c10::Argument("inputs", ListType::ofTensors()),
+                    c10::Argument("output"),
+                    c10::Argument("shape", ListType::ofInts()),
+                    c10::Argument("extra_shape", ListType::ofInts()),
+                    c10::Argument("input_as_shape", BoolType::get()),
+                    c10::Argument("dtype", IntType::get()),
+                    c10::Argument("value", NumberType::get())}),
+                (std::vector<c10::Argument>{})),
+            c10::kernel<
+                decltype(constant_fill_op_cpu_impl),
+                &constant_fill_op_cpu_impl>(),
+            c10::dispatchKey(CPUTensorId()))
+        .op(FunctionSchema(
+                "_c10_experimental::UniformFill",
+                "",
+                (std::vector<c10::Argument>{
+                    c10::Argument("inputs", ListType::ofTensors()),
+                    c10::Argument("output"),
+                    c10::Argument("shape", ListType::ofInts()),
+                    c10::Argument("extra_shape", ListType::ofInts()),
+                    c10::Argument("input_as_shape", BoolType::get()),
+                    c10::Argument("min", FloatType::get()),
+                    c10::Argument("max", FloatType::get())}),
+                (std::vector<c10::Argument>{})),
+            c10::kernel<
+                decltype(uniform_fill_op_cpu_impl),
+                &uniform_fill_op_cpu_impl>(),
+            c10::dispatchKey(CPUTensorId()))
+        .op(FunctionSchema(
+                "_c10_experimental::GivenTensorFill",
+                "",
+                (std::vector<c10::Argument>{
+                    c10::Argument("inputs", ListType::ofTensors()),
+                    c10::Argument("output"),
+                    c10::Argument("shape", ListType::ofInts()),
+                    c10::Argument("extra_shape", ListType::ofInts()),
+                    c10::Argument("input_as_shape", BoolType::get()),
+                    c10::Argument("values"),
+                }),
+                (std::vector<c10::Argument>{})),
+            c10::kernel<
+                decltype(given_tensor_fill_op_cpu_impl<float, CPUContext>),
+                &given_tensor_fill_op_cpu_impl<float, CPUContext>>(),
+            c10::dispatchKey(CPUTensorId()))
+        .op(FunctionSchema(
+                "_c10_experimental::GivenTensorIntFill",
+                "",
+                (std::vector<c10::Argument>{
+                    c10::Argument("inputs", ListType::ofTensors()),
+                    c10::Argument("output"),
+                    c10::Argument("shape", ListType::ofInts()),
+                    c10::Argument("extra_shape", ListType::ofInts()),
+                    c10::Argument("input_as_shape", BoolType::get()),
+                    c10::Argument("values"),
+                }),
+                (std::vector<c10::Argument>{})),
+            c10::kernel<
+                decltype(given_tensor_fill_op_cpu_impl<int, CPUContext>),
+                &given_tensor_fill_op_cpu_impl<int, CPUContext>>(),
+            c10::dispatchKey(CPUTensorId()))
+        .op(FunctionSchema(
+                "_c10_experimental::GivenTensorInt64Fill",
+                "",
+                (std::vector<c10::Argument>{
+                    c10::Argument("inputs", ListType::ofTensors()),
+                    c10::Argument("output"),
+                    c10::Argument("shape", ListType::ofInts()),
+                    c10::Argument("extra_shape", ListType::ofInts()),
+                    c10::Argument("input_as_shape", BoolType::get()),
+                    c10::Argument("values"),
+                }),
+                (std::vector<c10::Argument>{})),
+            c10::kernel<
+                decltype(given_tensor_fill_op_cpu_impl<int, CPUContext>),
+                &given_tensor_fill_op_cpu_impl<int, CPUContext>>(),
+            c10::dispatchKey(CPUTensorId()));
 
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill)
-    .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
+} // namespace
 
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorIntFill)
-    .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::ConstantFill",
+    C10ConstantFill_DontUseThisOpYet)
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::UniformFill",
+    C10UniformFill_DontUseThisOpYet)
+
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::GivenTensorFill",
+    C10GivenTensorFill_DontUseThisOpYet)
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::GivenTensorIntFill",
+    C10GivenTensorIntFill_DontUseThisOpYet)
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::GivenTensorInt64Fill",
+    C10GivenTensorInt64Fill_DontUseThisOpYet)
 
-C10_REGISTER_KERNEL(caffe2::ops::GivenTensorInt64Fill)
-    .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+} // namespace caffe2
index 23bbaf3..6c153de 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/flatten.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -25,11 +25,24 @@ void flatten_op_cpu_impl(
       input.raw_data(),
       output.raw_mutable_data(input.dtype()));
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Flatten",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output"),
+                                    c10::Argument("axis", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(flatten_op_cpu_impl<float, CPUContext>),
+        &flatten_op_cpu_impl<float, CPUContext>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Flatten)
-    .kernel<decltype(caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Flatten",
+    C10Flatten_DontUseThisOpYet)
+
+} // namespace caffe2
index 247e1bb..ef1e6d9 100644 (file)
@@ -1,8 +1,8 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
 #include "caffe2/operators/elementwise_ops_utils.h"
-#include "caffe2/operators/experimental/c10/schemas/mul.h"
 #include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -70,11 +70,25 @@ void mul_op_cpu_impl(
       C.mutable_data<DataType>(),
       static_cast<CPUContext*>(&context));
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Mul",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("input1"),
+            c10::Argument("input2"),
+            c10::Argument("output"),
+            c10::Argument("legacy_broadcast", BoolType::get()),
+            c10::Argument("axis", IntType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<decltype(mul_op_cpu_impl<float>), &mul_op_cpu_impl<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Mul)
-    .kernel<decltype(caffe2::mul_op_cpu_impl<float>), &caffe2::mul_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Mul",
+    C10Mul_DontUseThisOpYet)
+
+} // namespace caffe2
index 67c7ee2..c29a677 100644 (file)
@@ -1,8 +1,8 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/relu.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
 #include "caffe2/utils/eigen_utils.h"
 #include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
 
 using caffe2::Tensor;
 
@@ -39,11 +39,21 @@ void relu_op_cpu_impl(
   }
   */
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Relu",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<decltype(relu_op_cpu_impl<float>), &relu_op_cpu_impl<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Relu)
-    .kernel<decltype(caffe2::relu_op_cpu_impl<float>), &caffe2::relu_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Relu",
+    C10Relu_DontUseThisOpYet)
+
+} // namespace caffe2
index 470d633..17d96a5 100644 (file)
@@ -1,8 +1,8 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/sigmoid.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
 #include "caffe2/utils/eigen_utils.h"
 #include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
 
 using caffe2::Tensor;
 
@@ -22,11 +22,23 @@ void sigmoid_op_cpu_impl(
       output.mutable_data<DataType>(), input.numel()) =
       1. / (1. + (-xM).exp());
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::Sigmoid",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(sigmoid_op_cpu_impl<float>),
+        &sigmoid_op_cpu_impl<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::Sigmoid)
-    .kernel<decltype(caffe2::sigmoid_op_cpu_impl<float>), &caffe2::sigmoid_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::Sigmoid",
+    C10Sigmoid_DontUseThisOpYet)
+
+} // namespace caffe2
index 4255e2a..e23f6f8 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::Tensor;
 
@@ -69,11 +69,27 @@ void sigmoid_cross_entropy_with_logits_op_cpu_impl(
     out_ptr[i] = -value / inner_size;
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::SigmoidCrossEntropyWithLogits",
+        "",
+        (std::vector<c10::Argument>{
+            c10::Argument("input1"),
+            c10::Argument("input2"),
+            c10::Argument("output"),
+            c10::Argument("log_D_trick", BoolType::get()),
+            c10::Argument("unjoined_lr_loss", BoolType::get())}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(sigmoid_cross_entropy_with_logits_op_cpu_impl),
+        &sigmoid_cross_entropy_with_logits_op_cpu_impl>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
-    .kernel<decltype(caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl), &caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::SigmoidCrossEntropyWithLogits",
+    C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet)
+
+} // namespace caffe2
index 1d8b180..6bbc107 100644 (file)
@@ -1,8 +1,8 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
+#include "caffe2/core/tensor.h"
 #include "caffe2/perfkernels/embedding_lookup.h"
 #include "caffe2/utils/math.h"
-#include "caffe2/core/tensor.h"
 
 using caffe2::Tensor;
 
@@ -81,11 +81,24 @@ void sparse_lengths_sum_op_cpu(
   }
 }
 
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::SparseLengthsSum",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("data"),
+                                    c10::Argument("indices"),
+                                    c10::Argument("lengths"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(sparse_lengths_sum_op_cpu),
+        &sparse_lengths_sum_op_cpu>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
-    .kernel<decltype(caffe2::sparse_lengths_sum_op_cpu), &caffe2::sparse_lengths_sum_op_cpu>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::SparseLengthsSum",
+    C10SparseLengthsSum_DontUseThisOpYet)
+
+} // namespace caffe2
index 4d3a981..f103719 100644 (file)
@@ -1,7 +1,7 @@
-#include <ATen/core/dispatch/KernelRegistration.h>
-#include "caffe2/operators/experimental/c10/schemas/stop_gradient.h"
-#include "caffe2/utils/math.h"
+#include <ATen/core/op_registration/op_registration.h>
+#include "caffe2/core/operator_c10wrapper.h"
 #include "caffe2/core/tensor.h"
+#include "caffe2/utils/math.h"
 
 using caffe2::BaseContext;
 using caffe2::Tensor;
@@ -18,11 +18,23 @@ void stop_gradient_op_cpu_impl(
     output.CopyFrom(input);
   }
 }
+
+static auto registry = c10::RegisterOperators().op(
+    FunctionSchema(
+        "_c10_experimental::StopGradient",
+        "",
+        (std::vector<c10::Argument>{c10::Argument("input"),
+                                    c10::Argument("output")}),
+        (std::vector<c10::Argument>{})),
+    c10::kernel<
+        decltype(stop_gradient_op_cpu_impl<float>),
+        &stop_gradient_op_cpu_impl<float>>(),
+    c10::dispatchKey(CPUTensorId()));
+
 } // namespace
-} // namespace caffe2
 
-namespace c10 {
-C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
-    .kernel<decltype(caffe2::stop_gradient_op_cpu_impl<float>), &caffe2::stop_gradient_op_cpu_impl<float>>()
-    .dispatchKey(CPUTensorId());
-} // namespace c10
+REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
+    "_c10_experimental::StopGradient",
+    C10StopGradient_DontUseThisOpYet)
+
+} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/add.cc b/caffe2/operators/experimental/c10/schemas/add.cc
deleted file mode 100644 (file)
index 63ecc97..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/add.h"
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Add,
-    FunctionSchema(
-        "_c10_experimental::Add",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("input1"),
-            c10::Argument("input2"),
-            c10::Argument("output"),
-            c10::Argument("legacy_broadcast", BoolType::get()),
-            c10::Argument("axis", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Add",
-    C10Add_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/add.h b/caffe2/operators/experimental/c10/schemas/add.h
deleted file mode 100644 (file)
index 4dfa99a..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Add);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.cc b/caffe2/operators/experimental/c10/schemas/averaged_loss.cc
deleted file mode 100644 (file)
index dfc41a6..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/averaged_loss.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    AveragedLoss,
-    FunctionSchema(
-        "_c10_experimental::AveragedLoss",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::AveragedLoss",
-    C10AveragedLoss_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/averaged_loss.h b/caffe2/operators/experimental/c10/schemas/averaged_loss.h
deleted file mode 100644 (file)
index 548bd07..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(AveragedLoss);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.cc b/caffe2/operators/experimental/c10/schemas/batch_gather.cc
deleted file mode 100644 (file)
index 9fb84e5..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/batch_gather.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    BatchGather,
-    FunctionSchema(
-        "_c10_experimental::BatchGather",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("data"),
-                                    c10::Argument("indices"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::BatchGather",
-    C10BatchGather_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/batch_gather.h b/caffe2/operators/experimental/c10/schemas/batch_gather.h
deleted file mode 100644 (file)
index 214c67f..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(BatchGather);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.cc b/caffe2/operators/experimental/c10/schemas/batch_matmul.cc
deleted file mode 100644 (file)
index addb95e..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/batch_matmul.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    BatchMatmul,
-    FunctionSchema(
-        "_c10_experimental::BatchMatmul",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("A"),
-            c10::Argument("B"),
-            c10::Argument("output"),
-            c10::Argument("trans_a", IntType::get()),
-            c10::Argument("trans_b", IntType::get()),
-            c10::Argument("broadcast", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::BatchMatmul",
-    C10BatchMatMul_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/batch_matmul.h b/caffe2/operators/experimental/c10/schemas/batch_matmul.h
deleted file mode 100644 (file)
index 191e0e6..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(BatchMatmul);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/cast.cc b/caffe2/operators/experimental/c10/schemas/cast.cc
deleted file mode 100644 (file)
index c1133ce..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/cast.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-#include "caffe2/utils/cast.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Cast,
-    FunctionSchema(
-        "_c10_experimental::Cast",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("input"),
-            c10::Argument("output"),
-            c10::Argument("to_dtype", IntType::get()),
-        }),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Cast",
-    C10Cast_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/cast.h b/caffe2/operators/experimental/c10/schemas/cast.h
deleted file mode 100644 (file)
index 979637b..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Cast);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/concat.cc b/caffe2/operators/experimental/c10/schemas/concat.cc
deleted file mode 100644 (file)
index d9a7e33..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/concat.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Concat,
-    FunctionSchema(
-        "_c10_experimental::Concat",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("split_info", FloatType::get()),
-            c10::Argument("add", IntType::get()),
-            c10::Argument("add_axis", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Concat",
-    C10Concat_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/concat.h b/caffe2/operators/experimental/c10/schemas/concat.h
deleted file mode 100644 (file)
index aecaf40..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Concat);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.cc b/caffe2/operators/experimental/c10/schemas/enforce_finite.cc
deleted file mode 100644 (file)
index c8d58de..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/enforce_finite.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    EnforceFinite,
-    FunctionSchema(
-        "_c10_experimental::EnforceFinite",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::EnforceFinite",
-    C10EnforceFinite_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/enforce_finite.h b/caffe2/operators/experimental/c10/schemas/enforce_finite.h
deleted file mode 100644 (file)
index 704136c..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(EnforceFinite);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.cc b/caffe2/operators/experimental/c10/schemas/expand_dims.cc
deleted file mode 100644 (file)
index e2c4c75..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/expand_dims.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-using c10::intrusive_ptr;
-using c10::ivalue::IntList;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    ExpandDims,
-    FunctionSchema(
-        "_c10_experimental::ExpandDims",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output"),
-                                    c10::Argument("dims", ListType::ofInts())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::ExpandDims",
-    C10ExpandDims_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/expand_dims.h b/caffe2/operators/experimental/c10/schemas/expand_dims.h
deleted file mode 100644 (file)
index fa3ab8f..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(ExpandDims);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/fc.cc b/caffe2/operators/experimental/c10/schemas/fc.cc
deleted file mode 100644 (file)
index 773964e..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/fc.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    FullyConnected,
-    FunctionSchema(
-        "_c10_experimental::FullyConnected",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("X"),
-                                    c10::Argument("W"),
-                                    c10::Argument("b"),
-                                    c10::Argument("output"),
-                                    c10::Argument("axis", IntType::get()),
-                                    c10::Argument("axis_w", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::FullyConnected",
-    C10FC_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/fc.h b/caffe2/operators/experimental/c10/schemas/fc.h
deleted file mode 100644 (file)
index 1aed0eb..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(FullyConnected);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/filler.cc b/caffe2/operators/experimental/c10/schemas/filler.cc
deleted file mode 100644 (file)
index 8fe8707..0000000
+++ /dev/null
@@ -1,103 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/filler.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-#include "caffe2/utils/cast.h"
-
-using caffe2::CPUContext;
-using c10::C10Tensor;
-using c10::ivalue::IntList;
-using c10::intrusive_ptr;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema strings instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    ConstantFill,
-    FunctionSchema(
-        "_c10_experimental::ConstantFill",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("shape", ListType::ofInts()),
-            c10::Argument("extra_shape", ListType::ofInts()),
-            c10::Argument("input_as_shape", BoolType::get()),
-            c10::Argument("dtype", IntType::get()),
-            c10::Argument("value", NumberType::get())}),
-        (std::vector<c10::Argument>{})));
-C10_DEFINE_OP_SCHEMA(
-    UniformFill,
-    FunctionSchema(
-        "_c10_experimental::UniformFill",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("shape", ListType::ofInts()),
-            c10::Argument("extra_shape", ListType::ofInts()),
-            c10::Argument("input_as_shape", BoolType::get()),
-            c10::Argument("min", FloatType::get()),
-            c10::Argument("max", FloatType::get())}),
-        (std::vector<c10::Argument>{})));
-C10_DEFINE_OP_SCHEMA(
-    GivenTensorFill,
-    FunctionSchema(
-        "_c10_experimental::GivenTensorFill",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("shape", ListType::ofInts()),
-            c10::Argument("extra_shape", ListType::ofInts()),
-            c10::Argument("input_as_shape", BoolType::get()),
-            c10::Argument("values"),
-        }),
-        (std::vector<c10::Argument>{})));
-C10_DEFINE_OP_SCHEMA(
-    GivenTensorIntFill,
-    FunctionSchema(
-        "_c10_experimental::GivenTensorIntFill",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("shape", ListType::ofInts()),
-            c10::Argument("extra_shape", ListType::ofInts()),
-            c10::Argument("input_as_shape", BoolType::get()),
-            c10::Argument("values"),
-        }),
-        (std::vector<c10::Argument>{})));
-C10_DEFINE_OP_SCHEMA(
-    GivenTensorInt64Fill,
-    FunctionSchema(
-        "_c10_experimental::GivenTensorInt64Fill",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("inputs", ListType::ofTensors()),
-            c10::Argument("output"),
-            c10::Argument("shape", ListType::ofInts()),
-            c10::Argument("extra_shape", ListType::ofInts()),
-            c10::Argument("input_as_shape", BoolType::get()),
-            c10::Argument("values"),
-        }),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::ConstantFill",
-    C10ConstantFill_DontUseThisOpYet)
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::UniformFill",
-    C10UniformFill_DontUseThisOpYet)
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::GivenTensorFill",
-    C10GivenTensorFill_DontUseThisOpYet)
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::GivenTensorIntFill",
-    C10GivenTensorIntFill_DontUseThisOpYet)
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::GivenTensorInt64Fill",
-    C10GivenTensorInt64Fill_DontUseThisOpYet)
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/filler.h b/caffe2/operators/experimental/c10/schemas/filler.h
deleted file mode 100644 (file)
index 616893d..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(GivenTensorFill);
-C10_DECLARE_OP_SCHEMA(GivenTensorIntFill);
-C10_DECLARE_OP_SCHEMA(GivenTensorInt64Fill);
-C10_DECLARE_OP_SCHEMA(ConstantFill);
-C10_DECLARE_OP_SCHEMA(UniformFill);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/flatten.cc b/caffe2/operators/experimental/c10/schemas/flatten.cc
deleted file mode 100644 (file)
index 42353b2..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/flatten.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Flatten,
-    FunctionSchema(
-        "_c10_experimental::Flatten",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output"),
-                                    c10::Argument("axis", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Flatten",
-    C10Flatten_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/flatten.h b/caffe2/operators/experimental/c10/schemas/flatten.h
deleted file mode 100644 (file)
index 9c53462..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Flatten);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/mul.cc b/caffe2/operators/experimental/c10/schemas/mul.cc
deleted file mode 100644 (file)
index af7a7b7..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/mul.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Mul,
-    FunctionSchema(
-        "_c10_experimental::Mul",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("input1"),
-            c10::Argument("input2"),
-            c10::Argument("output"),
-            c10::Argument("legacy_broadcast", BoolType::get()),
-            c10::Argument("axis", IntType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Mul",
-    C10Mul_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/mul.h b/caffe2/operators/experimental/c10/schemas/mul.h
deleted file mode 100644 (file)
index 54b64f4..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Mul);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/relu.cc b/caffe2/operators/experimental/c10/schemas/relu.cc
deleted file mode 100644 (file)
index 43528d9..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/relu.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Relu,
-    FunctionSchema(
-        "_c10_experimental::Relu",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Relu",
-    C10Relu_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/relu.h b/caffe2/operators/experimental/c10/schemas/relu.h
deleted file mode 100644 (file)
index ea0aa89..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Relu);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.cc b/caffe2/operators/experimental/c10/schemas/sigmoid.cc
deleted file mode 100644 (file)
index 2261a19..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/sigmoid.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    Sigmoid,
-    FunctionSchema(
-        "_c10_experimental::Sigmoid",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::Sigmoid",
-    C10Sigmoid_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid.h b/caffe2/operators/experimental/c10/schemas/sigmoid.h
deleted file mode 100644 (file)
index 5d5ff41..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(Sigmoid);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.cc
deleted file mode 100644 (file)
index d1be6b9..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    SigmoidCrossEntropyWithLogits,
-    FunctionSchema(
-        "_c10_experimental::SigmoidCrossEntropyWithLogits",
-        "",
-        (std::vector<c10::Argument>{
-            c10::Argument("input1"),
-            c10::Argument("input2"),
-            c10::Argument("output"),
-            c10::Argument("log_D_trick", BoolType::get()),
-            c10::Argument("unjoined_lr_loss", BoolType::get())}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::SigmoidCrossEntropyWithLogits",
-    C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h b/caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h
deleted file mode 100644 (file)
index 671c2e2..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(SigmoidCrossEntropyWithLogits);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.cc
deleted file mode 100644 (file)
index f26abcf..0000000
+++ /dev/null
@@ -1,27 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    SparseLengthsSum,
-    FunctionSchema(
-        "_c10_experimental::SparseLengthsSum",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("data"),
-                                    c10::Argument("indices"),
-                                    c10::Argument("lengths"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::SparseLengthsSum",
-    C10SparseLengthsSum_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h b/caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h
deleted file mode 100644 (file)
index a4054e1..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(SparseLengthsSum);
-
-} // namespace ops
-} // namespace caffe2
diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.cc b/caffe2/operators/experimental/c10/schemas/stop_gradient.cc
deleted file mode 100644 (file)
index 3305845..0000000
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/experimental/c10/schemas/stop_gradient.h"
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include "caffe2/core/operator_c10wrapper.h"
-
-using caffe2::CPUContext;
-
-namespace caffe2 {
-namespace ops {
-// TODO Parse schema string instead of creating FunctionSchema manually
-C10_DEFINE_OP_SCHEMA(
-    StopGradient,
-    FunctionSchema(
-        "_c10_experimental::StopGradient",
-        "",
-        (std::vector<c10::Argument>{c10::Argument("input"),
-                                    c10::Argument("output")}),
-        (std::vector<c10::Argument>{})));
-}
-}
-
-namespace caffe2 {
-REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
-    "_c10_experimental::StopGradient",
-    C10StopGradient_DontUseThisOpYet)
-}
diff --git a/caffe2/operators/experimental/c10/schemas/stop_gradient.h b/caffe2/operators/experimental/c10/schemas/stop_gradient.h
deleted file mode 100644 (file)
index bb130e2..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-#pragma once
-
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-
-namespace caffe2 {
-namespace ops {
-
-C10_DECLARE_OP_SCHEMA(StopGradient);
-
-} // namespace ops
-} // namespace caffe2