#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
using c10::RegisterOperators;
-using c10::FunctionSchema;
-using c10::Argument;
-using c10::IntType;
-using c10::FloatType;
-using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
return 0;
}
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
int64_t incrementKernel(const Tensor& tensor, int64_t input) {
return input + 1;
}
return input - 1;
}
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, &incrementKernel);
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel);
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) {
- auto registrar = RegisterOperators(opSchema, &incrementKernel);
+ auto registrar = RegisterOperators("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel);
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, &incrementKernel)
- .op(errorOpSchema, &errorKernel);
+ .op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel)
+ .op("_test::error(Tensor dummy, int input) -> int", &errorKernel);
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, &incrementKernel);
- auto registrar2 = RegisterOperators().op(errorOpSchema, &errorKernel);
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel);
+ auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", &errorKernel);
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar = RegisterOperators().op(opSchema, &incrementKernel);
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel);
expectCallsIncrement(TensorType1());
}
was_called = true;
}
-FunctionSchema opWithoutOutputSchema(
- "_test::no_return",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithoutOutputSchema, &kernelWithoutOutput);
+ auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", &kernelWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
ASSERT_TRUE(op.has_value());
return std::make_tuple();
}
-FunctionSchema opWithZeroOutputsSchema(
- "_test::zero_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, &kernelWithZeroOutputs);
+ auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", &kernelWithZeroOutputs);
auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
ASSERT_TRUE(op.has_value());
return a + b;
}
-FunctionSchema opWithIntOutputSchema(
- "_test::int_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("a", IntType::get()),
- Argument("b", IntType::get())}),
- (std::vector<Argument>{Argument("sum", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntOutputSchema, &kernelWithIntOutput);
+ .op("_test::int_output(Tensor dummy, int a, int b) -> int", &kernelWithIntOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
ASSERT_TRUE(op.has_value());
return input;
}
-FunctionSchema opWithTensorOutput(
- "_test::returning_tensor",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorOutput, &kernelWithTensorOutput);
+ .op("_test::returning_tensor(Tensor input) -> Tensor", &kernelWithTensorOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
ASSERT_TRUE(op.has_value());
return {input1, input2, input3};
}
-FunctionSchema opWithTensorListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("input1"),
- Argument("input2"),
- Argument("input3")}),
- (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListOutputSchema, &kernelWithTensorListOutput);
+ .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", &kernelWithTensorListOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
return {input1, input2, input3};
}
-FunctionSchema opWithIntListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input1", IntType::get()),
- Argument("input2", IntType::get()),
- Argument("input3", IntType::get())}),
- (std::vector<Argument>{Argument("output", ListType::ofInts())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListOutputSchema, &kernelWithIntListOutput);
+ .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", &kernelWithIntListOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
);
}
-FunctionSchema opWithMultipleOutputsSchema(
- "_test::multiple_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output1"),
- Argument("output2", IntType::get()),
- Argument("output3", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithMultipleOutputsSchema, &kernelWithMultipleOutputs);
+ .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[])", &kernelWithMultipleOutputs);
auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
ASSERT_TRUE(op.has_value());
Tensor kernelWithTensorInputByValueWithOutput(Tensor input1) {
return input1;
}
-
-FunctionSchema opWithTensorInputWithOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, &kernelWithTensorInputByReferenceWithOutput);
+ .op("_test::tensor_input(Tensor input) -> Tensor", &kernelWithTensorInputByReferenceWithOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, &kernelWithTensorInputByValueWithOutput);
+ .op("_test::tensor_input(Tensor input) -> Tensor", &kernelWithTensorInputByValueWithOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
captured_input = input1;
}
-FunctionSchema opWithTensorInputWithoutOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, &kernelWithTensorInputByReferenceWithoutOutput);
+ .op("_test::tensor_input(Tensor input) -> ()", &kernelWithTensorInputByReferenceWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, &kernelWithTensorInputByValueWithoutOutput);
+ .op("_test::tensor_input(Tensor input) -> ()", &kernelWithTensorInputByValueWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
captured_int_input = input1;
}
-FunctionSchema opWithIntInputWithoutOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithoutOutput, &kernelWithIntInputWithoutOutput);
+ .op("_test::int_input(Tensor dummy, int input) -> ()", &kernelWithIntInputWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
return input1 + 1;
}
-FunctionSchema opWithIntInputWithOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithOutput, &kernelWithIntInputWithOutput);
+ .op("_test::int_input(Tensor dummy, int input) -> int", &kernelWithIntInputWithOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = input1.size();
}
-FunctionSchema opWithIntListInputWithoutOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithoutOutput, &kernelWithIntListInputWithoutOutput);
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", &kernelWithIntListInputWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
return input1.size();
}
-FunctionSchema opWithIntListInputWithOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithOutput, &kernelWithIntListInputWithOutput);
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> int", &kernelWithIntListInputWithOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = input1.size();
}
-FunctionSchema opWithTensorListInputWithoutOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithoutOutput, &kernelWithTensorListInputWithoutOutput);
+ .op("_test::tensor_list_input(Tensor[] input) -> ()", &kernelWithTensorListInputWithoutOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
return input1.size();
}
-FunctionSchema opWithTensorListInputWithOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithOutput, &kernelWithTensorListInputWithOutput);
+ .op("_test::tensor_list_input(Tensor[] input) -> int", &kernelWithTensorListInputWithOutput);
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> int", &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", &kernel_func<int64_t, Tensor>::func);
}, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func);
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", &kernel_func<void, Tensor, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{}),
- (std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func);
+ .op("_test::mismatch() -> ()", &kernel_func<void, Tensor, Tensor>::func);
}, "The number of arguments is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> ()", &kernel_func<void, Tensor, Tensor>::func);
}, "The number of arguments is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
- (std::vector<Argument>{})
- ), &kernel_func<void, Tensor, Tensor>::func);
+ .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", &kernel_func<void, Tensor, Tensor>::func);
}, "The number of arguments is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor, int64_t>::func);
+ .op("_test::mismatch(Tensor arg1, int arg2) -> int", &kernel_func<int64_t, Tensor, int64_t>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor, int64_t>::func);
+ .op("_test::mismatch(Tensor arg1, float arg2) -> int", &kernel_func<int64_t, Tensor, int64_t>::func);
}, "Type mismatch in argument 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor, int64_t>::func);
+ .op("_test::mismatch(int arg1, int arg2) -> int", &kernel_func<int64_t, Tensor, int64_t>::func);
}, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> int", &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> ()", &kernel_func<int64_t, Tensor>::func);
}, "The number of returns is different. Specified 0 but inferred 1"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()),
- Argument("ret2", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (int, int)", &kernel_func<int64_t, Tensor>::func);
}, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), &kernel_func<void, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> ()", &kernel_func<void, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), &kernel_func<void, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> Tensor", &kernel_func<void, Tensor>::func);
}, "The number of returns is different. Specified 1 but inferred 0"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), &kernel_func<void, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", &kernel_func<void, Tensor>::func);
}, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> ()", &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
}, "The number of returns is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1")})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> Tensor", &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
}, "The number of returns is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func);
}, "The number of returns is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> int", &kernel_func<int64_t, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> Tensor", &kernel_func<int64_t, Tensor>::func);
}, "Type mismatch in return 1: specified Tensor but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), &kernel_func<int64_t, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> float", &kernel_func<int64_t, Tensor>::func);
}, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), &kernel_func<Tensor, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> Tensor", &kernel_func<Tensor, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), &kernel_func<Tensor, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> float", &kernel_func<Tensor, Tensor>::func);
}, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
- ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (Tensor, int)", &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (Tensor, float)", &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
}, "Type mismatch in return 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
+ .op("_test::mismatch(Tensor arg) -> (int, int)", &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func);
}, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
#include <ATen/core/Tensor.h>
using c10::RegisterOperators;
-using c10::FunctionSchema;
-using c10::Argument;
-using c10::IntType;
-using c10::FloatType;
-using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
return 0;
}
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
int64_t incrementKernel(const Tensor& tensor, int64_t input) {
return input + 1;
}
return input - 1;
}
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()))
- .op(opSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()))
- .op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()))
- .op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()))
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
- auto registrar2 = RegisterOperators().op(opSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
- auto registrar3 = RegisterOperators().op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()));
- auto registrar4 = RegisterOperators().op(errorOpSchema, kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
+ auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType1()));
+ auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel<decltype(errorKernel), &errorKernel>(), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar1 = RegisterOperators().op(opSchema, kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(incrementKernel), &incrementKernel>(), dispatchKey(TensorType1()));
{
- auto registrar2 = RegisterOperators().op(opSchema, kernel<decltype(decrementKernel), &decrementKernel>(), dispatchKey(TensorType2()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<decltype(decrementKernel), &decrementKernel>(), dispatchKey(TensorType2()));
// assert that schema and cpu kernel are present
expectCallsIncrement(TensorType1());
was_called = true;
}
-FunctionSchema opWithoutOutputSchema(
- "_test::no_return",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithoutOutputSchema, kernel<decltype(kernelWithoutOutput), &kernelWithoutOutput>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", kernel<decltype(kernelWithoutOutput), &kernelWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
ASSERT_TRUE(op.has_value());
return std::make_tuple();
}
-FunctionSchema opWithZeroOutputsSchema(
- "_test::zero_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, kernel<decltype(kernelWithZeroOutputs), &kernelWithZeroOutputs>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", kernel<decltype(kernelWithZeroOutputs), &kernelWithZeroOutputs>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
ASSERT_TRUE(op.has_value());
return a + b;
}
-FunctionSchema opWithIntOutputSchema(
- "_test::int_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("a", IntType::get()),
- Argument("b", IntType::get())}),
- (std::vector<Argument>{Argument("sum", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntOutputSchema, kernel<decltype(kernelWithIntOutput), &kernelWithIntOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_output(Tensor dummy, int a, int b) -> int", kernel<decltype(kernelWithIntOutput), &kernelWithIntOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
ASSERT_TRUE(op.has_value());
return input;
}
-FunctionSchema opWithTensorOutput(
- "_test::returning_tensor",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorOutput, kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorOutput, kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType2()));
+ .op("_test::returning_tensor(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType1()))
+ .op("_test::returning_tensor(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorOutput), &kernelWithTensorOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
ASSERT_TRUE(op.has_value());
return {input1, input2, input3};
}
-FunctionSchema opWithTensorListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("input1"),
- Argument("input2"),
- Argument("input3")}),
- (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListOutputSchema, kernel<decltype(kernelWithTensorListOutput), &kernelWithTensorListOutput>(), dispatchKey(TensorType1()));
+ .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", kernel<decltype(kernelWithTensorListOutput), &kernelWithTensorListOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
return {input1, input2, input3};
}
-FunctionSchema opWithIntListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input1", IntType::get()),
- Argument("input2", IntType::get()),
- Argument("input3", IntType::get())}),
- (std::vector<Argument>{Argument("output", ListType::ofInts())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListOutputSchema, kernel<decltype(kernelWithIntListOutput), &kernelWithIntListOutput>(), dispatchKey(TensorType1()));
+ .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", kernel<decltype(kernelWithIntListOutput), &kernelWithIntListOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
);
}
-FunctionSchema opWithMultipleOutputsSchema(
- "_test::multiple_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output1"),
- Argument("output2", IntType::get()),
- Argument("output3", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithMultipleOutputsSchema, kernel<decltype(kernelWithMultipleOutputs), &kernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
+ .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[])", kernel<decltype(kernelWithMultipleOutputs), &kernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
ASSERT_TRUE(op.has_value());
return input1;
}
-FunctionSchema opWithTensorInputWithOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorInputByReferenceWithOutput), &kernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput, kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<decltype(kernelWithTensorInputByValueWithOutput), &kernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
captured_input = input1;
}
-FunctionSchema opWithTensorInputWithoutOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<decltype(kernelWithTensorInputByReferenceWithoutOutput), &kernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput, kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<decltype(kernelWithTensorInputByValueWithoutOutput), &kernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
captured_int_input = input1;
}
-FunctionSchema opWithIntInputWithoutOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithoutOutput, kernel<decltype(kernelWithIntInputWithoutOutput), &kernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_input(Tensor dummy, int input) -> ()", kernel<decltype(kernelWithIntInputWithoutOutput), &kernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
return input1 + 1;
}
-FunctionSchema opWithIntInputWithOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithOutput, kernel<decltype(kernelWithIntInputWithOutput), &kernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_input(Tensor dummy, int input) -> int", kernel<decltype(kernelWithIntInputWithOutput), &kernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = input1.size();
}
-FunctionSchema opWithIntListInputWithoutOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithoutOutput, kernel<decltype(kernelWithIntListInputWithoutOutput), &kernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", kernel<decltype(kernelWithIntListInputWithoutOutput), &kernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
return input1.size();
}
-FunctionSchema opWithIntListInputWithOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithOutput, kernel<decltype(kernelWithIntListInputWithOutput), &kernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> int", kernel<decltype(kernelWithIntListInputWithOutput), &kernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
captured_input_list_size = input1.size();
}
-FunctionSchema opWithTensorListInputWithoutOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithoutOutput, kernel<decltype(kernelWithTensorListInputWithoutOutput), &kernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::tensor_list_input(Tensor[] input) -> ()", kernel<decltype(kernelWithTensorListInputWithoutOutput), &kernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
return input1.size();
}
-FunctionSchema opWithTensorListInputWithOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithOutput, kernel<decltype(kernelWithTensorListInputWithOutput), &kernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::tensor_list_input(Tensor[] input) -> int", kernel<decltype(kernelWithTensorListInputWithOutput), &kernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch() -> ()", kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, int arg2) -> int", kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, float arg2) -> int", kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in argument 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(int arg1, int arg2) -> int", kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 1"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()),
- Argument("ret2", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 0"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1")})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified Tensor but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, int)", kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, float)", kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
#include <ATen/core/Tensor.h>
using c10::RegisterOperators;
-using c10::FunctionSchema;
using c10::OperatorKernel;
-using c10::Argument;
-using c10::IntType;
-using c10::FloatType;
-using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
}
};
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
struct IncrementKernel final : OperatorKernel {
int64_t operator()(const Tensor& tensor, int64_t input) {
return input + 1;
}
};
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<IncrementKernel>(), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()))
- .op(opSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()))
- .op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType1()))
- .op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel<IncrementKernel>(), dispatchKey(TensorType1()))
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType2()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType1()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
- auto registrar2 = RegisterOperators().op(opSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
- auto registrar3 = RegisterOperators().op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType1()));
- auto registrar4 = RegisterOperators().op(errorOpSchema, kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType2()));
+ auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType1()));
+ auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel<ErrorKernel>(), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar1 = RegisterOperators().op(opSchema, kernel<IncrementKernel>(), dispatchKey(TensorType1()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<IncrementKernel>(), dispatchKey(TensorType1()));
{
- auto registrar2 = RegisterOperators().op(opSchema, kernel<DecrementKernel>(), dispatchKey(TensorType2()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel<DecrementKernel>(), dispatchKey(TensorType2()));
// assert that schema and cpu kernel are present
expectCallsIncrement(TensorType1());
}
};
-FunctionSchema opWithoutOutputSchema(
- "_test::no_return",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithoutOutputSchema, kernel<KernelWithoutOutput>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", kernel<KernelWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithZeroOutputsSchema(
- "_test::zero_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, kernel<KernelWithZeroOutputs>(), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", kernel<KernelWithZeroOutputs>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
ASSERT_TRUE(op.has_value());
}
};
-
-FunctionSchema opWithIntOutputSchema(
- "_test::int_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("a", IntType::get()),
- Argument("b", IntType::get())}),
- (std::vector<Argument>{Argument("sum", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntOutputSchema, kernel<KernelWithIntOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_output(Tensor dummy, int a, int b) -> int", kernel<KernelWithIntOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorOutput(
- "_test::returning_tensor",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorOutput, kernel<KernelWithTensorOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorOutput, kernel<KernelWithTensorOutput>(), dispatchKey(TensorType2()));
+ .op("_test::returning_tensor(Tensor input) -> Tensor", kernel<KernelWithTensorOutput>(), dispatchKey(TensorType1()))
+ .op("_test::returning_tensor(Tensor input) -> Tensor", kernel<KernelWithTensorOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("input1"),
- Argument("input2"),
- Argument("input3")}),
- (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListOutputSchema, kernel<KernelWithTensorListOutput>(), dispatchKey(TensorType1()));
+ .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", kernel<KernelWithTensorListOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithIntListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input1", IntType::get()),
- Argument("input2", IntType::get()),
- Argument("input3", IntType::get())}),
- (std::vector<Argument>{Argument("output", ListType::ofInts())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListOutputSchema, kernel<KernelWithIntListOutput>(), dispatchKey(TensorType1()));
+ .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", kernel<KernelWithIntListOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithMultipleOutputsSchema(
- "_test::multiple_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output1"),
- Argument("output2", IntType::get()),
- Argument("output3", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithMultipleOutputsSchema, kernel<KernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
+ .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[])", kernel<KernelWithMultipleOutputs>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorInputWithOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<KernelWithTensorInputByReferenceWithOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput, kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> Tensor", kernel<KernelWithTensorInputByValueWithOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorInputWithoutOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<KernelWithTensorInputByReferenceWithoutOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput, kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType1()))
+ .op("_test::tensor_input(Tensor input) -> ()", kernel<KernelWithTensorInputByValueWithoutOutput>(), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithIntInputWithoutOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithoutOutput, kernel<KernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_input(Tensor dummy, int input) -> ()", kernel<KernelWithIntInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithIntInputWithOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithOutput, kernel<KernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_input(Tensor dummy, int input) -> int", kernel<KernelWithIntInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithIntListInputWithoutOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithoutOutput, kernel<KernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", kernel<KernelWithIntListInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithIntListInputWithOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithOutput, kernel<KernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> int", kernel<KernelWithIntListInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorListInputWithoutOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithoutOutput, kernel<KernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
+ .op("_test::tensor_list_input(Tensor[] input) -> ()", kernel<KernelWithTensorListInputWithoutOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
}
};
-FunctionSchema opWithTensorListInputWithOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithOutput, kernel<KernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
+ .op("_test::tensor_list_input(Tensor[] input) -> int", kernel<KernelWithTensorListInputWithOutput>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
ASSERT_TRUE(op.has_value());
int64_t counter;
};
-FunctionSchema opWithCacheSchema(
- "_test::cache_op",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithCache_thenCacheIsKeptCorrectly) {
auto registrar = RegisterOperators()
- .op(opWithCacheSchema, kernel<KernelWithCache>(), dispatchKey(TensorType1()));
+ .op("_test::cache_op(Tensor input) -> int", kernel<KernelWithCache>(), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::cache_op", "");
ASSERT_TRUE(op.has_value());
int64_t offset_;
};
-FunctionSchema opWithConstructorArgsSchema(
- "_test::offset_op",
- "",
- (std::vector<Argument>{Argument("tensor"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithConstructorArg_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithConstructorArgsSchema, kernel<KernelWithConstructorArg>(2), dispatchKey(TensorType1()))
- .op(opWithConstructorArgsSchema, kernel<KernelWithConstructorArg>(4), dispatchKey(TensorType2()));
+ .op("_test::offset_op(Tensor tensor, int input) -> int", kernel<KernelWithConstructorArg>(2), dispatchKey(TensorType1()))
+ .op("_test::offset_op(Tensor tensor, int input) -> int", kernel<KernelWithConstructorArg>(4), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::offset_op", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstructorArgs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithConstructorArgsSchema, kernel<KernelWithMultipleConstructorArgs>(2, 3), dispatchKey(TensorType1()))
- .op(opWithConstructorArgsSchema, kernel<KernelWithMultipleConstructorArgs>(4, 5), dispatchKey(TensorType2()));
+ .op("_test::offset_op(Tensor tensor, int input) -> int", kernel<KernelWithMultipleConstructorArgs>(2, 3), dispatchKey(TensorType1()))
+ .op("_test::offset_op(Tensor tensor, int input) -> int", kernel<KernelWithMultipleConstructorArgs>(4, 5), dispatchKey(TensorType2()));
auto op = c10::Dispatcher::singleton().findSchema("_test::offset_op", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch() -> ()", kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, int arg2) -> int", kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, float arg2) -> int", kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
}, "Type mismatch in argument 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(int arg1, int arg2) -> int", kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
}, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 1"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()),
- Argument("ret2", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 0"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1")})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified Tensor but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
- ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, int)", kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, float)", kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
using c10::RegisterOperators;
-using c10::FunctionSchema;
-using c10::Argument;
-using c10::IntType;
-using c10::FloatType;
-using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
return input + 1;
});
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) {
- auto registrar = RegisterOperators(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ auto registrar = RegisterOperators("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
return input + 1;
});
expectCallsIncrement(TensorType1());
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ .op("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
return input + 1;
})
- .op(errorOpSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ .op("_test::error(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
EXPECT_TRUE(false); // this kernel should never be called
return 0;
});
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
return input + 1;
});
- auto registrar2 = RegisterOperators().op(errorOpSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
EXPECT_TRUE(false); // this kernel should never be called
return 0;
});
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar = RegisterOperators().op(opSchema, [] (const Tensor& tensor, int64_t input) -> int64_t {
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t {
return input + 1;
});
bool was_called = false;
-FunctionSchema opWithoutOutputSchema(
- "_test::no_return",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithoutOutputSchema, [] (const Tensor&) -> void {
+ auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", [] (const Tensor&) -> void {
was_called = true;
});
EXPECT_EQ(0, result.size());
}
-FunctionSchema opWithZeroOutputsSchema(
- "_test::zero_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithZeroOutputsSchema, [] (const Tensor&) -> std::tuple<> {
+ auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", [] (const Tensor&) -> std::tuple<> {
was_called = true;
return std::make_tuple();
});
EXPECT_EQ(0, result.size());
}
-FunctionSchema opWithIntOutputSchema(
- "_test::int_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("a", IntType::get()),
- Argument("b", IntType::get())}),
- (std::vector<Argument>{Argument("sum", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntOutputSchema, [] (Tensor, int64_t a, int64_t b) -> int64_t {
+ .op("_test::int_output(Tensor dummy, int a, int b) -> int", [] (Tensor, int64_t a, int64_t b) -> int64_t {
return a + b;
});
EXPECT_EQ(9, result[0].toInt());
}
-FunctionSchema opWithTensorOutput(
- "_test::returning_tensor",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorOutput, [] (const Tensor& input) -> Tensor {
+ .op("_test::returning_tensor(Tensor input) -> Tensor", [] (const Tensor& input) -> Tensor {
return input;
});
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
}
-FunctionSchema opWithTensorListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("input1"),
- Argument("input2"),
- Argument("input3")}),
- (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListOutputSchema, [] (const Tensor& input1, const Tensor& input2, const Tensor& input3) -> std::vector<Tensor> {
+ .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", [] (const Tensor& input1, const Tensor& input2, const Tensor& input3) -> std::vector<Tensor> {
return {input1, input2, input3};
});
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
}
-FunctionSchema opWithIntListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input1", IntType::get()),
- Argument("input2", IntType::get()),
- Argument("input3", IntType::get())}),
- (std::vector<Argument>{Argument("output", ListType::ofInts())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListOutputSchema, [](const Tensor&, int64_t input1, int64_t input2, int64_t input3) -> std::vector<int64_t> {
+ .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", [](const Tensor&, int64_t input1, int64_t input2, int64_t input3) -> std::vector<int64_t> {
return {input1, input2, input3};
});
EXPECT_EQ(6, result[0].toIntListRef()[2]);
}
-FunctionSchema opWithMultipleOutputsSchema(
- "_test::multiple_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output1"),
- Argument("output2", IntType::get()),
- Argument("output3", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithMultipleOutputsSchema, [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
+ .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[])", [] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
);
EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
}
-FunctionSchema opWithTensorInputWithOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, [] (const Tensor& input1) -> Tensor {
+ .op("_test::tensor_input(Tensor input) -> Tensor", [] (const Tensor& input1) -> Tensor {
return input1;
});
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput, [](Tensor input1) -> Tensor {
+ .op("_test::tensor_input(Tensor input) -> Tensor", [](Tensor input1) -> Tensor {
return input1;
});
Tensor captured_input;
-FunctionSchema opWithTensorInputWithoutOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, [] (const Tensor& input1) -> void {
+ .op("_test::tensor_input(Tensor input) -> ()", [] (const Tensor& input1) -> void {
captured_input = input1;
});
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput, [] (Tensor input1) -> void {
+ .op("_test::tensor_input(Tensor input) -> ()", [] (Tensor input1) -> void {
captured_input = input1;
});
int64_t captured_int_input = 0;
-FunctionSchema opWithIntInputWithoutOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithoutOutput, [](Tensor, int64_t input1) -> void {
+ .op("_test::int_input(Tensor dummy, int input) -> ()", [](Tensor, int64_t input1) -> void {
captured_int_input = input1;
});
EXPECT_EQ(3, captured_int_input);
}
-FunctionSchema opWithIntInputWithOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithOutput, [] (Tensor, int64_t input1) -> int64_t {
+ .op("_test::int_input(Tensor dummy, int input) -> int", [] (Tensor, int64_t input1) -> int64_t {
return input1 + 1;
});
int64_t captured_input_list_size = 0;
-FunctionSchema opWithIntListInputWithoutOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithoutOutput, [] (Tensor, ArrayRef<int64_t> input1) -> void {
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", [] (Tensor, ArrayRef<int64_t> input1) -> void {
captured_input_list_size = input1.size();
});
EXPECT_EQ(3, captured_input_list_size);
}
-FunctionSchema opWithIntListInputWithOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithOutput, [](Tensor, ArrayRef<int64_t> input1) -> int64_t {
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> int", [](Tensor, ArrayRef<int64_t> input1) -> int64_t {
return input1.size();
});
EXPECT_EQ(3, outputs[0].toInt());
}
-FunctionSchema opWithTensorListInputWithoutOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithoutOutput, [] (ArrayRef<Tensor> input1) -> void {
+ .op("_test::tensor_list_input(Tensor[] input) -> ()", [] (ArrayRef<Tensor> input1) -> void {
captured_input_list_size = input1.size();
});
EXPECT_EQ(2, captured_input_list_size);
}
-FunctionSchema opWithTensorListInputWithOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithOutput, [] (ArrayRef<Tensor> input1) -> int64_t {
+ .op("_test::tensor_list_input(Tensor[] input) -> int", [] (ArrayRef<Tensor> input1) -> int64_t {
return input1.size();
});
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> int", [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", [] (Tensor) -> int64_t {return 0;});
}, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", [] (Tensor, Tensor) -> void {});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{}),
- (std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {});
+ .op("_test::mismatch() -> ()", [] (Tensor, Tensor) -> void {});
}, "The number of arguments is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg) -> ()", [] (Tensor, Tensor) -> void {});
}, "The number of arguments is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
- (std::vector<Argument>{})
- ), [] (Tensor, Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", [] (Tensor, Tensor) -> void {});
}, "The number of arguments is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor, int64_t) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg1, int arg2) -> int", [] (Tensor, int64_t) -> int64_t {return 0;});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor, int64_t) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg1, float arg2) -> int", [] (Tensor, int64_t) -> int64_t {return 0;});
}, "Type mismatch in argument 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor, int64_t) -> int64_t {return 0;});
+ .op("_test::mismatch(int arg1, int arg2) -> int", [] (Tensor, int64_t) -> int64_t {return 0;});
}, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> int", [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> ()", [] (Tensor) -> int64_t {return 0;});
}, "The number of returns is different. Specified 0 but inferred 1"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()),
- Argument("ret2", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> (int, int)", [] (Tensor) -> int64_t {return 0;});
}, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), [] (Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg) -> ()", [] (Tensor) -> void {});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), [] (Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg) -> Tensor", [] (Tensor) -> void {});
}, "The number of returns is different. Specified 1 but inferred 0"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), [] (Tensor) -> void {});
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", [] (Tensor) -> void {});
}, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ .op("_test::mismatch(Tensor arg) -> ()", [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
}, "The number of returns is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1")})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ .op("_test::mismatch(Tensor arg) -> Tensor", [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
}, "The number of returns is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", [] (Tensor) -> std::tuple<Tensor, Tensor> {return {};});
}, "The number of returns is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> int", [] (Tensor) -> int64_t {return 0;});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> Tensor", [] (Tensor) -> int64_t {return 0;});
}, "Type mismatch in return 1: specified Tensor but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), [] (Tensor) -> int64_t {return 0;});
+ .op("_test::mismatch(Tensor arg) -> float", [] (Tensor) -> int64_t {return 0;});
}, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), [] (Tensor) -> Tensor {return {};});
+ .op("_test::mismatch(Tensor arg) -> Tensor", [] (Tensor) -> Tensor {return {};});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), [] (Tensor) -> Tensor {return {};});
+ .op("_test::mismatch(Tensor arg) -> float", [] (Tensor) -> Tensor {return {};});
}, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
- ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+ .op("_test::mismatch(Tensor arg) -> (Tensor, int)", [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+ .op("_test::mismatch(Tensor arg) -> (Tensor, float)", [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
}, "Type mismatch in return 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
+ .op("_test::mismatch(Tensor arg) -> (int, int)", [] (Tensor) -> std::tuple<Tensor, int64_t> {return {};});
}, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
#include <ATen/core/Tensor.h>
using c10::RegisterOperators;
-using c10::FunctionSchema;
-using c10::Argument;
-using c10::IntType;
-using c10::FloatType;
-using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenOutOfLineKernel_whenRegistered_thenCanBeCalled) {
auto my_kernel = [] (Tensor, int64_t i) {return i+1;};
- auto registrar = RegisterOperators().op(opSchema, kernel(my_kernel), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(my_kernel), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
- .op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()))
- .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()))
- .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
- auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
- auto registrar3 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()));
- auto registrar4 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+ auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()));
+ auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
{
- auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2()));
// assert that schema and cpu kernel are present
expectCallsIncrement(TensorType1());
bool was_called = false;
-FunctionSchema opWithoutOutputSchema(
- "_test::no_return",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithoutOutputSchema,
+ auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()",
kernel([] (const Tensor&) -> void {was_called = true;}),
dispatchKey(TensorType1()));
EXPECT_EQ(0, result.size());
}
-FunctionSchema opWithZeroOutputsSchema(
- "_test::zero_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opWithZeroOutputsSchema,
+ auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()",
kernel([] (const Tensor&) -> std::tuple<> {was_called = true; return {};}),
dispatchKey(TensorType1()));
EXPECT_EQ(0, result.size());
}
-FunctionSchema opWithIntOutputSchema(
- "_test::int_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("a", IntType::get()),
- Argument("b", IntType::get())}),
- (std::vector<Argument>{Argument("sum", IntType::get())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntOutputSchema,
+ .op("_test::int_output(Tensor dummy, int a, int b) -> int",
kernel([] (Tensor, int64_t a, int64_t b) {return a+b;}),
dispatchKey(TensorType1()));
EXPECT_EQ(9, result[0].toInt());
}
-FunctionSchema opWithTensorOutput(
- "_test::returning_tensor",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorOutput,
+ .op("_test::returning_tensor(Tensor input) -> Tensor",
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType1()))
- .op(opWithTensorOutput,
+ .op("_test::returning_tensor(Tensor input) -> Tensor",
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType2()));
EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
}
-FunctionSchema opWithTensorListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("input1"),
- Argument("input2"),
- Argument("input3")}),
- (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListOutputSchema,
+ .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]",
kernel([] (const Tensor& a, const Tensor& b, const Tensor& c) -> std::vector<Tensor> {return {a, b, c};}),
dispatchKey(TensorType1()));
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
}
-FunctionSchema opWithIntListOutputSchema(
- "_test::list_output",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input1", IntType::get()),
- Argument("input2", IntType::get()),
- Argument("input3", IntType::get())}),
- (std::vector<Argument>{Argument("output", ListType::ofInts())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListOutputSchema,
+ .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]",
kernel([] (const Tensor&, int64_t a, int64_t b, int64_t c) -> std::vector<int64_t> {return {a,b,c};}),
dispatchKey(TensorType1()));
EXPECT_EQ(6, result[0].toIntListRef()[2]);
}
-FunctionSchema opWithMultipleOutputsSchema(
- "_test::multiple_outputs",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output1"),
- Argument("output2", IntType::get()),
- Argument("output3", ListType::ofTensors())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithMultipleOutputsSchema,
+ .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[])",
kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
}
-FunctionSchema opWithTensorInputWithOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{Argument("output")}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput,
+ .op("_test::tensor_input(Tensor input) -> Tensor",
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput,
+ .op("_test::tensor_input(Tensor input) -> Tensor",
kernel([] (const Tensor& a) {return a;}),
dispatchKey(TensorType2()));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithOutput,
+ .op("_test::tensor_input(Tensor input) -> Tensor",
kernel([] (Tensor a) {return a;}),
dispatchKey(TensorType1()))
- .op(opWithTensorInputWithOutput,
+ .op("_test::tensor_input(Tensor input) -> Tensor",
kernel([] (Tensor a) {return a;}),
dispatchKey(TensorType2()));
Tensor captured_input;
-FunctionSchema opWithTensorInputWithoutOutput(
- "_test::tensor_input",
- "",
- (std::vector<Argument>{Argument("input")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput,
+ .op("_test::tensor_input(Tensor input) -> ()",
kernel([] (const Tensor& a) -> void {captured_input = a;}),
dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput,
+ .op("_test::tensor_input(Tensor input) -> ()",
kernel([] (const Tensor& a) -> void {captured_input = a;}),
dispatchKey(TensorType2()));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorInputWithoutOutput,
+ .op("_test::tensor_input(Tensor input) -> ()",
kernel([] (Tensor a) -> void {captured_input = a;}),
dispatchKey(TensorType1()))
- .op(opWithTensorInputWithoutOutput,
+ .op("_test::tensor_input(Tensor input) -> ()",
kernel([] (Tensor a) -> void {captured_input = a;}),
dispatchKey(TensorType2()));
int64_t captured_int_input = 0;
-FunctionSchema opWithIntInputWithoutOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithoutOutput,
+ .op("_test::int_input(Tensor dummy, int input) -> ()",
kernel([] (Tensor, int64_t a) -> void {captured_int_input = a;}),
dispatchKey(TensorType1()));
EXPECT_EQ(3, captured_int_input);
}
-FunctionSchema opWithIntInputWithOutput(
- "_test::int_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntInputWithOutput,
+ .op("_test::int_input(Tensor dummy, int input) -> int",
kernel([] (Tensor, int64_t a) {return a + 1;}),
dispatchKey(TensorType1()));
int64_t captured_input_list_size = 0;
-FunctionSchema opWithIntListInputWithoutOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithoutOutput,
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> ()",
kernel([] (Tensor, ArrayRef<int64_t> a) {captured_input_list_size = a.size();}),
dispatchKey(TensorType1()));
EXPECT_EQ(3, captured_input_list_size);
}
-FunctionSchema opWithIntListInputWithOutput(
- "_test::int_list_input",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", ListType::ofInts())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithIntListInputWithOutput,
+ .op("_test::int_list_input(Tensor dummy, int[] input) -> int",
kernel([] (Tensor, ArrayRef<int64_t> a) -> int64_t {return a.size();}),
dispatchKey(TensorType1()));
EXPECT_EQ(3, outputs[0].toInt());
}
-FunctionSchema opWithTensorListInputWithoutOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithoutOutput,
+ .op("_test::tensor_list_input(Tensor[] input) -> ()",
kernel([] (ArrayRef<Tensor> a) -> void {captured_input_list_size = a.size();}),
dispatchKey(TensorType1()));
EXPECT_EQ(2, captured_input_list_size);
}
-FunctionSchema opWithTensorListInputWithOutput(
- "_test::tensor_list_input",
- "",
- (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
auto registrar = RegisterOperators()
- .op(opWithTensorListInputWithOutput,
+ .op("_test::tensor_list_input(Tensor[] input) -> int",
kernel([] (ArrayRef<Tensor> a) -> int64_t {return a.size();}),
dispatchKey(TensorType1()));
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch() -> ()", kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
}, "The number of arguments is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, int arg2) -> int", kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg1, float arg2) -> int", kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in argument 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(int arg1, int arg2) -> int", kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in argument 1: specified int but inferred Tensor"
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 1"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()),
- Argument("ret2", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 1"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 0"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret"), Argument("ret2")})
- ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 2 but inferred 0"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> ()", kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 0 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1")})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 1 but inferred 2"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
- ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
}, "The number of returns is different. Specified 3 but inferred 2"
);
}
TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", IntType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> int", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified Tensor but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred int"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret")})
- ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> Tensor", kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret", FloatType::get())})
- ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> float", kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified float but inferred Tensor"
);
// assert this does not fail because it matches
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
- ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, int)", kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
- ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (Tensor, float)", kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in return 2: specified float but inferred int"
);
expectThrows<c10::Error>([] {
RegisterOperators()
- .op(FunctionSchema(
- "_test::mismatch",
- "",
- (std::vector<Argument>{Argument("arg")}),
- (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
- ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+ .op("_test::mismatch(Tensor arg) -> (int, int)", kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
}, "Type mismatch in return 1: specified int but inferred Tensor"
);
}
#include <ATen/core/Tensor.h>
using c10::RegisterOperators;
-using c10::FunctionSchema;
-using c10::Argument;
-using c10::IntType;
using c10::kernel;
using c10::dispatchKey;
using c10::TensorTypeId;
EXPECT_TRUE(false); // this kernel should never be called
}
-FunctionSchema errorOpSchema(
- "_test::error",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
void incrementKernel(Stack* stack, KernelCache* cache) {
int input = torch::jit::pop(*stack).toInt();
torch::jit::pop(*stack); // pop the dummy tensor
torch::jit::push(*stack, input - 1);
}
-FunctionSchema opSchema(
- "_test::my_op",
- "",
- (std::vector<Argument>{Argument("dummy"),
- Argument("input", IntType::get())}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
void expectCallsIncrement(TensorTypeId type_id) {
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
}
TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
- auto registrar = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
- .op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()))
- .op(opSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()))
- .op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType1()))
- .op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()))
+ .op("_test::my_op(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType2()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType1()))
+ .op("_test::error(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
- auto registrar1 = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
- auto registrar2 = RegisterOperators().op(opSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
- auto registrar3 = RegisterOperators().op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType1()));
- auto registrar4 = RegisterOperators().op(errorOpSchema, kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
+ auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType1()));
+ auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", kernel(&errorKernel, &noCache), dispatchKey(TensorType2()));
expectCallsIncrement(TensorType1());
}
TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
{
- auto registrar1 = RegisterOperators().op(opSchema, kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
+ auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(&incrementKernel, &noCache), dispatchKey(TensorType1()));
{
- auto registrar2 = RegisterOperators().op(opSchema, kernel(&decrementKernel, &noCache), dispatchKey(TensorType2()));
+ auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", kernel(&decrementKernel, &noCache), dispatchKey(TensorType2()));
// assert that schema and cpu kernel are present
expectCallsIncrement(TensorType1());
torch::jit::push(*stack, static_cast<Cache*>(cache)->last_value++);
}
-FunctionSchema incrementSequenceOpSchema(
- "_test::increment_sequence",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{Argument("output", IntType::get())}));
-
-
TEST(OperatorRegistrationTest_StackBasedKernel, givenKernelWithCache_whenCalled_thenCacheIsHandledCorrectly) {
- auto registrar = RegisterOperators().op(incrementSequenceOpSchema, kernel(&increment_sequence_kernel, &make_cache), dispatchKey(TensorType1()));
+ auto registrar = RegisterOperators().op("_test::increment_sequence(Tensor dummy) -> int", kernel(&increment_sequence_kernel, &make_cache), dispatchKey(TensorType1()));
auto op = c10::Dispatcher::singleton().findSchema("_test::increment_sequence", "");
ASSERT_TRUE(op.has_value());
#include <ATen/core/op_registration/op_registration.h>
+#include <torch/csrc/jit/script/function_schema_parser.h>
namespace c10 {
bool owns_registration_;
};
+void RegisterOperators::registerOp_(const std::string& schemaStr, detail::KernelRegistrationConfig&& config) {
+ registerOp_(torch::jit::parseSchema(schemaStr), std::move(config));
+}
+
void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
AT_CHECK(!config.dispatch_key.has_value() || config.kernel_func != nullptr,
"Tried to register an operator with a dispatch key but without a kernel. "
* > c10::dispatchKey(CPUTensorId()));
*/
template<class... ConfigParameters>
+ RegisterOperators op(const std::string& schema, ConfigParameters&&... configParameters) && {
+ static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
+ "Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
+ " and c10::dispatchKey(...) for defining a dispatch key. Please see the documentation for registering c10 operators.");
+
+ op_(schema, std::forward<ConfigParameters>(configParameters)...);
+ return std::move(*this);
+ }
+
+ // This FunctionSchema based variant is only meant to be used for internal
+ // purposes when we already have a pre-parsed FunctionSchema.
+ // This is for example used for exposing legacy caffe2 operators to c10.
+ template<class... ConfigParameters>
RegisterOperators op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
static_assert(guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value,
"Invalid argument passed to op(). Examples for valid arguments are c10::kernel(...) for defining a kernel "
C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
"Please use RegisterOperators().op(...) instead.")
// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
- explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, FunctionSchema> schema, FuncType* func)
+ explicit RegisterOperators(guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, const std::string&> schema, FuncType* func)
: RegisterOperators() {
- legacyAPIOp_(std::move(schema), func);
+ legacyAPIOp_(schema, func);
}
template<class FuncType>
C10_DEPRECATED_MESSAGE("Registering kernels via passing arguments to RegisterOperators(...) is deprecated. " \
"Please use RegisterOperators().op(...) instead.")
// enable_if: only enable it if FuncType is actually a functor
- explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, FunctionSchema> schema, FuncType&& func)
+ explicit RegisterOperators(guts::enable_if_t<guts::is_functor<FuncType>::value, const std::string&> schema, FuncType&& func)
: RegisterOperators() {
- legacyAPIOp_(std::move(schema), std::forward<FuncType>(func));
+ legacyAPIOp_(schema, std::forward<FuncType>(func));
}
/**
"Please use the new c10::kernel() based API instead.")
// enable_if: only enable it if FuncType is actually a function, but not a stack based KernelFunction.
guts::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction>::value, RegisterOperators>
- op(FunctionSchema schema, FuncType* func, OtherArgs...) && {
+ op(const std::string& schema, FuncType* func, OtherArgs...) && {
// We intentionally don't extend this deprecated API to support dispatch keys
// and the like to push people towards using the new API.
static_assert(sizeof...(OtherArgs) == 0, "The deprecated function pointer based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
- legacyAPIOp_(std::move(schema), func);
+ legacyAPIOp_(schema, func);
return std::move(*this);
}
"Please use the new c10::kernel() based API instead.")
// enable_if: only enable it if FuncType is actually a functor
guts::enable_if_t<guts::is_functor<FuncType>::value, RegisterOperators>
- op(FunctionSchema schema, FuncType&& func, OtherArgs...) && {
+ op(const std::string& schema, FuncType&& func, OtherArgs...) && {
// We intentionally don't extend this deprecated API to support dispatch keys
// and the like to push people towards using the new API.
static_assert(sizeof...(OtherArgs) == 0, "The deprecated lambda based API to register kernels doesn't allow additional arguments for dispatch keys or other things. Please use the new c10::kernel() based API instead.");
static_assert(!std::is_base_of<OperatorKernel, FuncType>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new c10::kernel() based API instead.");
- legacyAPIOp_(std::move(schema), std::forward<FuncType>(func));
+ legacyAPIOp_(schema, std::forward<FuncType>(func));
return std::move(*this);
}
void op_(FunctionSchema&& schema, ConfigParameters&&... configParameters) {
registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
}
+ template<class... ConfigParameters>
+ void op_(const std::string& schema, ConfigParameters&&... configParameters) {
+ registerOp_(schema, detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
+ }
template<class FuncType>
- void legacyAPIOp_(FunctionSchema&& schema, FuncType&& func) {
- op_(std::move(schema), kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
+ void legacyAPIOp_(const std::string& schema, FuncType&& func) {
+ op_(schema, kernel<detail::WrapRuntimeKernelFunctor<guts::decay_t<FuncType>>>(std::forward<FuncType>(func)));
}
void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
+ void registerOp_(const std::string& schema, detail::KernelRegistrationConfig&& config);
class OperatorRegistrar;
using c10::RegisterOperators;
using c10::OperatorKernel;
-using c10::FunctionSchema;
-using c10::Argument;
using c10::kernel;
using c10::dispatchKey;
using c10::Dispatcher;
private:
bool* called_;
};
-
-FunctionSchema dummySchema(
- "_test::dummy",
- "",
- (std::vector<Argument>{Argument("dummy")}),
- (std::vector<Argument>{}));
-
TEST(OperatorRegistrationTest, givenOpWithoutFallbackKernel_whenCallingOpWithWrongDispatchKey_thenFails) {
- auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
}
TEST(OperatorRegistrationTest, givenOpWithFallbackKernelOutOfScope_whenCallingOpWithWrongDispatchKey_thenFails) {
- auto registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
{
- auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>());
+ auto inner_registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>());
// this registered a fallback kernel, but now that registration goes out of scope and deregisters it
}
TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernel_whenCallingOp_thenCallsFallbackKernel) {
bool called = false;
- auto registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
TEST(OperatorRegistrationTest, givenOpWithOnlyFallbackKernelAndOtherKernelOutOfScope_whenCallingOp_thenCallsFallbackKernel) {
bool called = false;
bool other_called = false;
- auto registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called)); // note: no dispatch key means this is the fallback kernel
{
- auto inner_registrar = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&other_called), dispatchKey(TensorType2()));
+ auto inner_registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&other_called), dispatchKey(TensorType2()));
}
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
- .op(dummySchema, kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
- .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
- .op(dummySchema, kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
- .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_fallback)) // note: no dispatch key means this is the fallback kernel
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
- .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
- .op(dummySchema, kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
bool called_kernel = false;
bool called_fallback = false;
auto registrar = c10::RegisterOperators()
- .op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
- .op(dummySchema, kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()))
+ .op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_fallback)); // note: no dispatch key means this is the fallback kernel
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value());
}
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegistering_thenOnlyRegistersSchema) {
- auto registrar = c10::RegisterOperators().op(dummySchema);
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value()); // assert schema is registered
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRunningOutOfScope_thenSchemaIsGone) {
{
- auto registrar = c10::RegisterOperators().op(dummySchema);
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
}
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
}
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwards_thenCanBeCalled) {
- auto registrar1 = c10::RegisterOperators().op(dummySchema);
+ auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
bool called_kernel = false;
- auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
+ auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<MockKernel>(&called_kernel), dispatchKey(TensorType1()));
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");
ASSERT_TRUE(op.has_value()); // assert schema is registered
}
TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterwardsAndRunsOutOfScope_thenSchemaIsStillThereButCannotBeCalledAnymore) {
- auto registrar1 = c10::RegisterOperators().op(dummySchema);
+ auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()");
{
- auto registrar2 = c10::RegisterOperators().op(dummySchema, kernel<DummyKernel>(), dispatchKey(TensorType1()));
+ auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", kernel<DummyKernel>(), dispatchKey(TensorType1()));
}
auto op = Dispatcher::singleton().findSchema("_test::dummy", "");