Introduce lambda-based kernel API (#18541)
authorSebastian Messmer <messmer@fb.com>
Sat, 30 Mar 2019 07:03:46 +0000 (00:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 07:07:31 +0000 (00:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18541

Allow registering lambdas as c10 kernels.

Reviewed By: dzhulgakov

Differential Revision: D14653005

fbshipit-source-id: f867cc776b1339e83b7a2e1935f5cf924cfba44a

aten/src/ATen/core/op_registration/kernel_functor.h
aten/src/ATen/core/op_registration/kernel_lambda.h [new file with mode: 0644]
aten/src/ATen/core/op_registration/kernel_lambda_test.cpp [new file with mode: 0644]
aten/src/ATen/core/op_registration/op_registration.h
c10/test/util/TypeTraits_test.cpp
c10/util/Metaprogramming.h
c10/util/TypeTraits.h

index 86285a1..427de09 100644 (file)
@@ -138,6 +138,18 @@ namespace detail {
       return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
     }
   };
+
+  template<class KernelFunctor, class... ConstructorParameters>
+  // enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
+  inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
+  detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
+  kernelFunctor(ConstructorParameters&&... constructorParameters) {
+    return {
+      &detail::wrap_kernel_functor<KernelFunctor>::call,
+      detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
+      detail::FunctionSchemaInferer<KernelFunctor>()
+    };
+  }
 }
 
 /**
@@ -181,11 +193,7 @@ template<class KernelFunctor, class... ConstructorParameters>
 inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
 detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
 kernel(ConstructorParameters&&... constructorParameters) {
-  return {
-    &detail::wrap_kernel_functor<KernelFunctor>::call,
-    detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
-    detail::FunctionSchemaInferer<KernelFunctor>()
-  };
+  return detail::kernelFunctor<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...);
 }
 
 }
diff --git a/aten/src/ATen/core/op_registration/kernel_lambda.h b/aten/src/ATen/core/op_registration/kernel_lambda.h
new file mode 100644 (file)
index 0000000..a5b8244
--- /dev/null
@@ -0,0 +1,60 @@
+#pragma once
+
+#include <ATen/core/op_registration/kernel_functor.h>
+#include <c10/util/TypeTraits.h>
+
+namespace c10 {
+
+namespace detail {
+  // WrapRuntimeKernelFunctor: Wraps any runtime functor into a functor that
+  // inherits from c10::OperatorKernel, so it can be used as a c10 kernel.
+  // This can, for example, be used for lamdas, functors or even function pointers.
+  // In the case of function pointers, since it is a runtime function pointer,
+  // there is an overhead for calling it whenever the kernel is invoked.
+  template<class FuncType, class ReturnType, class ParameterList> class WrapRuntimeKernelFunctor_ {};
+  template<class FuncType, class ReturnType, class... Parameters>
+  class WrapRuntimeKernelFunctor_<FuncType, ReturnType, guts::typelist::typelist<Parameters...>> final : public c10::OperatorKernel {
+  public:
+    template<class FuncType_>
+    explicit WrapRuntimeKernelFunctor_(FuncType_&& kernel_func)
+    : kernel_func_(std::forward<FuncType_>(kernel_func)) {}
+
+    auto operator()(Parameters&&... args) -> decltype(std::declval<FuncType>()(std::forward<Parameters>(args)...)) {
+      return kernel_func_(std::forward<Parameters>(args)...);
+    }
+
+  private:
+    FuncType kernel_func_;
+  };
+  template<class FuncType>
+  using WrapRuntimeKernelFunctor = WrapRuntimeKernelFunctor_<
+      FuncType,
+      typename guts::infer_function_traits_t<FuncType>::return_type,
+      typename guts::infer_function_traits_t<FuncType>::parameter_types
+  >;
+}
+
+/**
+ * Use this to register an operator whose kernel is implemented as a stateless lambda.
+ *
+ * Example:
+ *
+ * > static auto registry = c10::RegisterOperators()
+ * >     .op("my_op",
+ * >         c10::kernel([] (Tensor a) -> Tensor{...}),
+ * >         c10::dispatchKey(CPUTensorId()));
+ */
+template<class Lambda>
+inline constexpr auto kernel(Lambda&& functor) ->
+guts::enable_if_t<guts::is_stateless_lambda<guts::decay_t<Lambda>>::value,
+decltype(detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor)))> {
+  // We don't support stateful lambdas (i.e. lambdas with a capture), because their
+  // behavior would be nonobvious. A functor kernel with cache gets a new instance of
+  // its cache each time the kernel is looked up from the dispatch table.
+  // A lambda with a capture would be global and share its capture between all kernel lookups.
+  // So, instead of making users having to think about it (including the thread-safety
+  // issues this causes), let's just forbid stateful lambdas alltogether.
+  return detail::kernelFunctor<detail::WrapRuntimeKernelFunctor<guts::decay_t<Lambda>>>(std::forward<Lambda>(functor));
+}
+
+}
diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp
new file mode 100644 (file)
index 0000000..d9defec
--- /dev/null
@@ -0,0 +1,803 @@
+#include <gtest/gtest.h>
+#include <ATen/core/op_registration/test_helpers.h>
+
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/core/Tensor.h>
+
+using c10::RegisterOperators;
+using c10::FunctionSchema;
+using c10::Argument;
+using c10::IntType;
+using c10::FloatType;
+using c10::ListType;
+using c10::kernel;
+using c10::dispatchKey;
+using c10::TensorTypeId;
+using c10::KernelCache;
+using c10::Stack;
+using c10::guts::make_unique;
+using c10::ivalue::TensorList;
+using c10::ivalue::IntList;
+using c10::intrusive_ptr;
+using c10::ArrayRef;
+using std::unique_ptr;
+using at::Tensor;
+
+namespace {
+
+C10_DECLARE_TENSOR_TYPE(TensorType1);
+C10_DEFINE_TENSOR_TYPE(TensorType1);
+C10_DECLARE_TENSOR_TYPE(TensorType2);
+C10_DEFINE_TENSOR_TYPE(TensorType2);
+
+FunctionSchema errorOpSchema(
+    "_test::error",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+FunctionSchema opSchema(
+    "_test::my_op",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+void expectCallsIncrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(6, result[0].toInt());
+}
+
+void expectCallsDecrement(TensorTypeId type_id) {
+  // assert that schema and cpu kernel are present
+  auto op = c10::Dispatcher::singleton().findSchema("_test::my_op", "");
+  ASSERT_TRUE(op.has_value());
+  auto result = callOp(*op, dummyTensor(type_id), 5);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(4, result[0].toInt());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
+  auto registrar = RegisterOperators()
+      .op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()))
+      .op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()))
+      .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()))
+      .op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) {
+  auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+  auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+  auto registrar3 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType1()));
+  auto registrar4 = RegisterOperators().op(errorOpSchema, kernel([] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;}), dispatchKey(TensorType2()));
+  expectCallsIncrement(TensorType1());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) {
+  {
+    auto registrar1 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i+1;}), dispatchKey(TensorType1()));
+    {
+      auto registrar2 = RegisterOperators().op(opSchema, kernel([] (Tensor, int64_t i) {return i-1;}), dispatchKey(TensorType2()));
+
+      // assert that schema and cpu kernel are present
+      expectCallsIncrement(TensorType1());
+      expectCallsDecrement(TensorType2());
+    }
+
+    // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not
+    expectCallsIncrement(TensorType1());
+    expectDoesntFindKernel("_test::my_op", TensorType2());
+  }
+
+  // now both registrars are destructed. Assert that the whole schema is gone
+  expectDoesntFindOperator("_test::my_op");
+}
+
+bool was_called = false;
+
+FunctionSchema opWithoutOutputSchema(
+    "_test::no_return",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithoutOutputSchema,
+    kernel([] (const Tensor&) -> void {was_called = true;}),
+    dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::no_return", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+FunctionSchema opWithZeroOutputsSchema(
+    "_test::zero_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op(opWithZeroOutputsSchema,
+    kernel([] (const Tensor&) -> std::tuple<> {was_called = true; return {};}),
+    dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::zero_outputs", "");
+  ASSERT_TRUE(op.has_value());
+  was_called = false;
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_TRUE(was_called);
+  EXPECT_EQ(0, result.size());
+}
+
+FunctionSchema opWithIntOutputSchema(
+    "_test::int_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("a", IntType::get()),
+                           Argument("b", IntType::get())}),
+    (std::vector<Argument>{Argument("sum", IntType::get())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntOutputSchema,
+        kernel([] (Tensor, int64_t a, int64_t b) {return a+b;}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 3, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(9, result[0].toInt());
+}
+
+FunctionSchema opWithTensorOutput(
+    "_test::returning_tensor",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorOutput,
+        kernel([] (const Tensor& a) {return a;}),
+        dispatchKey(TensorType1()))
+      .op(opWithTensorOutput,
+        kernel([] (const Tensor& a) {return a;}),
+        dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::returning_tensor", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+FunctionSchema opWithTensorListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("input1"),
+                           Argument("input2"),
+                           Argument("input3")}),
+    (std::vector<Argument>{Argument("output", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListOutputSchema,
+        kernel([] (const Tensor& a, const Tensor& b, const Tensor& c) -> std::vector<Tensor> {return {a, b, c};}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), dummyTensor(TensorType2()), dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[0].toTensorListRef()[1].type_id());
+  EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
+}
+
+FunctionSchema opWithIntListOutputSchema(
+    "_test::list_output",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input1", IntType::get()),
+                           Argument("input2", IntType::get()),
+                           Argument("input3", IntType::get())}),
+    (std::vector<Argument>{Argument("output", ListType::ofInts())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListOutputSchema,
+        kernel([] (const Tensor&, int64_t a, int64_t b, int64_t c) -> std::vector<int64_t> {return {a,b,c};}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::list_output", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()), 2, 4, 6);
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(3, result[0].toIntListRef().size());
+  EXPECT_EQ(2, result[0].toIntListRef()[0]);
+  EXPECT_EQ(4, result[0].toIntListRef()[1]);
+  EXPECT_EQ(6, result[0].toIntListRef()[2]);
+}
+
+FunctionSchema opWithMultipleOutputsSchema(
+    "_test::multiple_outputs",
+    "",
+    (std::vector<Argument>{Argument("dummy")}),
+    (std::vector<Argument>{Argument("output1"),
+                           Argument("output2", IntType::get()),
+                           Argument("output3", ListType::ofTensors())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+     .op(opWithMultipleOutputsSchema,
+       kernel([] (Tensor) -> std::tuple<Tensor, int64_t, std::vector<Tensor>> {
+         return std::tuple<Tensor, int64_t, std::vector<Tensor>>(
+           dummyTensor(TensorType2()), 5, {dummyTensor(TensorType1()), dummyTensor(TensorType2())}
+         );
+       }),
+       dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::multiple_outputs", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(3, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+  EXPECT_EQ(5, result[1].toInt());
+  EXPECT_EQ(2, result[2].toTensorListRef().size());
+  EXPECT_EQ(TensorType1(), result[2].toTensorListRef()[0].type_id());
+  EXPECT_EQ(TensorType2(), result[2].toTensorListRef()[1].type_id());
+}
+
+FunctionSchema opWithTensorInputWithOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{Argument("output")}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput,
+        kernel([] (const Tensor& a) {return a;}),
+        dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput,
+        kernel([] (const Tensor& a) {return a;}),
+        dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithOutput,
+        kernel([] (Tensor a) {return a;}),
+        dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithOutput,
+        kernel([] (Tensor a) {return a;}),
+        dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto result = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType1(), result[0].toTensor().type_id());
+
+  result = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(1, result.size());
+  EXPECT_EQ(TensorType2(), result[0].toTensor().type_id());
+}
+
+Tensor captured_input;
+
+FunctionSchema opWithTensorInputWithoutOutput(
+    "_test::tensor_input",
+    "",
+    (std::vector<Argument>{Argument("input")}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput,
+        kernel([] (const Tensor& a) -> void {captured_input = a;}),
+        dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput,
+        kernel([] (const Tensor& a) -> void {captured_input = a;}),
+        dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorInputWithoutOutput,
+        kernel([] (Tensor a) -> void {captured_input = a;}),
+        dispatchKey(TensorType1()))
+      .op(opWithTensorInputWithoutOutput,
+        kernel([] (Tensor a) -> void {captured_input = a;}),
+        dispatchKey(TensorType2()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType1(), captured_input.type_id());
+
+  outputs = callOp(*op, dummyTensor(TensorType2()));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(TensorType2(), captured_input.type_id());
+}
+
+int64_t captured_int_input = 0;
+
+FunctionSchema opWithIntInputWithoutOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithoutOutput,
+        kernel([] (Tensor, int64_t a) -> void {captured_int_input = a;}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_int_input = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_int_input);
+}
+
+FunctionSchema opWithIntInputWithOutput(
+    "_test::int_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", IntType::get())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntInputWithOutput,
+        kernel([] (Tensor, int64_t a) {return a + 1;}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), 3);
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(4, outputs[0].toInt());
+}
+
+int64_t captured_input_list_size = 0;
+
+FunctionSchema opWithIntListInputWithoutOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithoutOutput,
+        kernel([] (Tensor, ArrayRef<int64_t> a) {captured_input_list_size = a.size();}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(3, captured_input_list_size);
+}
+
+FunctionSchema opWithIntListInputWithOutput(
+    "_test::int_list_input",
+    "",
+    (std::vector<Argument>{Argument("dummy"),
+                           Argument("input", ListType::ofInts())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithIntListInputWithOutput,
+        kernel([] (Tensor, ArrayRef<int64_t> a) -> int64_t {return a.size();}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::int_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, dummyTensor(TensorType1()), IntList::create({2, 4, 6}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(3, outputs[0].toInt());
+}
+
+FunctionSchema opWithTensorListInputWithoutOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithoutOutput,
+        kernel([] (ArrayRef<Tensor> a) -> void {captured_input_list_size = a.size();}),
+        dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  captured_input_list_size = 0;
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(0, outputs.size());
+  EXPECT_EQ(2, captured_input_list_size);
+}
+
+FunctionSchema opWithTensorListInputWithOutput(
+    "_test::tensor_list_input",
+    "",
+    (std::vector<Argument>{Argument("input", ListType::ofTensors())}),
+    (std::vector<Argument>{Argument("output", IntType::get())}));
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) {
+  auto registrar = RegisterOperators()
+      .op(opWithTensorListInputWithOutput,
+         kernel([] (ArrayRef<Tensor> a) -> int64_t {return a.size();}),
+         dispatchKey(TensorType1()));
+
+  auto op = c10::Dispatcher::singleton().findSchema("_test::tensor_list_input", "");
+  ASSERT_TRUE(op.has_value());
+
+  auto outputs = callOp(*op, TensorList::create({dummyTensor(TensorType1()), dummyTensor(TensorType1())}));
+  EXPECT_EQ(1, outputs.size());
+  EXPECT_EQ(2, outputs[0].toInt());
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg"), Argument("arg2")}),
+          (std::vector<Argument>{})
+      ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{}),
+            (std::vector<Argument>{})
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
+            (std::vector<Argument>{})
+        ), kernel([] (Tensor, Tensor) -> void {}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
+            (std::vector<Argument>{Argument("ret", IntType::get())})
+        ), kernel([] (Tensor, int64_t) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()),
+                                   Argument("ret2", IntType::get())})
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{})
+      ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret"), Argument("ret2")})
+        ), kernel([] (Tensor) -> void {}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2")})
+      ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{})
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1")})
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
+        ), kernel([] (Tensor) -> std::tuple<Tensor, Tensor> {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret", IntType::get())})
+      ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret")})
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel([] (Tensor) -> int64_t {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret")})
+      ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret", FloatType::get())})
+        ), kernel([] (Tensor) -> Tensor {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  // assert this does not fail because it matches
+  RegisterOperators()
+      .op(FunctionSchema(
+          "_test::mismatch",
+          "",
+          (std::vector<Argument>{Argument("arg")}),
+          (std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
+      ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1()));
+
+  // and now a set of mismatching schemas
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
+        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+
+  EXPECT_THROW(
+    RegisterOperators()
+        .op(FunctionSchema(
+            "_test::mismatch",
+            "",
+            (std::vector<Argument>{Argument("arg")}),
+            (std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
+        ), kernel([] (Tensor) -> std::tuple<Tensor, int64_t> {return {};}), dispatchKey(TensorType1())),
+    c10::Error
+  );
+}
+
+}
index 7861eee..29b4895 100644 (file)
@@ -10,6 +10,7 @@
 #include <ATen/core/op_registration/kernel_stackbased.h>
 #include <ATen/core/op_registration/kernel_functor.h>
 #include <ATen/core/op_registration/kernel_function.h>
+#include <ATen/core/op_registration/kernel_lambda.h>
 #include <ATen/core/op_registration/infer_schema.h>
 
 namespace c10 {
index e8c2bdb..c93c81e 100644 (file)
@@ -112,3 +112,44 @@ namespace test_is_type_condition {
     static_assert(!is_type_condition<NotATypeCondition>::value, "");
 }
 }
+
+namespace test_lambda_is_stateless {
+  template<class Result, class... Args>
+  struct MyStatelessFunctor final {
+    Result operator()(Args...) {}
+  };
+
+  template<class Result, class... Args>
+  struct MyStatelessConstFunctor final {
+    Result operator()(Args...) const {}
+  };
+
+  void func() {
+    auto stateless_lambda = [] (int a) {return a;};
+    static_assert(is_stateless_lambda<decltype(stateless_lambda)>::value, "");
+
+    int b = 4;
+    auto stateful_lambda_1 = [&] (int a) {return a + b;};
+    static_assert(!is_stateless_lambda<decltype(stateful_lambda_1)>::value, "");
+
+    auto stateful_lambda_2 = [=] (int a) {return a + b;};
+    static_assert(!is_stateless_lambda<decltype(stateful_lambda_2)>::value, "");
+
+    auto stateful_lambda_3 = [b] (int a) {return a + b;};
+    static_assert(!is_stateless_lambda<decltype(stateful_lambda_3)>::value, "");
+
+    static_assert(!is_stateless_lambda<MyStatelessFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
+    static_assert(!is_stateless_lambda<MyStatelessFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
+    static_assert(!is_stateless_lambda<MyStatelessConstFunctor<int, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
+    static_assert(!is_stateless_lambda<MyStatelessConstFunctor<void, int>>::value, "even if stateless, a functor is not a lambda, so it's false");
+
+    class Dummy final {};
+    static_assert(!is_stateless_lambda<Dummy>::value, "A non-functor type is also not a lambda");
+
+    static_assert(!is_stateless_lambda<int>::value, "An int is not a lambda");
+
+    using Func = int(int);
+    static_assert(!is_stateless_lambda<Func>::value, "A function is not a lambda");
+    static_assert(!is_stateless_lambda<Func*>::value, "A function pointer is not a lambda");
+  }
+}
index 60e4553..645b6cf 100644 (file)
@@ -7,24 +7,6 @@
 #include <c10/util/Array.h>
 
 namespace c10 { namespace guts {
-namespace detail {
-/**
- * strip_class: helper to remove the class type from pointers to `operator()`.
- */
-
-template <typename T>
-struct strip_class {};
-template <typename Class, typename Result, typename... Args>
-struct strip_class<Result (Class::*)(Args...)> {
-  using type = Result(Args...);
-};
-template <typename Class, typename Result, typename... Args>
-struct strip_class<Result (Class::*)(Args...) const> {
-  using type = Result(Args...);
-};
-template <typename T>
-using strip_class_t = typename strip_class<T>::type;
-} // namespace detail
 
 /**
  * Access information about result type or arguments from a function type.
@@ -44,16 +26,6 @@ struct function_traits<Result (Args...)> {
 };
 
 /**
- * Evaluates to true_type, iff the given class is a Functor
- * (i.e. has a call operator with some set of arguments)
- */
-
-template<class Functor, class Enable = void>
-struct is_functor : std::false_type {};
-template<class Functor>
-struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
-
-/**
  * infer_function_traits: creates a `function_traits` type for a simple
  * function (pointer) or functor (lambda/struct). Currently does not support
  * class methods.
index b4e04ea..6eee698 100644 (file)
@@ -49,6 +49,64 @@ template <template <class...> class Template, class... Args>
 struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
 template<template<class...> class Template, class T> using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
 
+namespace detail {
+/**
+ * strip_class: helper to remove the class type from pointers to `operator()`.
+ */
+
+template <typename T>
+struct strip_class {};
+template <typename Class, typename Result, typename... Args>
+struct strip_class<Result (Class::*)(Args...)> {
+  using type = Result(Args...);
+};
+template <typename Class, typename Result, typename... Args>
+struct strip_class<Result (Class::*)(Args...) const> {
+  using type = Result(Args...);
+};
+template <typename T>
+using strip_class_t = typename strip_class<T>::type;
+} // namespace detail
+
+/**
+ * Evaluates to true_type, iff the given class is a Functor
+ * (i.e. has a call operator with some set of arguments)
+ */
+
+template<class Functor, class Enable = void>
+struct is_functor : std::false_type {};
+template<class Functor>
+struct is_functor<Functor, guts::enable_if_t<is_function_type<detail::strip_class_t<decltype(&Functor::operator())>>::value>> : std::true_type {};
+
+
+/**
+ * lambda_is_stateless<T> is true iff the lambda type T is stateless
+ * (i.e. does not have a closure).
+ * Example:
+ *  auto stateless_lambda = [] (int a) {return a;};
+ *  lambda_is_stateless<decltype(stateless_lambda)> // true
+ *  auto stateful_lambda = [&] (int a) {return a;};
+ *  lambda_is_stateless<decltype(stateful_lambda)> // false
+ */
+namespace detail {
+template<class LambdaType, class FuncType> struct is_stateless_lambda__ final {
+    static_assert(!std::is_same<LambdaType, LambdaType>::value, "Base case shouldn't be hit");
+};
+// implementation idea: According to the C++ standard, stateless lambdas are convertible to function pointers
+template<class LambdaType, class C, class Result, class... Args>
+struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
+template<class LambdaType, class C, class Result, class... Args>
+struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)> : std::is_convertible<LambdaType, Result(*)(Args...)> {};
+
+// case where LambdaType is not even a functor
+template<class LambdaType, class Enable = void> struct is_stateless_lambda_ final : std::false_type {};
+// case where LambdaType is a functor
+template<class LambdaType> struct is_stateless_lambda_<LambdaType, guts::enable_if_t<is_functor<LambdaType>::value>>
+: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
+}
+template<class T>
+using is_stateless_lambda = detail::is_stateless_lambda_<guts::decay_t<T>>;
+
 
 
 /**