From 48e7f989171622db4aeaa44bbd23aa771f0e6972 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Sat, 30 Mar 2019 00:03:44 -0700 Subject: [PATCH] Move stuff to cpp files (#18301) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18301 Move code out of headers and templates into source files and non-templates. Reviewed By: dzhulgakov Differential Revision: D14572347 fbshipit-source-id: 9fd5d62d54000a95e93076cd73f591ba2c5c2653 --- aten/src/ATen/core/function_schema.h | 6 +++ aten/src/ATen/core/op_registration/base.h | 47 +--------------------- .../ATen/core/op_registration/op_registration.cpp | 46 +++++++++++++++++++++ .../ATen/core/op_registration/op_registration.h | 15 +++---- 4 files changed, 60 insertions(+), 54 deletions(-) diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 5f80d12..da1ec34 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -216,4 +216,10 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) return out; } +inline std::string toString(const FunctionSchema& schema) { + std::ostringstream str; + str << schema; + return str.str(); +} + } // namespace c10 diff --git a/aten/src/ATen/core/op_registration/base.h b/aten/src/ATen/core/op_registration/base.h index 60250a5..3fdf6e9 100644 --- a/aten/src/ATen/core/op_registration/base.h +++ b/aten/src/ATen/core/op_registration/base.h @@ -1,8 +1,7 @@ #pragma once /** - * This file sets up the basics for operator registration like the - * c10::RegisterOperators() class. + * This file sets up the basics for operator registration. * * You probably don't want to include this file directly but include * op_registration.h instead since that adds more functionality you'll @@ -15,45 +14,10 @@ namespace c10 { namespace detail { - // OperatorRegistrar in its constructor registers an operator in the dispatch - // table deregisters it in the destructor. The intent is that this class is - // constructed at static initialization time so that operators automatically - // get registered when a dlopen() occurs. - // You shouldn't call this directly; instead, use the RegisterOperators class. - class OperatorRegistrar final { - public: - explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator) - : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { - Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator)); - } - - OperatorRegistrar(OperatorRegistrar&& rhs) noexcept - : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) { - rhs.owns_registration_ = false; - } - - // not needed and would break RAII if defaulted. - OperatorRegistrar& operator=(OperatorRegistrar&& rhs) noexcept = delete; - OperatorRegistrar(const OperatorRegistrar& rhs) = delete; - OperatorRegistrar& operator=(const OperatorRegistrar& rhs) = delete; - - ~OperatorRegistrar() { - if (owns_registration_) { - Dispatcher::singleton().deregisterKernel(op_, dispatch_key_); - Dispatcher::singleton().deregisterSchema(op_); - } - } - - private: - const OperatorHandle op_; - const TensorTypeId dispatch_key_; - bool owns_registration_; - }; - // KernelRegistrationConfig accumulates all information from the config // parameters passed to a RegisterOperators::op() call into one object. struct KernelRegistrationConfig final { - TensorTypeId dispatch_key; + c10::optional dispatch_key = c10::nullopt; KernelFunction* kernel_func = nullptr; KernelCacheCreatorFunction cache_creator_func = nullptr; std::unique_ptr inferred_function_schema = nullptr; @@ -87,13 +51,6 @@ namespace detail { // apply all configParameters (void)std::initializer_list{(std::forward(configParameters).apply(&config), 0)...}; - // TODO Allow this for just registering the schema? - AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel"); - - // if kernel_func is set, so must be cache_creator_func, - // the API shouldn't allow anything else. - AT_ASSERT(static_cast(config.cache_creator_func)); - return config; } } diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 5294dc1..cb43c55 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -7,4 +7,50 @@ RegisterOperators::~RegisterOperators() = default; RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default; RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default; +// OperatorRegistrar in its constructor registers an operator in the dispatch +// table deregisters it in the destructor. +class RegisterOperators::OperatorRegistrar final { +public: + explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator) + : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) { + Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator)); + } + + OperatorRegistrar(OperatorRegistrar&& rhs) noexcept + : op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) { + rhs.owns_registration_ = false; + } + + // not needed and would break RAII if defaulted. + OperatorRegistrar& operator=(OperatorRegistrar&& rhs) noexcept = delete; + OperatorRegistrar(const OperatorRegistrar& rhs) = delete; + OperatorRegistrar& operator=(const OperatorRegistrar& rhs) = delete; + + ~OperatorRegistrar() { + if (owns_registration_) { + Dispatcher::singleton().deregisterKernel(op_, dispatch_key_); + Dispatcher::singleton().deregisterSchema(op_); + } + } + +private: + const OperatorHandle op_; + const TensorTypeId dispatch_key_; + bool owns_registration_; +}; + +void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) { + // TODO Allow this for registering the schema without a kernel? + AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel"); + + // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else. + AT_ASSERT(static_cast(config.cache_creator_func)); + + if (config.inferred_function_schema.get() != nullptr) { + assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); + } + + registrars_.emplace_back(std::move(schema), *config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); +} + } diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index ff0ec8c..7861eee 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -65,21 +65,18 @@ public: template guts::enable_if_t>...>::value, RegisterOperators> op(FunctionSchema schema, ConfigParameters&&... configParameters) && { - detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward(configParameters)...); - - if (config.inferred_function_schema.get() != nullptr) { - assertSchemasHaveSameSignature(*config.inferred_function_schema, schema); - } - - registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func)); + registerOp_(std::move(schema), detail::make_registration_config(std::forward(configParameters)...)); return std::move(*this); } - // TODO error if dispatch key is not specified // TODO Add deprecated function and lambda based kernel APIs private: - std::vector registrars_; + void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config); + + class OperatorRegistrar; + + std::vector registrars_; }; } -- 2.7.4