&& lhs.alias_info() == rhs.alias_info();
}
+struct OperatorName final {
+ std::string name;
+ std::string overload_name;
+};
+
struct FunctionSchema {
FunctionSchema(
std::string name,
std::vector<Argument> returns,
bool is_vararg = false,
bool is_varret = false)
- : name_(std::move(name)),
- overload_name_(std::move(overload_name)),
+ : name_({std::move(name), std::move(overload_name)}),
arguments_(std::move(arguments)),
returns_(std::move(returns)),
is_vararg_(is_vararg),
is_varret) {}
private:
- const std::string name_;
- const std::string overload_name_;
+ OperatorName name_;
const std::vector<Argument> arguments_;
const std::vector<Argument> returns_;
// if true then this schema takes an arbitrary number of additional arguments
public:
const std::string& name() const {
- return name_;
+ return name_.name;
}
const std::string& overload_name() const {
- return overload_name_;
+ return name_.overload_name;
}
const std::vector<Argument>& arguments() const {
return arguments_;
}
bool is_mutable() const {
// see [custom operator aliasing]
- const auto kind = Symbol::fromQualString(name_);
+ const auto kind = Symbol::fromQualString(name_.name);
const auto is_custom_op = !kind.is_aten() && !kind.is_prim();
return is_custom_op ||
std::any_of(
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
/**
* This file tests the legacy function-based API for registering kernels.
EXPECT_EQ(4, outputs[0].toInt());
}
+std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+ return {};
+}
+
+TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+ auto registrar = RegisterOperators()
+ .op("_test::no_schema_specified", &kernelForSchemaInference);
+
+ auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+ ASSERT_TRUE(op.has_value());
+
+ c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
template<class Return, class... Args> struct kernel_func final {
static Return func(Args...) { return {}; }
};
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
using c10::RegisterOperators;
using c10::kernel;
EXPECT_EQ(4, outputs[0].toInt());
}
+std::tuple<int64_t, Tensor> kernelForSchemaInference(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+ return {};
+}
+
+TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+ auto registrar = RegisterOperators()
+ .op("_test::no_schema_specified", kernel<decltype(kernelForSchemaInference), &kernelForSchemaInference>());
+
+ auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+ ASSERT_TRUE(op.has_value());
+
+ c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
template<class Return, class... Args> struct kernel_func final {
static Return func(Args...) { return {}; }
};
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
using c10::RegisterOperators;
using c10::OperatorKernel;
EXPECT_EQ(4, outputs[0].toInt());
}
+struct KernelForSchemaInference final : OperatorKernel {
+ std::tuple<int64_t, Tensor> operator()(Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) {
+ return {};
+ }
+};
+
+TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+ auto registrar = RegisterOperators()
+ .op("_test::no_schema_specified", kernel<KernelForSchemaInference>());
+
+ auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+ ASSERT_TRUE(op.has_value());
+
+ c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
Return operator()(Args...) { return {}; }
};
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
/**
* This file tests the legacy lambda-based API for registering kernels:
EXPECT_EQ(4, outputs[0].toInt());
}
+TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+ auto registrar = RegisterOperators()
+ .op("_test::no_schema_specified", [] (Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) -> std::tuple<int64_t, Tensor> {return {};});
+
+ auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+ ASSERT_TRUE(op.has_value());
+
+ c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
using c10::RegisterOperators;
using c10::kernel;
EXPECT_EQ(4, outputs[0].toInt());
}
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) {
+ auto registrar = RegisterOperators()
+ .op("_test::no_schema_specified", kernel([] (Tensor arg1, int64_t arg2, ArrayRef<Tensor> arg3) -> std::tuple<int64_t, Tensor> {return {};}));
+
+ auto op = c10::Dispatcher::singleton().findSchema("_test::no_schema_specified", "");
+ ASSERT_TRUE(op.has_value());
+
+ c10::assertSchemasHaveSameSignature(torch::jit::parseSchema("_test::no_schema_specified(Tensor arg1, int arg2, Tensor[] arg3) -> (int, Tensor)"), op->schema());
+}
+
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/Tensor.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
using c10::RegisterOperators;
using c10::kernel;
EXPECT_EQ(4, outputs[0].toInt());
}
+void kernelForSchemaInference(Stack* stack, KernelCache* cache) {
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenFailsBecauseItCannotInferFromStackBasedKernel) {
+ expectThrows<c10::Error>([] {
+ RegisterOperators().op("_test::no_schema_specified", kernel(&kernelForSchemaInference, &noCache));
+ }, "Cannot infer schema from this kernel function. Please explicitly specify the operator schema.");
+}
+
struct Cache final : KernelCache {
int last_value = 4;
};
bool owns_registration_;
};
-void RegisterOperators::registerOp_(const std::string& schemaStr, detail::KernelRegistrationConfig&& config) {
- registerOp_(torch::jit::parseSchema(schemaStr), std::move(config));
+void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, detail::KernelRegistrationConfig&& config) {
+ either<OperatorName, FunctionSchema> schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr);
+ if (schemaOrName.is_right()) {
+ // schema was explicitly specified. Check it matches the inferred one and register the op.
+ checkSchemaAndRegisterOp_(std::move(schemaOrName).right(), std::move(config));
+ } else {
+ // schema wasn't explicitly specified. Take the inferred schema for registering the op.
+ AT_ASSERTM(nullptr != config.inferred_function_schema.get(), "Cannot infer schema from this kernel function. Please explicitly specify the operator schema.");
+ OperatorName name = std::move(schemaOrName).left();
+ FunctionSchema inferredSchema(
+ std::move(name.name),
+ std::move(name.overload_name),
+ config.inferred_function_schema->arguments(),
+ config.inferred_function_schema->returns(),
+ config.inferred_function_schema->is_vararg(),
+ config.inferred_function_schema->is_varret()
+ );
+ registerOp_(std::move(inferredSchema), std::move(config));
+ }
+}
+
+void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
+ if (config.inferred_function_schema.get() != nullptr) {
+ assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
+ }
+
+ registerOp_(std::move(schema), std::move(config));
}
void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
// if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
AT_ASSERT((config.kernel_func != nullptr) == static_cast<bool>(config.cache_creator_func));
- if (config.inferred_function_schema.get() != nullptr) {
- assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
- }
-
registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
}
* > c10::dispatchKey(CPUTensorId()));
*/
template<class... ConfigParameters>
- RegisterOperators op(const std::string& schema, ConfigParameters&&... configParameters) && {
+ RegisterOperators op(const std::string& schemaOrName, ConfigParameters&&... configParameters) && {
static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
"Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
" and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
- op_(schema, std::forward<ConfigParameters>(configParameters)...);
+ op_(schemaOrName, std::forward<ConfigParameters>(configParameters)...);
return std::move(*this);
}
C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
"Please use RegisterOperators().op(...) instead.")
// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
- explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, const std::string&> schema, FuncType* func)
+ explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, const std::string&> schemaOrName, FuncType* func)
: RegisterOperators() {
- legacyAPIOp_(schema, func);
+ legacyAPIOp_(schemaOrName, func);
}
template<class FuncType>
C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
"Please use RegisterOperators().op(...) instead.")
// enable_if: only enable it if FuncType is actually a functor
- explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, const std::string&> schema, FuncType&& func)
+ explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, const std::string&> schemaOrName, FuncType&& func)
: RegisterOperators() {
- legacyAPIOp_(schema, std::forward<FuncType>(func));
+ legacyAPIOp_(schemaOrName, std::forward<FuncType>(func));
}
/**
"Please use the new c10::kernel() based API instead.")
// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, RegisterOperators>
- op(const std::string& schema, FuncType* func, OtherArgs...) && {
+ op(const std::string& schemaOrName, FuncType* func, OtherArgs...) && {
// We intentionally don't extend this deprecated API to support dispatch keys
// and the like to push people towards using the new API.
static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
- legacyAPIOp_(schema, func);
+ legacyAPIOp_(schemaOrName, func);
return std::move(*this);
}
"Please use the new c10::kernel() based API instead.")
// enable_if: only enable it if FuncType is actually a functor
guts::enable_if_t<guts::is_functor<FuncType>::value, RegisterOperators>
- op(const std::string& schema, FuncType&& func, OtherArgs...) && {
+ op(const std::string& schemaOrName, FuncType&& func, OtherArgs...) && {
// We intentionally don't extend this deprecated API to support dispatch keys
// and the like to push people towards using the new API.
static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
- legacyAPIOp_(schema, std::forward<FuncType>(func));
+ legacyAPIOp_(schemaOrName, std::forward<FuncType>(func));
return std::move(*this);
}
private:
template<class... ConfigParameters>
void op_(FunctionSchema&& schema, ConfigParameters&&... configParameters) {
- registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+ checkSchemaAndRegisterOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
}
template<class... ConfigParameters>
- void op_(const std::string& schema, ConfigParameters&&... configParameters) {
- registerOp_(schema, detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+ void op_(const std::string& schemaOrName, ConfigParameters&&... configParameters) {
+ checkSchemaAndRegisterOp_(schemaOrName, detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
}
template<class FuncType>
- void legacyAPIOp_(const std::string& schema, FuncType&& func) {
- op_(schema, kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
+ void legacyAPIOp_(const std::string& schemaOrName, FuncType&& func) {
+ op_(schemaOrName, kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
}
+ void checkSchemaAndRegisterOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
+ void checkSchemaAndRegisterOp_(const std::string& schemaOrName, detail::KernelRegistrationConfig&& config);
+
void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
- void registerOp_(const std::string& schema, detail::KernelRegistrationConfig&& config);
class OperatorRegistrar;
#include <vector>
using c10::FunctionSchema;
+using c10::OperatorName;
using c10::Argument;
using c10::IValue;
using c10::ListType;
+using c10::either;
+using c10::make_right;
+using c10::make_left;
using at::TypeKind;
namespace torch {
SchemaParser(const std::string& str)
: L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
- FunctionSchema parseDeclaration() {
- std::string name = L.expect(TK_IDENT).text();
- if (L.nextIf(':')) {
- L.expect(':');
- name = name + "::" + L.expect(TK_IDENT).text();
- }
- std::string overload_name = "";
- if (L.nextIf('.')) {
- overload_name = L.expect(TK_IDENT).text();
+ either<OperatorName, FunctionSchema> parseDeclaration() {
+ OperatorName name = parseName();
+
+ // If there is no parentheses coming, then this is just the operator name
+ // without an argument list
+ if (L.cur().kind != '(') {
+ return make_left<OperatorName, FunctionSchema>(std::move(name));
}
+
std::vector<Argument> arguments;
std::vector<Argument> returns;
bool kwarg_only = false;
returns.push_back(
parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
}
- return FunctionSchema{
- std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), is_vararg, false};
+ return make_right<OperatorName, FunctionSchema>(
+ std::move(name.name), std::move(name.overload_name), std::move(arguments), std::move(returns), is_vararg, false);
}
- std::vector<FunctionSchema> parseDeclarations() {
- std::vector<FunctionSchema> results;
+ c10::OperatorName parseName() {
+ std::string name = L.expect(TK_IDENT).text();
+ if (L.nextIf(':')) {
+ L.expect(':');
+ name = name + "::" + L.expect(TK_IDENT).text();
+ }
+ std::string overload_name = "";
+ if (L.nextIf('.')) {
+ overload_name = L.expect(TK_IDENT).text();
+ }
+ return {name, overload_name};
+ }
+
+ std::vector<either<OperatorName, FunctionSchema>> parseDeclarations() {
+ std::vector<either<OperatorName, FunctionSchema>> results;
do {
results.push_back(parseDeclaration());
} while (L.nextIf(TK_NEWLINE));
} // namespace
} // namespace script
+C10_EXPORT either<OperatorName, FunctionSchema> parseSchemaOrName(const std::string& schemaOrName) {
+ return script::SchemaParser(schemaOrName).parseDeclarations().at(0);
+}
+
C10_EXPORT FunctionSchema parseSchema(const std::string& schema) {
- return script::SchemaParser(schema).parseDeclarations().at(0);
+ auto parsed = parseSchemaOrName(schema);
+ AT_CHECK(parsed.is_right(), "Tried to parse a function schema but only the operator name was given");
+ return parsed.right();
}
} // namespace jit
#include <ATen/core/function_schema.h>
#include <ATen/core/Macros.h>
+#include <c10/util/either.h>
#include <string>
namespace torch {
namespace jit {
-CAFFE2_API ::c10::FunctionSchema parseSchema(const std::string& schema);
+CAFFE2_API c10::either<c10::OperatorName, c10::FunctionSchema> parseSchemaOrName(const std::string& schemaOrName);
+CAFFE2_API c10::FunctionSchema parseSchema(const std::string& schema);
} // namespace jit
} // namespace torch