Move schema inference to c10 (#18090)
authorSebastian Messmer <messmer@fb.com>
Thu, 21 Mar 2019 21:51:38 +0000 (14:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Mar 2019 21:57:30 +0000 (14:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18090

This schema inference is needed by the c10 operator registration mechanism. Move it to c10.
It is going to be used by diffs stacked on top.

Reviewed By: ezyang

Differential Revision: D14491454

fbshipit-source-id: 0f8ddcdbd91467c8347d315dd443a1ca8b216481

aten/src/ATen/core/op_registration/infer_schema.h [new file with mode: 0644]
c10/CMakeLists.txt
setup.py
torch/csrc/jit/custom_operator.h

diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h
new file mode 100644 (file)
index 0000000..e040503
--- /dev/null
@@ -0,0 +1,88 @@
+#pragma once
+
+/**
+ * This file contains functionality to take a C++ function and infer its
+ * c10::FunctionSchema.
+ */
+
+#include <ATen/core/function_schema.h>
+#include <c10/util/C++17.h>
+#include <c10/util/Metaprogramming.h>
+
+namespace c10 {
+
+namespace detail {
+/// Checks the static C++ type `T` for correctness to catch common error cases.
+template <typename T>
+void checkStaticTypes() {
+ // Give nice error messages for some of the common error cases.
+ // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
+ static_assert(
+     !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
+     "INVALID TYPE: Only int64_t is supported as an integral argument type");
+ static_assert(
+     !std::is_same<T, float>::value,
+     "INVALID TYPE: float is not supported as an argument type, use double instead");
+}
+
+template <typename First, typename Second, typename... Rest>
+void checkStaticTypes() {
+ checkStaticTypes<First>();
+ checkStaticTypes<Second, Rest...>();
+}
+
+template <typename... Ts, size_t... Is>
+::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
+  checkStaticTypes<guts::decay_t<Ts>...>();
+  // Arguments are named "_<index>"
+  return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
+}
+
+template <typename... Ts, size_t... Is>
+::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
+  return createArgumentVectorFromTypes<Ts..., Is...>();
+}
+
+/// Unpack a tuple return type into a vector of return types, one per tuple
+/// element.
+template <typename... Ts>
+::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
+  return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
+}
+
+/// Create a single-element `vector` for simple (non-tuple) return types.
+template <typename ReturnType>
+::std::vector<Argument> createReturns(ReturnType*) {
+  checkStaticTypes<guts::decay_t<ReturnType>>();
+  return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
+}
+
+/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
+/// into the argument list.
+template <typename FunctionTraits, size_t... Is>
+::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
+ using ArgumentTypes = typename FunctionTraits::parameter_types;
+ return createArgumentVectorFromTypes<
+     c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
+}
+
+/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
+/// function.
+template <typename FunctionTraits>
+FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
+ using ReturnType = typename FunctionTraits::return_type;
+
+ auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
+     guts::make_index_sequence<FunctionTraits::number_of_parameters>());
+ auto returns = createReturns(static_cast<ReturnType*>(nullptr));
+
+ return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
+}
+}
+
+template<class FuncType>
+FunctionSchema inferFunctionSchema(std::string name, std::string overload_name) {
+  return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
+}
+
+}
index 81fab30..4ae588e 100644 (file)
@@ -29,7 +29,7 @@ file(GLOB C10_SRCS
         *.cpp
         core/*.cpp
         core/dispatch/*.cpp
-        core/opschema/*.cpp
+        core/op_registration/*.cpp
         core/impl/*.cpp
         macros/*.cpp
         util/*.cpp
index 4ce1993..deed8b5 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -744,8 +744,8 @@ if __name__ == '__main__':
                 'include/c10/macros/*.h',
                 'include/c10/core/*.h',
                 'include/ATen/core/dispatch/*.h',
+                'include/ATen/core/op_registration/*.h',
                 'include/c10/core/impl/*.h',
-                'include/ATen/core/opschema/*.h',
                 'include/c10/util/*.h',
                 'include/c10/cuda/*.h',
                 'include/c10/cuda/impl/*.h',
index d877346..22b206a 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <torch/csrc/jit/operator.h>
 #include <ATen/core/stack.h>
+#include <ATen/core/op_registration/infer_schema.h>
 #include <torch/csrc/jit/tracer.h>
 #include <torch/csrc/utils/variadic.h>
 
@@ -16,73 +17,7 @@ namespace detail {
 using ::c10::Argument;
 using ::c10::FunctionSchema;
 
-/// Checks the static C++ type `T` for correctness to catch common error cases.
-template <typename T>
-void checkStaticTypes() {
-  // Give nice error messages for some of the common error cases.
-  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
-  static_assert(
-      !std::is_integral<T>::value || std::is_same<T, int64_t>::value,
-      "INVALID TYPE: Only int64_t is supported as an integral argument type");
-  static_assert(
-      !std::is_same<T, float>::value,
-      "INVALID TYPE: float is not supported as an argument type, use double instead");
-}
-
-template <typename First, typename Second, typename... Rest>
-void checkStaticTypes() {
-  checkStaticTypes<First>();
-  checkStaticTypes<Second, Rest...>();
-}
-
-template <typename... Ts, size_t... Is>
-::std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
-  checkStaticTypes<decay_t<Ts>...>();
-  // Arguments are named "_<index>"
-  return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
-}
-
-template <typename... Ts, size_t... Is>
-::std::vector<Argument> createReturns(Indices<Is...> indices) {
-  return createArgumentVectorFromTypes<Ts..., Is...>();
-}
-
-/// Unpack a tuple return type into a vector of return types, one per tuple
-/// element.
-template <typename... Ts>
-::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
-  // Create an index pack so we can call `get<Indices>` on the tuple next.
-  return createReturns<Ts...>(typename MakeIndices<sizeof...(Ts)>::indices{});
-}
 
-/// Create a single-element `vector` for simple (non-tuple) return types.
-template <typename ReturnType>
-::std::vector<Argument> createReturns(ReturnType*) {
-  checkStaticTypes<decay_t<ReturnType>>();
-  return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
-}
-
-/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
-/// into the argument list.
-template <typename FunctionTraits, size_t... Is>
-::std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
-  using ArgumentTypes = typename FunctionTraits::parameter_types;
-  return createArgumentVectorFromTypes<
-      c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
-}
-
-/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
-/// function.
-template <typename FunctionTraits>
-FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
-  using ReturnType = typename FunctionTraits::return_type;
-
-  auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
-      typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
-  auto returns = createReturns(static_cast<ReturnType*>(nullptr));
-
-  return {name, "", arguments, returns};
-}
 
 /// Adds the elements of the `tuple` as input nodes to the traced graph.
 template <size_t... Is, typename... Types>
@@ -178,8 +113,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
   const auto bracketIndex = schemaOrName.find('(');
   if (bracketIndex == std::string::npos) {
     // Infer the full schema and we're good.
-    return torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
-        /*name=*/schemaOrName);
+    return c10::detail::createFunctionSchemaFromTraits<Traits>(
+        /*name=*/schemaOrName, "");
   }
 
   // If the user provided her own schema, we need to infer it nevertheless and
@@ -189,8 +124,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
   auto providedSchema = parseSchema(schemaOrName);
 
   const auto inferredSchema =
-      torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
-          providedSchema.name());
+      c10::detail::createFunctionSchemaFromTraits<Traits>(
+          providedSchema.name(), providedSchema.overload_name());
   checkArgumentVector(
       "argument",
       inferredSchema.arguments(),