From 24752eb7b827607ad0ff65e89d6c25b5bf48896a Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Sat, 30 Mar 2019 00:03:46 -0700 Subject: [PATCH] Report better errors when kernel or dispatch key are missing (#18302) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18302 These might be use cases we want to support in the future, but they don't work yet. Let's at least report an error instead of doing segfaults or worse. Reviewed By: dzhulgakov Differential Revision: D14572346 fbshipit-source-id: 49262ce131493bc887defe2978d8b22f202cd8cc --- aten/src/ATen/core/dispatch/DispatchTable.h | 13 ++++- .../ATen/core/op_registration/op_registration.cpp | 13 +++-- .../core/op_registration/op_registration_test.cpp | 58 ++++++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 aten/src/ATen/core/op_registration/op_registration_test.cpp diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 15865cc..19cedf9 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -100,7 +100,18 @@ class ThreadsafeOperatorTable_ final { } static std::string dispatch_key_to_string(TensorTypeId id) { - return std::string(toString(tensorTypeIdToBackend(id))) + "[" + toString(id) + "]"; + // TODO Find better way to stringify tensor type ids without relying on backend + std::string name = ""; + try { + name = toString(tensorTypeIdToBackend(id)); + } catch (const std::exception&) { + // This can fail if the tensor type id is not one of the preregistered backends. + // However, dispatch_key_to_string is used to generate error reports, that + // means an error already has happened when entering this function. + // We don't want inner errors during generation of a report for an + // outer error. Just report an empty name instead. + } + return name + "[" + toString(id) + "]"; } LeftRight> map_; diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index cb43c55..d0236b7 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -40,9 +40,16 @@ private: }; void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { - // TODO Allow this for registering the schema without a kernel? - AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel"); - + // TODO We need to support registrations without a dispatch key, + // at least for the deprecated APIs. Maybe also for new ones. + AT_CHECK(config.dispatch_key.has_value(), + "Tried to register an operator with function schema ", toString(schema), + ", but didn't specify a dispatch key. Please add a c10::dispatchKey(...) parameter to the registration call."); + + // TODO Should we allow this and only register a schema without a kernel? + AT_CHECK(config.kernel_func != nullptr, + "Tried to register an operator with function schema ", toString(schema), + ", but didn't specify a kernel. Please add a c10::kernel<...>(...) parameter to the registration call."); // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else. AT_ASSERT(static_cast(config.cache_creator_func)); diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp new file mode 100644 index 0000000..aeab038 --- /dev/null +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -0,0 +1,58 @@ +/** + * This file contains some general registration test cases. + * More detailed test cases containing different APIs for registering kernels + * are found in other files in this directory. + */ + +#include +#include + +#include +#include + +using c10::RegisterOperators; +using c10::OperatorKernel; +using c10::FunctionSchema; +using c10::Argument; +using c10::kernel; +using c10::dispatchKey; +using at::Tensor; + +namespace { + +C10_DECLARE_TENSOR_TYPE(TensorType1); +C10_DEFINE_TENSOR_TYPE(TensorType1); + +struct DummyKernel final : OperatorKernel { + void operator()(Tensor) {} +}; + +FunctionSchema dummySchema( + "_test::dummy", + "", + (std::vector{Argument("dummy")}), + (std::vector{})); + +TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutKernel_thenFails) { + // make sure it crashes when kernel is absent + EXPECT_THROW( + c10::RegisterOperators().op(dummySchema, dispatchKey(TensorType1())), + c10::Error + ); + + // but make sure it doesn't crash when kernel is present + c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); +} + +TEST(OperatorRegistrationTest, whenTryingToRegisterWithoutDispatchKey_thenFails) { + // make sure it crashes when dispatch key is absent + EXPECT_THROW( + c10::RegisterOperators().op(dummySchema, kernel()), + c10::Error + ); + + // but make sure it doesn't crash when dispatch key is present + c10::RegisterOperators().op(dummySchema, kernel(), dispatchKey(TensorType1())); +} + +} -- 2.7.4