}
static std::string dispatch_key_to_string(TensorTypeId id) {
- return std::string(toString(tensorTypeIdToBackend(id))) + "[" + toString(id) + "]";
+ // TODO Find better way to stringify tensor type ids without relying on backend
+ std::string name = "";
+ try {
+ name = toString(tensorTypeIdToBackend(id));
+ } catch (const std::exception&) {
+ // This can fail if the tensor type id is not one of the preregistered backends.
+ // However, dispatch_key_to_string is used to generate error reports, that
+ // means an error already has happened when entering this function.
+ // We don't want inner errors during generation of a report for an
+ // outer error. Just report an empty name instead.
+ }
+ return name + "[" + toString(id) + "]";
}
LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_;
};
void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
- // TODO Allow this for registering the schema without a kernel?
- AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel");
-
+ // TODO We need to support registrations without a dispatch key,
+ // at least for the deprecated APIs. Maybe also for new ones.
+ AT_CHECK(config.dispatch_key.has_value(),
+ "Tried to register an operator with function schema ", toString(schema),
+ ", but didn't specify a dispatch key. Please add a c10::dispatchKey(...) parameter to the registration call.");
+
+ // TODO Should we allow this and only register a schema without a kernel?
+ AT_CHECK(config.kernel_func != nullptr,
+ "Tried to register an operator with function schema ", toString(schema),
+ ", but didn't specify a kernel. Please add a c10::kernel<...>(...) parameter to the registration call.");
// if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
AT_ASSERT(static_cast<bool>(config.cache_creator_func));
--- /dev/null
+/**
+ * This file contains some general registration test cases.
+ * More detailed test cases containing different APIs for registering kernels
+ * are found in other files in this directory.
+ */
+
+#include <gtest/gtest.h>
+#include <ATen/core/op_registration/test_helpers.h>
+
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/core/Tensor.h>
+
+using c10::RegisterOperators;
+using c10::OperatorKernel;
+using c10::FunctionSchema;
+using c10::Argument;
+using c10::kernel;
+using c10::dispatchKey;
+using at::Tensor;
+
+namespace {
+
+C10_DECLARE_TENSOR_TYPE(TensorType1);
+C10_DEFINE_TENSOR_TYPE(TensorType1);
+
+struct DummyKernel final : OperatorKernel {
+ void operator()(Tensor) {}
+};
+
+FunctionSchema dummySchema(
+ "_test::dummy",
+ "",
+ (std::vector<Argument>{Argument("dummy")}),
+ (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) {
+ // make sure it crashes when kernel is absent
+ EXPECT_THROW(
+ c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())),
+ c10::Error
+ );
+
+ // but make sure it doesn't crash when kernel is present
+ c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+}
+
+TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutDispatchKey_thenFails) {
+ // make sure it crashes when dispatch key is absent
+ EXPECT_THROW(
+ c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>()),
+ c10::Error
+ );
+
+ // but make sure it doesn't crash when dispatch key is present
+ c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+}
+
+}