From daa77c6e269df04042dafdb5512ac239f90e5f31 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 21 Mar 2019 14:51:38 -0700 Subject: [PATCH] Move schema inference to c10 (#18090) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18090 This schema inference is needed by the c10 operator registration mechanism. Move it to c10. It is going to be used by diffs stacked on top. Reviewed By: ezyang Differential Revision: D14491454 fbshipit-source-id: 0f8ddcdbd91467c8347d315dd443a1ca8b216481 --- aten/src/ATen/core/op_registration/infer_schema.h | 88 +++++++++++++++++++++++ c10/CMakeLists.txt | 2 +- setup.py | 2 +- torch/csrc/jit/custom_operator.h | 75 ++----------------- 4 files changed, 95 insertions(+), 72 deletions(-) create mode 100644 aten/src/ATen/core/op_registration/infer_schema.h diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h new file mode 100644 index 0000000..e040503 --- /dev/null +++ b/aten/src/ATen/core/op_registration/infer_schema.h @@ -0,0 +1,88 @@ +#pragma once + +/** + * This file contains functionality to take a C++ function and infer its + * c10::FunctionSchema. + */ + +#include +#include +#include + +namespace c10 { + +namespace detail { +/// Checks the static C++ type `T` for correctness to catch common error cases. +template +void checkStaticTypes() { + // Give nice error messages for some of the common error cases. + // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT + static_assert( + !std::is_integral::value || std::is_same::value, + "INVALID TYPE: Only int64_t is supported as an integral argument type"); + static_assert( + !std::is_same::value, + "INVALID TYPE: float is not supported as an argument type, use double instead"); +} + +template +void checkStaticTypes() { + checkStaticTypes(); + checkStaticTypes(); +} + +template +::std::vector createArgumentVectorFromTypes(guts::index_sequence) { + checkStaticTypes...>(); + // Arguments are named "_" + return {Argument("_" + std::to_string(Is), getTypePtr>())...}; +} + +template +::std::vector createReturns(guts::index_sequence) { + return createArgumentVectorFromTypes(); +} + +/// Unpack a tuple return type into a vector of return types, one per tuple +/// element. +template +::std::vector createReturns(std::tuple* tuple) { + return createReturns(guts::make_index_sequence()); +} + +/// Create a single-element `vector` for simple (non-tuple) return types. +template +::std::vector createReturns(ReturnType*) { + checkStaticTypes>(); + return {Argument("_1", getTypePtr>())}; +} + +/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices +/// into the argument list. +template +::std::vector createArgumentVectorFromTraits(guts::index_sequence indices) { + using ArgumentTypes = typename FunctionTraits::parameter_types; + return createArgumentVectorFromTypes< + c10::guts::typelist::element_t...>(indices); +} + +/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a +/// function. +template +FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) { + using ReturnType = typename FunctionTraits::return_type; + + auto arguments = createArgumentVectorFromTraits( + guts::make_index_sequence()); + auto returns = createReturns(static_cast(nullptr)); + + return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)}; +} +} + +template +FunctionSchema inferFunctionSchema(std::string name, std::string overload_name) { + return detail::createFunctionSchemaFromTraits>(std::move(name), std::move(overload_name)); +} + +} diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 81fab30..4ae588e 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -29,7 +29,7 @@ file(GLOB C10_SRCS *.cpp core/*.cpp core/dispatch/*.cpp - core/opschema/*.cpp + core/op_registration/*.cpp core/impl/*.cpp macros/*.cpp util/*.cpp diff --git a/setup.py b/setup.py index 4ce1993..deed8b5 100644 --- a/setup.py +++ b/setup.py @@ -744,8 +744,8 @@ if __name__ == '__main__': 'include/c10/macros/*.h', 'include/c10/core/*.h', 'include/ATen/core/dispatch/*.h', + 'include/ATen/core/op_registration/*.h', 'include/c10/core/impl/*.h', - 'include/ATen/core/opschema/*.h', 'include/c10/util/*.h', 'include/c10/cuda/*.h', 'include/c10/cuda/impl/*.h', diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index d877346..22b206a 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -16,73 +17,7 @@ namespace detail { using ::c10::Argument; using ::c10::FunctionSchema; -/// Checks the static C++ type `T` for correctness to catch common error cases. -template -void checkStaticTypes() { - // Give nice error messages for some of the common error cases. - // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT - static_assert( - !std::is_integral::value || std::is_same::value, - "INVALID TYPE: Only int64_t is supported as an integral argument type"); - static_assert( - !std::is_same::value, - "INVALID TYPE: float is not supported as an argument type, use double instead"); -} - -template -void checkStaticTypes() { - checkStaticTypes(); - checkStaticTypes(); -} - -template -::std::vector createArgumentVectorFromTypes(Indices indices) { - checkStaticTypes...>(); - // Arguments are named "_" - return {Argument("_" + std::to_string(Is), getTypePtr>())...}; -} - -template -::std::vector createReturns(Indices indices) { - return createArgumentVectorFromTypes(); -} - -/// Unpack a tuple return type into a vector of return types, one per tuple -/// element. -template -::std::vector createReturns(std::tuple* tuple) { - // Create an index pack so we can call `get` on the tuple next. - return createReturns(typename MakeIndices::indices{}); -} -/// Create a single-element `vector` for simple (non-tuple) return types. -template -::std::vector createReturns(ReturnType*) { - checkStaticTypes>(); - return {Argument("_1", getTypePtr>())}; -} - -/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices -/// into the argument list. -template -::std::vector createArgumentVectorFromTraits(Indices indices) { - using ArgumentTypes = typename FunctionTraits::parameter_types; - return createArgumentVectorFromTypes< - c10::guts::typelist::element_t...>(indices); -} - -/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a -/// function. -template -FunctionSchema createFunctionSchemaFromTraits(const std::string& name) { - using ReturnType = typename FunctionTraits::return_type; - - auto arguments = createArgumentVectorFromTraits( - typename MakeIndices::indices{}); - auto returns = createReturns(static_cast(nullptr)); - - return {name, "", arguments, returns}; -} /// Adds the elements of the `tuple` as input nodes to the traced graph. template @@ -178,8 +113,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) { const auto bracketIndex = schemaOrName.find('('); if (bracketIndex == std::string::npos) { // Infer the full schema and we're good. - return torch::jit::detail::createFunctionSchemaFromTraits( - /*name=*/schemaOrName); + return c10::detail::createFunctionSchemaFromTraits( + /*name=*/schemaOrName, ""); } // If the user provided her own schema, we need to infer it nevertheless and @@ -189,8 +124,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) { auto providedSchema = parseSchema(schemaOrName); const auto inferredSchema = - torch::jit::detail::createFunctionSchemaFromTraits( - providedSchema.name()); + c10::detail::createFunctionSchemaFromTraits( + providedSchema.name(), providedSchema.overload_name()); checkArgumentVector( "argument", inferredSchema.arguments(), -- 2.7.4