From: Sebastian Messmer Date: Thu, 18 Apr 2019 09:00:49 +0000 (-0700) Subject: Allow registering ops without specifying the full schema (#19286) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~169 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5ca22cce69c97999555d31624b7ba2fc1a4e4424;p=platform%2Fupstream%2Fpytorch.git Allow registering ops without specifying the full schema (#19286) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19286 The operator registration API now allows registering an operator by only giving the operator name and not the full operator schema, as long as the operator schema can be inferred from the kernel function. Reviewed By: dzhulgakov Differential Revision: D14931921 fbshipit-source-id: 3776ce43d4ce67bb5a3ea3d07c37de96eebe08ba --- diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index ea9d02d..41e8565 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -84,6 +84,11 @@ inline bool operator==(const Argument& lhs, const Argument& rhs) { && lhs.alias_info() == rhs.alias_info(); } +struct OperatorName final { + std::string name; + std::string overload_name; +}; + struct FunctionSchema { FunctionSchema( std::string name, @@ -92,8 +97,7 @@ struct FunctionSchema { std::vector returns, bool is_vararg = false, bool is_varret = false) - : name_(std::move(name)), - overload_name_(std::move(overload_name)), + : name_({std::move(name), std::move(overload_name)}), arguments_(std::move(arguments)), returns_(std::move(returns)), is_vararg_(is_vararg), @@ -116,8 +120,7 @@ struct FunctionSchema { is_varret) {} private: - const std::string name_; - const std::string overload_name_; + OperatorName name_; const std::vector arguments_; const std::vector returns_; // if true then this schema takes an arbitrary number of additional arguments @@ -130,10 +133,10 @@ private: public: const std::string& name() const { - return name_; + return name_.name; } const std::string& overload_name() const { - return overload_name_; + return name_.overload_name; } const std::vector& arguments() const { return arguments_; @@ -149,7 +152,7 @@ public: } bool is_mutable() const { // see [custom operator aliasing] - const auto kind = Symbol::fromQualString(name_); + const auto kind = Symbol::fromQualString(name_.name); const auto is_custom_op = !kind.is_aten() && !kind.is_prim(); return is_custom_op || std::any_of( diff --git a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp index b0856b9..afd67e5 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp @@ -3,6 +3,7 @@ #include #include +#include /** * This file tests the legacy function-based API for registering kernels. @@ -455,6 +456,20 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith EXPECT_EQ(4, outputs[0].toInt()); } +std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef arg3) { + return {}; +} + +TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { + auto registrar = RegisterOperators() + .op("_test::no_schema_specified", &kernelForSchemaInference); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", ""); + ASSERT_TRUE(op.has_value()); + + c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema()); +} + template struct kernel_func final { static Return func(Args...) { return {}; } }; diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp index 1311f7d..deef45c 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp @@ -3,6 +3,7 @@ #include #include +#include using c10::RegisterOperators; using c10::kernel; @@ -457,6 +458,20 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen EXPECT_EQ(4, outputs[0].toInt()); } +std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef arg3) { + return {}; +} + +TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { + auto registrar = RegisterOperators() + .op("_test::no_schema_specified", kernel()); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", ""); + ASSERT_TRUE(op.has_value()); + + c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema()); +} + template struct kernel_func final { static Return func(Args...) { return {}; } }; diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp index cb5277f..ea2aa34 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp @@ -3,6 +3,7 @@ #include #include +#include using c10::RegisterOperators; using c10::OperatorKernel; @@ -600,6 +601,22 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens EXPECT_EQ(4, outputs[0].toInt()); } +struct KernelForSchemaInference final : OperatorKernel { + std::tuple operator()(Tensor arg1, int64_t arg2, ArrayRef arg3) { + return {}; + } +}; + +TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { + auto registrar = RegisterOperators() + .op("_test::no_schema_specified", kernel()); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", ""); + ASSERT_TRUE(op.has_value()); + + c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema()); +} + template struct KernelFunc final : OperatorKernel{ Return operator()(Args...) { return {}; } }; diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp index 2efbc9f..b8c65b5 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp @@ -3,6 +3,7 @@ #include #include +#include /** * This file tests the legacy lambda-based API for registering kernels: @@ -406,6 +407,16 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenFallbackKernelWithou EXPECT_EQ(4, outputs[0].toInt()); } +TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { + auto registrar = RegisterOperators() + .op("_test::no_schema_specified", [] (Tensor arg1, int64_t arg2, ArrayRef arg3) -> std::tuple {return {};}); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", ""); + ASSERT_TRUE(op.has_value()); + + c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema()); +} + TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp index a175ad9..673855b 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp @@ -3,6 +3,7 @@ #include #include +#include using c10::RegisterOperators; using c10::kernel; @@ -419,6 +420,16 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenFallbackKernelWithoutTenso EXPECT_EQ(4, outputs[0].toInt()); } +TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { + auto registrar = RegisterOperators() + .op("_test::no_schema_specified", kernel([] (Tensor arg1, int64_t arg2, ArrayRef arg3) -> std::tuple {return {};})); + + auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", ""); + ASSERT_TRUE(op.has_value()); + + c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema()); +} + TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() diff --git a/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp b/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp index 2c9211a..0fc2060 100644 --- a/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp @@ -3,6 +3,7 @@ #include #include +#include using c10::RegisterOperators; using c10::kernel; @@ -140,6 +141,15 @@ TEST(OperatorRegistrationTest_StackBasedKernel, givenFallbackKernelWithoutTensor EXPECT_EQ(4, outputs[0].toInt()); } +void kernelForSchemaInference(Stack* stack, KernelCache* cache) { +} + +TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenFailsBecauseItCannotInferFromStackBasedKernel) { + expectThrows([] { + RegisterOperators().op("_test::no_schema_specified", kernel(&kernelForSchemaInference, &noCache)); + }, "Cannot infer schema from this kernel function. Please explicitly specify the operator schema."); +} + struct Cache final : KernelCache { int last_value = 4; }; diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 39dc342..637f515 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -56,8 +56,33 @@ private: bool owns_registration_; }; -void RegisterOperators::registerOp_(const std::string& schemaStr, detail::KernelRegistrationConfig&& config) { - registerOp_(torch::jit::parseSchema(schemaStr), std::move(config)); +void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, detail::KernelRegistrationConfig&& config) { + either schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr); + if (schemaOrName.is_right()) { + // schema was explicitly specified. Check it matches the inferred one and register the op. + checkSchemaAndRegisterOp_(std::move(schemaOrName).right(), std::move(config)); + } else { + // schema wasn't explicitly specified. Take the inferred schema for registering the op. + AT_ASSERTM(nullptr != config.inferred_function_schema.get(), "Cannot infer schema from this kernel function. Please explicitly specify the operator schema."); + OperatorName name = std::move(schemaOrName).left(); + FunctionSchema inferredSchema( + std::move(name.name), + std::move(name.overload_name), + config.inferred_function_schema->arguments(), + config.inferred_function_schema->returns(), + config.inferred_function_schema->is_vararg(), + config.inferred_function_schema->is_varret() + ); + registerOp_(std::move(inferredSchema), std::move(config)); + } +} + +void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { + if (config.inferred_function_schema.get() != nullptr) { + assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); + } + + registerOp_(std::move(schema), std::move(config)); } void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { @@ -68,10 +93,6 @@ void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegis // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else. AT_ASSERT((config.kernel_func != nullptr) == static_cast(config.cache_creator_func)); - if (config.inferred_function_schema.get() != nullptr) { - assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); - } - registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); } diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 883e888..95c7a2f 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -63,12 +63,12 @@ public: * > c10::dispatchKey(CPUTensorId())); */ template - RegisterOperators op(const std::string& schema, ConfigParameters&&... configParameters) && { + RegisterOperators op(const std::string& schemaOrName, ConfigParameters&&... configParameters) && { static_assert(guts::conjunction>...>::value, "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel " " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators."); - op_(schema, std::forward(configParameters)...); + op_(schemaOrName, std::forward(configParameters)...); return std::move(*this); } @@ -89,18 +89,18 @@ public: C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \ "Please use RegisterOperators().op(...) instead.") // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction. - explicit RegisterOperators(guts::enable_if_t::value && !std::is_same::value, const std::string&> schema, FuncType* func) + explicit RegisterOperators(guts::enable_if_t::value && !std::is_same::value, const std::string&> schemaOrName, FuncType* func) : RegisterOperators() { - legacyAPIOp_(schema, func); + legacyAPIOp_(schemaOrName, func); } template C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \ "Please use RegisterOperators().op(...) instead.") // enable_if: only enable it if FuncType is actually a functor - explicit RegisterOperators(guts::enable_if_t::value, const std::string&> schema, FuncType&& func) + explicit RegisterOperators(guts::enable_if_t::value, const std::string&> schemaOrName, FuncType&& func) : RegisterOperators() { - legacyAPIOp_(schema, std::forward(func)); + legacyAPIOp_(schemaOrName, std::forward(func)); } /** @@ -139,12 +139,12 @@ public: "Please use the new c10::kernel() based API instead.") // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction. guts::enable_if_t::value && !std::is_same::value, RegisterOperators> - op(const std::string& schema, FuncType* func, OtherArgs...) && { + op(const std::string& schemaOrName, FuncType* func, OtherArgs...) && { // We intentionally don't extend this deprecated API to support dispatch keys // and the like to push people towards using the new API. static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead."); - legacyAPIOp_(schema, func); + legacyAPIOp_(schemaOrName, func); return std::move(*this); } @@ -178,14 +178,14 @@ public: "Please use the new c10::kernel() based API instead.") // enable_if: only enable it if FuncType is actually a functor guts::enable_if_t::value, RegisterOperators> - op(const std::string& schema, FuncType&& func, OtherArgs...) && { + op(const std::string& schemaOrName, FuncType&& func, OtherArgs...) && { // We intentionally don't extend this deprecated API to support dispatch keys // and the like to push people towards using the new API. static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead."); static_assert(!std::is_base_of::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead."); - legacyAPIOp_(schema, std::forward(func)); + legacyAPIOp_(schemaOrName, std::forward(func)); return std::move(*this); } @@ -194,20 +194,22 @@ public: private: template void op_(FunctionSchema&& schema, ConfigParameters&&... configParameters) { - registerOp_(std::move(schema), detail::make_registration_config(std::forward(configParameters)...)); + checkSchemaAndRegisterOp_(std::move(schema), detail::make_registration_config(std::forward(configParameters)...)); } template - void op_(const std::string& schema, ConfigParameters&&... configParameters) { - registerOp_(schema, detail::make_registration_config(std::forward(configParameters)...)); + void op_(const std::string& schemaOrName, ConfigParameters&&... configParameters) { + checkSchemaAndRegisterOp_(schemaOrName, detail::make_registration_config(std::forward(configParameters)...)); } template - void legacyAPIOp_(const std::string& schema, FuncType&& func) { - op_(schema, kernel>>(std::forward(func))); + void legacyAPIOp_(const std::string& schemaOrName, FuncType&& func) { + op_(schemaOrName, kernel>>(std::forward(func))); } + void checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config); + void checkSchemaAndRegisterOp_(const std::string& schemaOrName, detail::KernelRegistrationConfig&& config); + void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config); - void registerOp_(const std::string& schema, detail::KernelRegistrationConfig&& config); class OperatorRegistrar; diff --git a/torch/csrc/jit/script/function_schema_parser.cpp b/torch/csrc/jit/script/function_schema_parser.cpp index 639a848..0191398 100644 --- a/torch/csrc/jit/script/function_schema_parser.cpp +++ b/torch/csrc/jit/script/function_schema_parser.cpp @@ -9,9 +9,13 @@ #include using c10::FunctionSchema; +using c10::OperatorName; using c10::Argument; using c10::IValue; using c10::ListType; +using c10::either; +using c10::make_right; +using c10::make_left; using at::TypeKind; namespace torch { @@ -23,16 +27,15 @@ struct SchemaParser { SchemaParser(const std::string& str) : L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {} - FunctionSchema parseDeclaration() { - std::string name = L.expect(TK_IDENT).text(); - if (L.nextIf(':')) { - L.expect(':'); - name = name + "::" + L.expect(TK_IDENT).text(); - } - std::string overload_name = ""; - if (L.nextIf('.')) { - overload_name = L.expect(TK_IDENT).text(); + either parseDeclaration() { + OperatorName name = parseName(); + + // If there is no parentheses coming, then this is just the operator name + // without an argument list + if (L.cur().kind != '(') { + return make_left(std::move(name)); } + std::vector arguments; std::vector returns; bool kwarg_only = false; @@ -62,12 +65,25 @@ struct SchemaParser { returns.push_back( parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false)); } - return FunctionSchema{ - std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), is_vararg, false}; + return make_right( + std::move(name.name), std::move(name.overload_name), std::move(arguments), std::move(returns), is_vararg, false); } - std::vector parseDeclarations() { - std::vector results; + c10::OperatorName parseName() { + std::string name = L.expect(TK_IDENT).text(); + if (L.nextIf(':')) { + L.expect(':'); + name = name + "::" + L.expect(TK_IDENT).text(); + } + std::string overload_name = ""; + if (L.nextIf('.')) { + overload_name = L.expect(TK_IDENT).text(); + } + return {name, overload_name}; + } + + std::vector> parseDeclarations() { + std::vector> results; do { results.push_back(parseDeclaration()); } while (L.nextIf(TK_NEWLINE)); @@ -256,8 +272,14 @@ struct SchemaParser { } // namespace } // namespace script +C10_EXPORT either parseSchemaOrName(const std::string& schemaOrName) { + return script::SchemaParser(schemaOrName).parseDeclarations().at(0); +} + C10_EXPORT FunctionSchema parseSchema(const std::string& schema) { - return script::SchemaParser(schema).parseDeclarations().at(0); + auto parsed = parseSchemaOrName(schema); + AT_CHECK(parsed.is_right(), "Tried to parse a function schema but only the operator name was given"); + return parsed.right(); } } // namespace jit diff --git a/torch/csrc/jit/script/function_schema_parser.h b/torch/csrc/jit/script/function_schema_parser.h index 71e9f8c..e1872d8 100644 --- a/torch/csrc/jit/script/function_schema_parser.h +++ b/torch/csrc/jit/script/function_schema_parser.h @@ -2,12 +2,14 @@ #include #include +#include #include namespace torch { namespace jit { -CAFFE2_API ::c10::FunctionSchema parseSchema(const std::string& schema); +CAFFE2_API c10::either parseSchemaOrName(const std::string& schemaOrName); +CAFFE2_API c10::FunctionSchema parseSchema(const std::string& schema); } // namespace jit } // namespace torch