Allow registering ops without specifying the full schema (#19286)
authorSebastian Messmer <messmer@fb.com>
Thu, 18 Apr 2019 09:00:49 +0000 (02:00 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 09:04:46 +0000 (02:04 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19286

The operator registration API now allows registering an operator by only giving the operator name and not the full operator schema,
as long as the operator schema can be inferred from the kernel function.

Reviewed By: dzhulgakov

Differential Revision: D14931921

fbshipit-source-id: 3776ce43d4ce67bb5a3ea3d07c37de96eebe08ba

aten/src/ATen/core/function_schema.h
aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp
aten/src/ATen/core/op_registration/kernel_function_test.cpp
aten/src/ATen/core/op_registration/kernel_functor_test.cpp
aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
aten/src/ATen/core/op_registration/kernel_stackbased_test.cpp
aten/src/ATen/core/op_registration/op_registration.cpp
aten/src/ATen/core/op_registration/op_registration.h
torch/csrc/jit/script/function_schema_parser.cpp
torch/csrc/jit/script/function_schema_parser.h

index ea9d02d..41e8565 100644 (file)
@@ -84,6 +84,11 @@ inline bool operator==(const Argument& lhs, const Argument& rhs) {
           && lhs.alias_info() == rhs.alias_info();
 }
 
+struct OperatorName final {
+  std::string name;
+  std::string overload_name;
+};
+
 struct FunctionSchema {
   FunctionSchema(
       std::string name,
@@ -92,8 +97,7 @@ struct FunctionSchema {
       std::vector<Argument> returns,
       bool is_vararg = false,
       bool is_varret = false)
-      : name_(std::move(name)),
-        overload_name_(std::move(overload_name)),
+      : name_({std::move(name), std::move(overload_name)}),
         arguments_(std::move(arguments)),
         returns_(std::move(returns)),
         is_vararg_(is_vararg),
@@ -116,8 +120,7 @@ struct FunctionSchema {
             is_varret) {}
 
 private:
-  const std::string name_;
-  const std::string overload_name_;
+  OperatorName name_;
   const std::vector<Argument> arguments_;
   const std::vector<Argument> returns_;
   // if true then this schema takes an arbitrary number of additional arguments
@@ -130,10 +133,10 @@ private:
 
 public:
   const std::string& name() const {
-    return name_;
+    return name_.name;
   }
   const std::string& overload_name() const {
-    return overload_name_;
+    return name_.overload_name;
   }
   const std::vector<Argument>& arguments() const {
     return arguments_;
@@ -149,7 +152,7 @@ public:
   }
   bool is_mutable() const {
     // see [custom operator aliasing]
-    const auto kind = Symbol::fromQualString(name_);
+    const auto kind = Symbol::fromQualString(name_.name);
     const auto is_custom_op = !kind.is_aten() && !kind.is_prim();
     return is_custom_op ||
         std::any_of(
index b0856b9..afd67e5 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 /**
  * This file tests the legacy function-based API for registering kernels.
@@ -455,6 +456,20 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+  return {};
+}
+
+TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+  auto registrar = RegisterOperators()
+      .op("_test::no_schema_specified", &kernelForSchemaInference);
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+  ASSERT_TRUE(op.has_value());
+
+  c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
 template<class Return, class... Args> struct kernel_func final {
   static Return func(Args...) { return {}; }
 };
index 1311f7d..deef45c 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 using c10::RegisterOperators;
 using c10::kernel;
@@ -457,6 +458,20 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+  return {};
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+  auto registrar = RegisterOperators()
+      .op("_test::no_schema_specified", kernel<decltype(kernelForSchemaInference), &kernelForSchemaInference>());
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+  ASSERT_TRUE(op.has_value());
+
+  c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
 template<class Return, class... Args> struct kernel_func final {
   static Return func(Args...) { return {}; }
 };
index cb5277f..ea2aa34 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 using c10::RegisterOperators;
 using c10::OperatorKernel;
@@ -600,6 +601,22 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+struct KernelForSchemaInference final : OperatorKernel {
+  std::tuple<int64_t, Tensor> operator()(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+    return {};
+  }
+};
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+  auto registrar = RegisterOperators()
+      .op("_test::no_schema_specified", kernel<KernelForSchemaInference>());
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+  ASSERT_TRUE(op.has_value());
+
+  c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
 template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
   Return operator()(Args...) { return {}; }
 };
index 2efbc9f..b8c65b5 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 /**
  * This file tests the legacy lambda-based API for registering kernels:
@@ -406,6 +407,16 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenFallbackKernelWithou
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+  auto registrar = RegisterOperators()
+      .op("_test::no_schema_specified", [] (Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) -> std::tuple<int64_t, Tensor> {return {};});
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+  ASSERT_TRUE(op.has_value());
+
+  c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
 TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
   // assert this does not fail because it matches
   RegisterOperators()
index a175ad9..673855b 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 using c10::RegisterOperators;
 using c10::kernel;
@@ -419,6 +420,16 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenFallbackKernelWithoutTenso
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+  auto registrar = RegisterOperators()
+      .op("_test::no_schema_specified", kernel([] (Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) -> std::tuple<int64_t, Tensor> {return {};}));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+  ASSERT_TRUE(op.has_value());
+
+  c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
 TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
   // assert this does not fail because it matches
   RegisterOperators()
index 2c9211a..0fc2060 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
 
 using c10::RegisterOperators;
 using c10::kernel;
@@ -140,6 +141,15 @@ TEST(OperatorRegistrationTest_StackBasedKernel, givenFallbackKernelWithoutTensor
   EXPECT_EQ(4, outputs[0].toInt());
 }
 
+void kernelForSchemaInference(Stack* stack, KernelCache* cache) {
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenFailsBecauseItCannotInferFromStackBasedKernel) {
+  expectThrows<c10::Error>([] {
+      RegisterOperators().op("_test::no_schema_specified", kernel(&kernelForSchemaInference, &noCache));
+  }, "Cannot infer schema from this kernel function. Please explicitly specify the operator schema.");
+}
+
 struct Cache final : KernelCache {
   int last_value = 4;
 };
index 39dc342..637f515 100644 (file)
@@ -56,8 +56,33 @@ private:
   bool owns_registration_;
 };
 
-void RegisterOperators::registerOp_(const std::string& schemaStr, detail::KernelRegistrationConfig&& config) {
-  registerOp_(torch::jit::parseSchema(schemaStr), std::move(config));
+void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, detail::KernelRegistrationConfig&& config) {
+  either<OperatorName, FunctionSchema> schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr);
+  if (schemaOrName.is_right()) {
+    // schema was explicitly specified. Check it matches the inferred one and register the op.
+    checkSchemaAndRegisterOp_(std::move(schemaOrName).right(), std::move(config));
+  } else {
+    // schema wasn't explicitly specified. Take the inferred schema for registering the op.
+    AT_ASSERTM(nullptr != config.inferred_function_schema.get(), "Cannot infer schema from this kernel function. Please explicitly specify the operator schema.");
+    OperatorName name = std::move(schemaOrName).left();
+    FunctionSchema inferredSchema(
+      std::move(name.name),
+      std::move(name.overload_name),
+      config.inferred_function_schema->arguments(),
+      config.inferred_function_schema->returns(),
+      config.inferred_function_schema->is_vararg(),
+      config.inferred_function_schema->is_varret()
+    );
+    registerOp_(std::move(inferredSchema), std::move(config));
+  }
+}
+
+void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
+  if (config.inferred_function_schema.get() != nullptr) {
+    assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
+  }
+
+  registerOp_(std::move(schema), std::move(config));
 }
 
 void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
@@ -68,10 +93,6 @@ void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegis
   // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
   AT_ASSERT((config.kernel_func != nullptr) == static_cast<bool>(config.cache_creator_func));
 
-  if (config.inferred_function_schema.get() != nullptr) {
-    assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
-  }
-
   registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
 }
 
index 883e888..95c7a2f 100644 (file)
@@ -63,12 +63,12 @@ public:
    * >         c10::dispatchKey(CPUTensorId()));
    */
   template<class... ConfigParameters>
-  RegisterOperators op(const std::string& schema, ConfigParameters&&... configParameters) && {
+  RegisterOperators op(const std::string& schemaOrName, ConfigParameters&&... configParameters) && {
     static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
       "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
       " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
 
-    op_(schema, std::forward<ConfigParameters>(configParameters)...);
+    op_(schemaOrName, std::forward<ConfigParameters>(configParameters)...);
     return std::move(*this);
   }
 
@@ -89,18 +89,18 @@ public:
   C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
                          "Please use RegisterOperators().op(...) instead.")
   // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
-  explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, const std::string&> schema, FuncType* func)
+  explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, const std::string&> schemaOrName, FuncType* func)
   : RegisterOperators() {
-    legacyAPIOp_(schema, func);
+    legacyAPIOp_(schemaOrName, func);
   }
 
   template<class FuncType>
   C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
                          "Please use RegisterOperators().op(...) instead.")
   // enable_if: only enable it if FuncType is actually a functor
-  explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, const std::string&> schema, FuncType&& func)
+  explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, const std::string&> schemaOrName, FuncType&& func)
   : RegisterOperators() {
-    legacyAPIOp_(schema, std::forward<FuncType>(func));
+    legacyAPIOp_(schemaOrName, std::forward<FuncType>(func));
   }
 
   /**
@@ -139,12 +139,12 @@ public:
                           "Please use the new c10::kernel() based API instead.")
    // enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
    guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, RegisterOperators>
-   op(const std::string& schema, FuncType* func, OtherArgs...) && {
+   op(const std::string& schemaOrName, FuncType* func, OtherArgs...) && {
      // We intentionally don't extend this deprecated API to support dispatch keys
      // and the like to push people towards using the new API.
      static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
 
-     legacyAPIOp_(schema, func);
+     legacyAPIOp_(schemaOrName, func);
      return std::move(*this);
    }
 
@@ -178,14 +178,14 @@ public:
                            "Please use the new c10::kernel() based API instead.")
     // enable_if: only enable it if FuncType is actually a functor
     guts::enable_if_t<guts::is_functor<FuncType>::value, RegisterOperators>
-    op(const std::string& schema, FuncType&& func, OtherArgs...) && {
+    op(const std::string& schemaOrName, FuncType&& func, OtherArgs...) && {
       // We intentionally don't extend this deprecated API to support dispatch keys
       // and the like to push people towards using the new API.
       static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
 
       static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
 
-      legacyAPIOp_(schema, std::forward<FuncType>(func));
+      legacyAPIOp_(schemaOrName, std::forward<FuncType>(func));
       return std::move(*this);
     }
 
@@ -194,20 +194,22 @@ public:
 private:
   template<class... ConfigParameters>
   void op_(FunctionSchema&& schema, ConfigParameters&&... configParameters) {
-    registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+    checkSchemaAndRegisterOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
   }
   template<class... ConfigParameters>
-  void op_(const std::string& schema, ConfigParameters&&... configParameters) {
-    registerOp_(schema, detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+  void op_(const std::string& schemaOrName, ConfigParameters&&... configParameters) {
+    checkSchemaAndRegisterOp_(schemaOrName, detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
   }
 
   template<class FuncType>
-  void legacyAPIOp_(const std::string& schema, FuncType&& func) {
-    op_(schema, kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
+  void legacyAPIOp_(const std::string& schemaOrName, FuncType&& func) {
+    op_(schemaOrName, kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
   }
 
+  void checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
+  void checkSchemaAndRegisterOp_(const std::string& schemaOrName, detail::KernelRegistrationConfig&& config);
+
   void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
-  void registerOp_(const std::string& schema, detail::KernelRegistrationConfig&& config);
 
   class OperatorRegistrar;
 
index 639a848..0191398 100644 (file)
@@ -9,9 +9,13 @@
 #include <vector>
 
 using c10::FunctionSchema;
+using c10::OperatorName;
 using c10::Argument;
 using c10::IValue;
 using c10::ListType;
+using c10::either;
+using c10::make_right;
+using c10::make_left;
 using at::TypeKind;
 
 namespace torch {
@@ -23,16 +27,15 @@ struct SchemaParser {
   SchemaParser(const std::string& str)
       : L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
 
-  FunctionSchema parseDeclaration() {
-    std::string name = L.expect(TK_IDENT).text();
-    if (L.nextIf(':')) {
-      L.expect(':');
-      name = name + "::" + L.expect(TK_IDENT).text();
-    }
-    std::string overload_name = "";
-    if (L.nextIf('.')) {
-      overload_name = L.expect(TK_IDENT).text();
+  either<OperatorName, FunctionSchema> parseDeclaration() {
+    OperatorName name = parseName();
+
+    // If there is no parentheses coming, then this is just the operator name
+    // without an argument list
+    if (L.cur().kind != '(') {
+      return make_left<OperatorName, FunctionSchema>(std::move(name));
     }
+
     std::vector<Argument> arguments;
     std::vector<Argument> returns;
     bool kwarg_only = false;
@@ -62,12 +65,25 @@ struct SchemaParser {
       returns.push_back(
           parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
     }
-    return FunctionSchema{
-        std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), is_vararg, false};
+    return make_right<OperatorName, FunctionSchema>(
+        std::move(name.name), std::move(name.overload_name), std::move(arguments), std::move(returns), is_vararg, false);
   }
 
-  std::vector<FunctionSchema> parseDeclarations() {
-    std::vector<FunctionSchema> results;
+  c10::OperatorName parseName() {
+    std::string name = L.expect(TK_IDENT).text();
+    if (L.nextIf(':')) {
+      L.expect(':');
+      name = name + "::" + L.expect(TK_IDENT).text();
+    }
+    std::string overload_name = "";
+    if (L.nextIf('.')) {
+      overload_name = L.expect(TK_IDENT).text();
+    }
+    return {name, overload_name};
+  }
+
+  std::vector<either<OperatorName, FunctionSchema>> parseDeclarations() {
+    std::vector<either<OperatorName, FunctionSchema>> results;
     do {
       results.push_back(parseDeclaration());
     } while (L.nextIf(TK_NEWLINE));
@@ -256,8 +272,14 @@ struct SchemaParser {
 } // namespace
 } // namespace script
 
+C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
+  return script::SchemaParser(schemaOrName).parseDeclarations().at(0);
+}
+
 C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
-  return script::SchemaParser(schema).parseDeclarations().at(0);
+  auto parsed = parseSchemaOrName(schema);
+  AT_CHECK(parsed.is_right(), "Tried to parse a function schema but only the operator name was given");
+  return parsed.right();
 }
 
 } // namespace jit
index 71e9f8c..e1872d8 100644 (file)
@@ -2,12 +2,14 @@
 
 #include <ATen/core/function_schema.h>
 #include <ATen/core/Macros.h>
+#include <c10/util/either.h>
 #include <string>
 
 namespace torch {
 namespace jit {
 
-CAFFE2_API ::c10::FunctionSchema parseSchema(const std::string& schema);
+CAFFE2_API c10::either<c10::OperatorName, c10::FunctionSchema> parseSchemaOrName(const std::string& schemaOrName);
+CAFFE2_API c10::FunctionSchema parseSchema(const std::string& schema);
 
 } // namespace jit
 } // namespace torch