Move stuff to cpp files (#18301)
authorSebastian Messmer <messmer@fb.com>
Sat, 30 Mar 2019 07:03:44 +0000 (00:03 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 07:07:25 +0000 (00:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18301

Move code out of headers and templates into source files and non-templates.

Reviewed By: dzhulgakov

Differential Revision: D14572347

fbshipit-source-id: 9fd5d62d54000a95e93076cd73f591ba2c5c2653

aten/src/ATen/core/function_schema.h
aten/src/ATen/core/op_registration/base.h
aten/src/ATen/core/op_registration/op_registration.cpp
aten/src/ATen/core/op_registration/op_registration.h

index 5f80d12..da1ec34 100644 (file)
@@ -216,4 +216,10 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema)
   return out;
 }
 
+inline std::string toString(const FunctionSchema& schema) {
+  std::ostringstream str;
+  str << schema;
+  return str.str();
+}
+
 } // namespace c10
index 60250a5..3fdf6e9 100644 (file)
@@ -1,8 +1,7 @@
 #pragma once
 
 /**
- * This file sets up the basics for operator registration like the
- * c10::RegisterOperators() class.
+ * This file sets up the basics for operator registration.
  *
  * You probably don't want to include this file directly but include
  * op_registration.h instead since that adds more functionality you'll
@@ -15,45 +14,10 @@ namespace c10 {
 
 namespace detail {
 
-  // OperatorRegistrar in its constructor registers an operator in the dispatch
-  // table deregisters it in the destructor. The intent is that this class is
-  // constructed at static initialization time so that operators automatically
-  // get registered when a dlopen() occurs.
-  // You shouldn't call this directly; instead, use the RegisterOperators class.
-  class OperatorRegistrar final {
-  public:
-    explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
-    : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
-      Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator));
-    }
-
-    OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
-    :  op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) {
-      rhs.owns_registration_ = false;
-    }
-
-    // not needed and would break RAII if defaulted.
-    OperatorRegistrar& operator=(OperatorRegistrar&& rhs) noexcept = delete;
-    OperatorRegistrar(const OperatorRegistrar& rhs) = delete;
-    OperatorRegistrar& operator=(const OperatorRegistrar& rhs) = delete;
-
-    ~OperatorRegistrar() {
-      if (owns_registration_) {
-        Dispatcher::singleton().deregisterKernel(op_, dispatch_key_);
-        Dispatcher::singleton().deregisterSchema(op_);
-      }
-    }
-
-  private:
-    const OperatorHandle op_;
-    const TensorTypeId dispatch_key_;
-    bool owns_registration_;
-  };
-
   // KernelRegistrationConfig accumulates all information from the config
   // parameters passed to a RegisterOperators::op() call into one object.
   struct KernelRegistrationConfig final {
-    TensorTypeId dispatch_key;
+    c10::optional<TensorTypeId> dispatch_key = c10::nullopt;
     KernelFunction* kernel_func = nullptr;
     KernelCacheCreatorFunction cache_creator_func = nullptr;
     std::unique_ptr<FunctionSchema> inferred_function_schema = nullptr;
@@ -87,13 +51,6 @@ namespace detail {
     // apply all configParameters
     (void)std::initializer_list<int>{(std::forward<ConfigParameters>(configParameters).apply(&config), 0)...};
 
-    // TODO Allow this for just registering the schema?
-    AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel");
-
-    // if kernel_func is set, so must be cache_creator_func,
-    // the API shouldn't allow anything else.
-    AT_ASSERT(static_cast<bool>(config.cache_creator_func));
-
     return config;
   }
 }
index 5294dc1..cb43c55 100644 (file)
@@ -7,4 +7,50 @@ RegisterOperators::~RegisterOperators() = default;
 RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default;
 RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) = default;
 
+// OperatorRegistrar in its constructor registers an operator in the dispatch
+// table deregisters it in the destructor.
+class RegisterOperators::OperatorRegistrar final {
+public:
+  explicit OperatorRegistrar(FunctionSchema&& schema, TensorTypeId dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator)
+  : op_(Dispatcher::singleton().registerSchema(std::move(schema))), dispatch_key_(std::move(dispatch_key)), owns_registration_(true) {
+    Dispatcher::singleton().registerKernel(op_, dispatch_key_, kernel, std::move(cache_creator));
+  }
+
+  OperatorRegistrar(OperatorRegistrar&& rhs) noexcept
+  :  op_(std::move(rhs.op_)), dispatch_key_(std::move(rhs.dispatch_key_)), owns_registration_(rhs.owns_registration_) {
+    rhs.owns_registration_ = false;
+  }
+
+  // not needed and would break RAII if defaulted.
+  OperatorRegistrar& operator=(OperatorRegistrar&& rhs) noexcept = delete;
+  OperatorRegistrar(const OperatorRegistrar& rhs) = delete;
+  OperatorRegistrar& operator=(const OperatorRegistrar& rhs) = delete;
+
+  ~OperatorRegistrar() {
+    if (owns_registration_) {
+      Dispatcher::singleton().deregisterKernel(op_, dispatch_key_);
+      Dispatcher::singleton().deregisterSchema(op_);
+    }
+  }
+
+private:
+  const OperatorHandle op_;
+  const TensorTypeId dispatch_key_;
+  bool owns_registration_;
+};
+
+void RegisterOperators::registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config) {
+  // TODO Allow this for registering the schema without a kernel?
+  AT_CHECK(config.kernel_func != nullptr, "Cannot register operator without kernel");
+
+  // if kernel_func is set, so must be cache_creator_func, the API shouldn't allow anything else.
+  AT_ASSERT(static_cast<bool>(config.cache_creator_func));
+
+  if (config.inferred_function_schema.get() != nullptr) {
+    assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
+  }
+
+  registrars_.emplace_back(std::move(schema), *config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
+}
+
 }
index ff0ec8c..7861eee 100644 (file)
@@ -65,21 +65,18 @@ public:
   template<class... ConfigParameters>
   guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
   op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
-    detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
-
-    if (config.inferred_function_schema.get() != nullptr) {
-      assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
-    }
-
-    registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
+    registerOp_(std::move(schema), detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...));
     return std::move(*this);
   }
 
-  // TODO error if dispatch key is not specified
   // TODO Add deprecated function and lambda based kernel APIs
 
 private:
-  std::vector<detail::OperatorRegistrar> registrars_;
+  void registerOp_(FunctionSchema&& schema, detail::KernelRegistrationConfig&& config);
+
+  class OperatorRegistrar;
+
+  std::vector<OperatorRegistrar> registrars_;
 };
 
 }