New operator registration MVP (#18161)
authorSebastian Messmer <messmer@fb.com>
Sat, 30 Mar 2019 07:03:43 +0000 (00:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 07:07:16 +0000 (00:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18161

This introduces version 0 for the new operator registration.

For now, it only works with kernels that are defined as stack-based functions.
This is actually not the intended public API for defining kernels, but it's the basis which is going to be used to define the public APIs (see diffs on top for them),
and it's also the API used for exposing caffe2 operators.

This diff also switches the mechanism for exposing caffe2 operators to the new mechanism.

Reviewed By: dzhulgakov

Differential Revision: D14514231

fbshipit-source-id: 454ab7b5b46a10203aa27b175400d23f818dd1df

aten/src/ATen/core/dispatch/README.md
aten/src/ATen/core/op_registration/base.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/dispatch_key.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_stackbased.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/op_registration.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/op_registration.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/test_helpers.h [new file with mode: 0644]
caffe2/core/c10_operator.h

index 1cb2af9..e29c93c 100644 (file)
@@ -8,5 +8,5 @@ This folder contains the following files:
 - DispatchTable.h: Implementation of the actual dispatch mechanism. Hash table with kernels, lookup, ...
 - KernelCache.h: An interface operator kernels can use to inherit from if they need to keep around a cache between invocations
 - KernelFunction.h: The core interface (i.e. function pointer) for calling a kernel
-- OpSchemaRegistration.h: The mechanisms to register new operators with the c10 dispatcher
-- KernelRegistration.h: The mechanisms to register kernels with the c10 dispatcher
+- OpSchemaRegistration.h (deprecated): The mechanisms to register new operators with the c10 dispatcher
+- KernelRegistration.h (deprecated): The mechanisms to register kernels with the c10 dispatcher
diff --git a/aten/src/ATen/core/op_registration/base.h b/aten/src/ATen/core/op_registration/base.h
new file mode 100644 (file)
index 0000000..8a9430a
--- /dev/null
@@ -0,0 +1,100 @@
+#pragma once
+
+/**
+ * This file sets up the basics for operator registration like the
+ * c10::RegisterOperators() class.
+ *
+ * You probably don't want to include this file directly but include
+ * op_registration.h instead since that adds more functionality you'll
+ * likely need to register your operators.
+ */
+
+#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace c10 {
+
+namespace detail {
+
+  // OperatorRegistrar in its constructor registers an operator in the dispatch
+  // table deregisters it in the destructor. The intent is that this class is
+  // constructed at static initialization time so that operators automatically
+  // get registered when a dlopen() occurs.
+  // You shouldn't call this directly; instead, use the RegisterOperators class.
+  class OperatorRegistrar final {
+  public:
+    explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction* cache_creator)
+    : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
+      Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, cache_creator);
+    }
+
+    OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
+    :  op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) {
+      rhs.owns_registration_ = false;
+    }
+
+    // not needed and would break RAII if defaulted.
+    OperatorRegistrar& operator=(OperatorRegistrar&& rhs) noexcept = delete;
+    OperatorRegistrar(const OperatorRegistrar& rhs) = delete;
+    OperatorRegistrar& operator=(const OperatorRegistrar& rhs) = delete;
+
+    ~OperatorRegistrar() {
+      if (owns_registration_) {
+        Dispatcher::singleton().deregisterKernel(op_, dispatch_key_);
+        Dispatcher::singleton().deregisterSchema(op_);
+      }
+    }
+
+  private:
+    const OperatorHandle op_;
+    const TensorTypeId dispatch_key_;
+    bool owns_registration_;
+  };
+
+  // KernelRegistrationConfig accumulates all information from the config
+  // parameters passed to a RegisterOperators::op() call into one object.
+  struct KernelRegistrationConfig final {
+    TensorTypeId dispatch_key;
+    KernelFunction* kernel_func = nullptr;
+    KernelCacheCreatorFunction* cache_creator_func = nullptr;
+  };
+
+  // is_registration_config_parameter is a concept that returns true_type iff its argument is
+  // a valid parameter to be passed to c10::RegisterOperators().op(parameters...)
+  // That is, it must have an apply method that takes a KernelRegistrationConfig*.
+  template<class ConfigParameter, class Enable = void>
+  struct is_registration_config_parameter : std::false_type {
+    static_assert(std::is_same<ConfigParameter, guts::decay_t<ConfigParameter>>::value, "is_registration_config_parameter doesn't work with reference types");
+  };
+  template<class ConfigParameter>
+  struct is_registration_config_parameter<ConfigParameter, guts::void_t<decltype(
+    std::declval<ConfigParameter>().apply(std::declval<KernelRegistrationConfig*>()),
+    std::declval<const ConfigParameter&>().apply(std::declval<KernelRegistrationConfig*>())
+  )>> : std::true_type {
+    static_assert(std::is_same<ConfigParameter, guts::decay_t<ConfigParameter>>::value, "is_registration_config_parameter doesn't work with reference types");
+  };
+  static_assert(!is_registration_config_parameter<KernelRegistrationConfig>::value, "For classes that aren't registration parameters, this concept should return false");
+  // note: the corresponding asserts that the concept returns true are next to the definition of the corresponding classes
+
+  // Take a list of configuration parameters and return a
+  // KernelRegistrationConfig accumulating all their configurations.
+  template<class... ConfigParameters>
+  KernelRegistrationConfig make_registration_config(ConfigParameters&&... configParameters) {
+    static_assert(guts::conjunction<is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, "One of the parameters isn't a valid registration config parameter.");
+
+    KernelRegistrationConfig config;
+
+    // apply all configParameters
+    (void)std::initializer_list<int>{(std::forward<ConfigParameters>(configParameters).apply(&config), 0)...};
+
+    // TODO Allow this for just registering the schema?
+    AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel");
+
+    // if kernel_func is set, so must be cache_creator_func,
+    // the API shouldn't allow anything else.
+    AT_ASSERT(static_cast<bool>(config.cache_creator_func));
+
+    return config;
+  }
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/dispatch_key.h b/aten/src/ATen/core/op_registration/dispatch_key.h
new file mode 100644 (file)
index 0000000..a4ced36
--- /dev/null
@@ -0,0 +1,59 @@
+#pragma once
+
+/**
+ * This file implements c10::dispatchKey() which is used in the kernel
+ * registration API to set the dispatch key for a registered kernel.
+ *
+ * You probably don't want to include this file directly but include
+ * op_registration.h instead since that adds more functionality you'll
+ * likely need to register your operators.
+ */
+
+#include <ATen/core/op_registration/base.h>
+
+namespace c10 {
+
+namespace detail {
+  struct DispatchKeyConfigParameter final {
+    explicit constexpr DispatchKeyConfigParameter(TensorTypeId dispatch_key)
+    : dispatch_key_(dispatch_key) {}
+
+    void apply(KernelRegistrationConfig* registration) const {
+      registration->dispatch_key = dispatch_key_;
+    }
+
+  private:
+    TensorTypeId dispatch_key_;
+  };
+  static_assert(is_registration_config_parameter<DispatchKeyConfigParameter>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+}
+
+/**
+ * Use this to register an operator with a kernel for a certain dispatch key.
+ *
+ * Example:
+ *
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * >   class my_kernel_cuda final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel<my_kernel_cpu>(),
+ * >         c10::dispatchKey(CPUTensorId()))
+ * >     .op("my_op",
+ * >         c10::kernel<my_kernel_cuda>(),
+ * >         c10::dispatchKey(CUDATensorId()));
+ */
+inline constexpr detail::DispatchKeyConfigParameter dispatchKey(TensorTypeId dispatch_key) {
+  return detail::DispatchKeyConfigParameter(dispatch_key);
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_stackbased.h b/aten/src/ATen/core/op_registration/kernel_stackbased.h
new file mode 100644 (file)
index 0000000..09750aa
--- /dev/null
@@ -0,0 +1,59 @@
+#pragma once
+
+/**
+ * This file implements c10::kernel(stack_based_kernel) which is used in the
+ * kernel registration API to set the dispatch key for a registered kernel.
+ * You probably don't want to use this API, stack based kernels are internal
+ * only. There's other, better kernel APIs which are built on top of this one.
+ *
+ * You probably don't want to include this file directly but include
+ * op_registration.h instead since that adds more functionality you'll
+ * likely need to register your operators.
+ */
+
+#include <ATen/core/op_registration/base.h>
+
+namespace c10 {
+
+namespace detail {
+  struct KernelRegistrationConfigParameter final {
+    explicit constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func)
+    : kernel_func_(kernel_func), cache_creator_func_(std::move(cache_creator_func)) {
+    }
+
+    void apply(KernelRegistrationConfig* registration) const {
+      registration->kernel_func = kernel_func_;
+      registration->cache_creator_func = cache_creator_func_;
+    }
+
+  private:
+    KernelFunction* kernel_func_;
+    KernelCacheCreatorFunction* cache_creator_func_;
+  };
+
+  static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
+}
+
+/**
+ * Use this to register an operator whose kernel is implemented by a stack
+ * based function. This is meant to be used internally, for example for writing
+ * wrappers for other ways of writing operators. This is not part of the
+ * public API.
+ *
+ * Example:
+ *
+ * > namespace {
+ * >   void my_kernel_cpu(Stack* stack, KernelCache* cache) {...}
+ * >   unique_ptr<KernelCache> my_cache_creator() {...}
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel(&my_kernel_cpu, &my_cache_creator),
+ * >         c10::dispatchKey(CPUTensorId()));
+ */
+inline constexpr detail::KernelRegistrationConfigParameter kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator) {
+  return detail::KernelRegistrationConfigParameter(kernel_func, cache_creator);
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp b/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp
new file mode 100644 (file)
index 0000000..847c273
--- /dev/null
@@ -0,0 +1,166 @@
+#include <gtest/gtest.h>
+#include <ATen/core/op_registration/test_helpers.h>
+
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/core/Tensor.h>
+
+using c10::RegisterOperators;
+using c10::FunctionSchema;
+using c10::Argument;
+using c10::IntType;
+using c10::kernel;
+using c10::dispatchKey;
+using c10::TensorTypeId;
+using c10::KernelCache;
+using c10::Stack;
+using c10::guts::make_unique;
+using std::unique_ptr;
+
+namespace {
+
+C10_DECLARE_TENSOR_TYPE(TensorType1);
+C10_DEFINE_TENSOR_TYPE(TensorType1);
+C10_DECLARE_TENSOR_TYPE(TensorType2);
+C10_DEFINE_TENSOR_TYPE(TensorType2);
+
+std::unique_ptr<c10::KernelCache> noCache() {
+  return nullptr;
+}
+
+void errorKernel(Stack* stack, KernelCache* cache) {
+  EXPECT_TRUE(false); // this kernel should never be called
+}
+
+FunctionSchema errorOpSchema(
+    "_test::error",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+void incrementKernel(Stack* stack, KernelCache* cache) {
+  int input = torch::jit::pop(*stack).toInt();
+  torch::jit::pop(*stack); // pop the dummy tensor
+  torch::jit::push(*stack, input + 1);
+}
+
+void decrementKernel(Stack* stack, KernelCache* cache) {
+  int input = torch::jit::pop(*stack).toInt();
+  torch::jit::pop(*stack); // pop the dummy tensor
+  torch::jit::push(*stack, input - 1);
+}
+
+FunctionSchema opSchema(
+    "_test::my_op",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+void expectCallsIncrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(6, result[0].toInt());
+}
+
+void expectCallsDecrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(4, result[0].toInt());
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
+  auto registrar = RegisterOperators()
+      .op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()))
+      .op(opSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()))
+      .op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType1()))
+      .op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
+  auto registrar1 = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+  auto registrar2 = RegisterOperators().op(opSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+  auto registrar3 = RegisterOperators().op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType1()));
+  auto registrar4 = RegisterOperators().op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
+  {
+    auto registrar1 = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+    {
+      auto registrar2 = RegisterOperators().op(opSchema, kernel(&decrementKernel, &noCache), dispatchKey(TensorType2()));
+
+      // assert that schema and cpu kernel are present
+      expectCallsIncrement(TensorType1());
+      expectCallsDecrement(TensorType2());
+    }
+
+    // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
+    expectCallsIncrement(TensorType1());
+    expectDoesntFindKernel("_test::my_op", TensorType2());
+  }
+
+  // now both registrars are destructed. Assert that the whole schema is gone
+  expectDoesntFindOperator("_test::my_op");
+}
+
+struct Cache final : KernelCache {
+  int last_value = 4;
+};
+
+unique_ptr<KernelCache> make_cache() {
+  return make_unique<Cache>();
+}
+
+void increment_sequence_kernel(Stack* stack, KernelCache* cache) {
+  torch::jit::pop(*stack); // pop dummy tensor
+  EXPECT_EQ(0, stack->size());
+  torch::jit::push(*stack, static_cast<Cache*>(cache)->last_value++);
+}
+
+FunctionSchema incrementSequenceOpSchema(
+    "_test::increment_sequence",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernelWithCache_whenCalled_thenCacheIsHandledCorrectly) {
+  auto registrar = RegisterOperators().op(incrementSequenceOpSchema, kernel(&increment_sequence_kernel, &make_cache), dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::increment_sequence", "");
+  ASSERT_TRUE(op.has_value());
+
+  // expect first time calling returns a 4 (4 is the initial value in the cache)
+  auto stack = makeStack(dummyTensor(TensorType1()));
+  auto kernel = c10::Dispatcher::singleton().lookup(*op, &stack);
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(4, stack[0].toInt());
+
+  // expect second time calling returns a 5
+  stack = makeStack(dummyTensor(TensorType1()));
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(5, stack[0].toInt());
+
+  // expect third time calling returns a 6
+  stack = makeStack(dummyTensor(TensorType1()));
+  kernel.call(&stack);
+  EXPECT_EQ(1, stack.size());
+  EXPECT_EQ(6, stack[0].toInt());
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp
new file mode 100644 (file)
index 0000000..5294dc1
--- /dev/null
@@ -0,0 +1,10 @@
+#include <ATen/core/op_registration/op_registration.h>
+
+namespace c10 {
+
+RegisterOperators::RegisterOperators() = default;
+RegisterOperators::~RegisterOperators() = default;
+RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default;
+RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default;
+
+}
diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h
new file mode 100644 (file)
index 0000000..d6bd887
--- /dev/null
@@ -0,0 +1,77 @@
+#pragma once
+
+/**
+ * Include this file if you want to register operators. It includes all
+ * functionality needed to do so for you.
+ */
+
+#include <ATen/core/op_registration/base.h>
+#include <ATen/core/op_registration/dispatch_key.h>
+#include <ATen/core/op_registration/kernel_stackbased.h>
+
+namespace c10 {
+
+/**
+ * An instance of this class handles the registration for one or more operators.
+ * Make sure you keep the RegisterOperators instance around since it will
+ * deregister the operator it's responsible for in its destructor.
+ *
+ * Example:
+ *
+ * > namespace {
+ * >   class my_kernel_cpu final : public c10::OperatorKernel {
+ * >   public:
+ * >     Tensor operator()(Tensor a, Tensor b) {...}
+ * >   };
+ * > }
+ * >
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel<my_kernel_cpu>(),
+ * >         c10::dispatchKey(CPUTensorId()));
+ */
+class C10_API RegisterOperators final {
+public:
+  RegisterOperators();
+  ~RegisterOperators();
+
+  RegisterOperators(const RegisterOperators&) = delete;
+  RegisterOperators& operator=(const RegisterOperators&) = delete;
+  RegisterOperators(RegisterOperators&&) noexcept;
+  RegisterOperators& operator=(RegisterOperators&&);
+
+
+  /**
+   * Register an operator based on a function schema and a set of configuration
+   * parameters (i.e. kernel function, dispatch key, ...).
+   *
+   * Example:
+   *
+   * > namespace {
+   * >   class my_kernel_cpu final : public c10::OperatorKernel {
+   * >   public:
+   * >     Tensor operator()(Tensor a, Tensor b) {...}
+   * >   };
+   * > }
+   * >
+   * > static auto registry = c10::RegisterOperators()
+   * >     .op("my_op",
+   * >         c10::kernel<my_kernel_cpu>(),
+   * >         c10::dispatchKey(CPUTensorId()));
+   */
+  template<class... ConfigParameters>
+  guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
+  op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
+    detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
+    registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, config.cache_creator_func);
+    return std::move(*this);
+  }
+
+  // TODO error if dispatch key is not specified
+  // TODO Add functor, function and lambda based kernel APIs
+
+private:
+  std::vector<c10::detail::OperatorRegistrar> registrars_;
+};
+
+}
diff --git a/aten/src/ATen/core/op_registration/test_helpers.h b/aten/src/ATen/core/op_registration/test_helpers.h
new file mode 100644 (file)
index 0000000..ed9d35a
--- /dev/null
@@ -0,0 +1,46 @@
+#pragma once
+
+#include <gtest/gtest.h>
+
+#include <ATen/core/Tensor.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/core/ivalue.h>
+#include <c10/core/CPUAllocator.h>
+
+template<class... Inputs>
+std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
+  return {std::forward<Inputs>(inputs)...};
+}
+
+at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) {
+  auto* allocator = c10::GetCPUAllocator();
+  int64_t nelements = 1;
+  auto dtype = caffe2::TypeMeta::Make<float>();
+  auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
+    dtype,
+    nelements,
+    allocator->allocate(nelements * dtype.itemsize()),
+    allocator,
+    /*resizable=*/true);
+  return at::detail::make_tensor<c10::TensorImpl>(storage_impl, dispatch_key, false);
+}
+
+template<class... Args>
+std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... args) {
+  auto stack = makeStack(std::forward<Args>(args)...);
+  auto kernel = c10::Dispatcher::singleton().lookup(op, &stack);
+  kernel.call(&stack);
+  return stack;
+}
+
+void expectDoesntFindKernel(const char* op_name, c10::TensorTypeId dispatch_key) {
+  auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
+  EXPECT_ANY_THROW(
+    callOp(*op, dummyTensor(dispatch_key), 5);
+  );
+}
+
+void expectDoesntFindOperator(const char* op_name) {
+  auto op = c10::Dispatcher::singleton().findSchema(op_name, "");
+  EXPECT_FALSE(op.has_value());
+}
index e23259b..240a16b 100644 (file)
@@ -1,9 +1,8 @@
 #pragma once
 
-#include <vector>
-#include <ATen/core/dispatch/OpSchemaRegistration.h>
-#include <ATen/core/dispatch/KernelRegistration.h>
 #include <ATen/core/function_schema.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <vector>
 
 namespace caffe2 {
 namespace detail {
@@ -80,12 +79,11 @@ inline void _call_caffe2_op_from_c10(
   //                might reuse one of the preallocated tensors but doesn't have to.
 }
 
-template <const c10::OperatorHandle& (*OpHandle)(), class Caffe2Operator>
+template <const c10::FunctionSchema& (*Schema)(), class Caffe2Operator>
 void call_caffe2_op_from_c10(
     c10::Stack* stack,
     c10::KernelCache* cache) { // TODO Pass in correct cache type
-  _call_caffe2_op_from_c10(
-      stack, OpHandle().schema(), &_call_caffe2_op<Caffe2Operator>);
+  _call_caffe2_op_from_c10(stack, Schema(), &_call_caffe2_op<Caffe2Operator>);
 }
 
 inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName, std::vector<c10::Argument> inputs, std::vector<c10::Argument> outputs) {
@@ -105,6 +103,9 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName
       std::move(outputs));
 }
 
+inline std::unique_ptr<c10::KernelCache> noCache() {
+  return nullptr;
+}
 }
 }
 
@@ -154,56 +155,64 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName
  *   input an input of type TensorList. There must be no other tensor inputs.
  */
 #ifndef C10_MOBILE
-#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \
-  namespace caffe2 {                              \
-  namespace _c10_ops {                            \
-  C10_DECLARE_OP_SCHEMA(OperatorName);            \
-  }                                               \
+#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName)                  \
+  namespace caffe2 {                                               \
+  namespace _c10_ops {                                             \
+  CAFFE2_API const ::c10::FunctionSchema& schema_##OperatorName(); \
+  }                                                                \
   }
 
 // TODO This macro should take a JIT schema string instead of a vector of inputs and outputs.
-#define C10_REGISTER_CAFFE2_OPERATOR_CPU(                                     \
-    OperatorName, Inputs, Outputs, OperatorClass)                             \
-  /* Register the op schema with the c10 dispatcher */                        \
-  namespace caffe2 {                                                          \
-  namespace _c10_ops {                                                        \
-  C10_DEFINE_OP_SCHEMA(                                                       \
-      OperatorName,                                                           \
-      caffe2::detail::make_function_schema_for_c10(                           \
-          #OperatorName,                                                      \
-          Inputs,                                                             \
-          Outputs));                                                          \
-  }                                                                           \
-  }                                                                           \
-  /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */  \
-  namespace c10 {                                                             \
-  C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \
-      .kernel<&caffe2::detail::call_caffe2_op_from_c10<                       \
-          ::caffe2::_c10_ops::OperatorName,                                   \
-          OperatorClass>>()                                                   \
-      .dispatchKey(CPUTensorId());                                            \
-  }
-
-#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass)        \
-  namespace c10 {                                                             \
-  C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \
-      .kernel<&caffe2::detail::call_caffe2_op_from_c10<                       \
-          ::caffe2::_c10_ops::OperatorName,                                   \
-          OperatorClass>>()                                                   \
-      .dispatchKey(CUDATensorId());                                           \
-  }
+#define C10_REGISTER_CAFFE2_OPERATOR_CPU(                                    \
+    OperatorName, Inputs, Outputs, OperatorClass)                            \
+  /* Register the op schema with the c10 dispatcher */                       \
+  namespace caffe2 {                                                         \
+  namespace _c10_ops {                                                       \
+  C10_EXPORT const ::c10::FunctionSchema& schema_##OperatorName() {          \
+    static ::c10::FunctionSchema schema =                                    \
+        ::caffe2::detail::make_function_schema_for_c10(                      \
+            #OperatorName, Inputs, Outputs);                                 \
+    return schema;                                                           \
+  }                                                                          \
+  }                                                                          \
+  }                                                                          \
+  /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
+  static auto registry_##OperatorName##_##__COUNTER__ =                      \
+      ::c10::RegisterOperators().op(                                         \
+          ::caffe2::_c10_ops::schema_##OperatorName(),                       \
+          ::c10::kernel(                                                     \
+              &::caffe2::detail::call_caffe2_op_from_c10<                    \
+                  ::caffe2::_c10_ops::schema_##OperatorName,                 \
+                  OperatorClass>,                                            \
+              &::caffe2::detail::noCache),                                   \
+          ::c10::dispatchKey(::c10::CPUTensorId()));
+
+#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass)       \
+  /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
+  static auto registry_##OperatorName##_##__COUNTER__ =                      \
+      ::c10::RegisterOperators().op(                                         \
+          ::caffe2::_c10_ops::schema_##OperatorName(),                       \
+          ::c10::kernel(                                                     \
+              &::caffe2::detail::call_caffe2_op_from_c10<                    \
+                  ::caffe2::_c10_ops::schema_##OperatorName,                 \
+                  OperatorClass>,                                            \
+              &::caffe2::detail::noCache),                                   \
+          ::c10::dispatchKey(::c10::CUDATensorId()));
 
 // You should never manually call the C10_REGISTER_CAFFE2_OPERATOR_HIP macro.
 // The C10_REGISTER_CAFFE2_OPERATOR_CUDA macro from above will be automatically
 // rewritten to C10_REGISTER_CAFFE2_OPERATOR_HIP by hipify.
-#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass)         \
-  namespace c10 {                                                             \
-  C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) /*.withCache<Cache>()*/ \
-      .kernel<&caffe2::detail::call_caffe2_op_from_c10<                       \
-          ::caffe2::_c10_ops::OperatorName,                                   \
-          OperatorClass>>()                                                   \
-      .dispatchKey(HIPTensorId());                                            \
-  }
+#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass)        \
+  /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
+  static auto registry_##OperatorName##_##__COUNTER__ =                      \
+      ::c10::RegisterOperators().op(                                         \
+          ::caffe2::_c10_ops::schema_##OperatorName(),                       \
+          ::c10::kernel(                                                     \
+              &::caffe2::detail::call_caffe2_op_from_c10<                    \
+                  ::caffe2::_c10_ops::schema_##OperatorName,                 \
+                  OperatorClass>,                                            \
+              &::caffe2::detail::noCache),                                   \
+          ::c10::dispatchKey(::c10::HIPTensorId()));
 
 #else
 // Don't use c10 dispatcher on mobile because of binary size