Allow registering same operator schema multiple times (#18038)
authorSebastian Messmer <messmer@fb.com>
Thu, 21 Mar 2019 21:51:38 +0000 (14:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Mar 2019 21:57:28 +0000 (14:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18038

Now that we have named overloads, we can allow registering the same function schema multiple times and just check it's identical.

This is going to be used in custom op registration since they register the schema every time a kernel is registered.

Reviewed By: dzhulgakov

Differential Revision: D14467494

fbshipit-source-id: 2c26cf72a64b65f120afe05e989302ec42597515

aten/src/ATen/core/alias_info.h
aten/src/ATen/core/dispatch/Dispatcher.cpp
aten/src/ATen/core/dispatch/Dispatcher.h
aten/src/ATen/core/function_schema.h
aten/src/ATen/core/ivalue.h
caffe2/operators/experimental/c10/schemas/filler.cc

index 61f482e..36704e4 100644 (file)
@@ -80,6 +80,13 @@ class AliasInfo {
   bool isWrite_ = false;
 };
 
+inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
+  return lhs.isWrite() == rhs.isWrite()
+      && lhs.beforeSets() == rhs.beforeSets()
+      && lhs.afterSets() == rhs.afterSets()
+      && lhs.containedTypes() == rhs.containedTypes();
+}
+
 // DEBUG ONLY; this does not match the way things are represented in the schema
 inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
   out << "(";
index da0df1a..85a148f 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/core/dispatch/Dispatcher.h>
+#include <sstream>
 
 namespace c10 {
 
@@ -39,15 +40,44 @@ C10_EXPORT Dispatcher& Dispatcher::singleton() {
   return _singleton;
 }
 
+c10::optional<OperatorHandle> Dispatcher::findSchema(const char* operator_name, const char* overload_name) {
+  const auto found = std::find_if(operators_.begin(), operators_.end(), [&] (const OperatorDef& opDef) {
+    return opDef.schema.name() == operator_name && opDef.schema.overload_name() == overload_name;
+  });
+
+  if (found == operators_.end()) {
+    return c10::nullopt;
+  }
+
+  return OperatorHandle(found);
+}
+
+OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) {
+  const auto found = findSchema(schema.name().c_str(), schema.overload_name().c_str());
+  if (found != c10::nullopt) {
+    if (found->schema() != schema) {
+      std::ostringstream str;
+      str << schema << " vs " << found->schema();
+      AT_ERROR("Tried to register multiple operators with the same name and the same overload name but different schemas: ", str.str());
+    }
+    return *found;
+  }
+
+  operators_.emplace_back(std::move(schema));
+  return OperatorHandle(--operators_.end());
+}
+
 OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
   // we need a lock to avoid concurrent writes
   std::lock_guard<std::mutex> lock(mutex_);
 
-  operators_.emplace_back(std::move(schema));
-  auto op = OperatorHandle(--operators_.end());
+  auto op = findOrRegisterSchema_(std::move(schema));
 
-  // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
-  listeners_->callOnOperatorRegistered(op);
+  ++op.operatorDefIterator_->refcount;
+  if (1 == op.operatorDefIterator_->refcount) {
+    // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
+    listeners_->callOnOperatorRegistered(op);
+  }
 
   return op;
 }
@@ -56,14 +86,21 @@ void Dispatcher::deregisterSchema(const OperatorHandle& op) {
   // we need a lock to avoid concurrent writes
   std::lock_guard<std::mutex> lock(mutex_);
 
-  if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
-    AT_ERROR("Tried to deregister op schema that still has kernels registered");
-  }
+  // reduce refcount and actually deregister if no references left
+  AT_ASSERT(op.operatorDefIterator_->refcount > 0);
+  --op.operatorDefIterator_->refcount;
+  if (0 == op.operatorDefIterator_->refcount) {
+    if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
+      std::ostringstream str;
+      str << op.schema();
+      AT_ERROR("Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", str.str());
+    }
 
-  // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
-  listeners_->callOnOperatorDeregistered(op);
+    // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
+    listeners_->callOnOperatorDeregistered(op);
 
-  operators_.erase(op.operatorDefIterator_);
+    operators_.erase(op.operatorDefIterator_);
+  }
 }
 
 void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
index 6f8d13b..35b32a3 100644 (file)
@@ -72,10 +72,12 @@ private:
   struct OperatorDef final {
     explicit OperatorDef(FunctionSchema schema_)
     : dispatchTable(schema_)
-    , schema(std::move(schema_)) {}
+    , schema(std::move(schema_))
+    , refcount(0) {}
 
     DispatchTable dispatchTable;
     FunctionSchema schema;
+    size_t refcount;
   };
   friend class OperatorHandle;
 
@@ -91,16 +93,37 @@ public:
   /**
    * Register a new operator schema. The handle returned can be used to register
    * kernels to this operator or to call it.
+   *
+   * If a schema with the same operator name and overload name already exists,
+   * this function will check that both schemas are exactly identical and then
+   * return the existing schema.
+   *
+   * Each call to registerSchema() should have a corresponding call to
+   * deregisterSchema(), even if multiple calls register (or deregister)
+   * schemas with the same operator name and overload name.
    */
   OperatorHandle registerSchema(FunctionSchema schema);
 
   /**
    * Remove an operator from the dispatcher. Make sure you removed
-   * all kernels for this operatorbefore calling this.
+   * all kernels for this operator before calling this.
+   *
+   * If a schema was registered multiple times (see above how registerSchema()
+   * handles registering schemas that already exist), it must be deregistered
+   * the exact same number of times before it is actually deregistered.
+   * That is, each call to registerSchema() should have a corresponding call
+   * to deregisterSchema().
    */
   void deregisterSchema(const OperatorHandle& op);
 
   /**
+   * Looks for an operator schema with the given name and overload name
+   * and returns it if it is registered.
+   * Returns nullopt otherwise.
+   */
+  c10::optional<OperatorHandle> findSchema(const char* operator_name, const char* overload_name);
+
+  /**
    * Register an operator to the dispatch table for an operator.
    */
   void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func);
@@ -126,6 +149,8 @@ public:
 private:
   Dispatcher();
 
+  OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
+
   std::list<OperatorDef> operators_;
   std::unique_ptr<detail::RegistrationListenerList> listeners_;
   std::mutex mutex_;
index 1c08d54..5f80d12 100644 (file)
@@ -64,6 +64,25 @@ private:
   c10::optional<AliasInfo> alias_info_;
 };
 
+namespace detail {
+inline bool defaultValueEquals_(const c10::optional<IValue>& lhs, const c10::optional<IValue>& rhs) {
+  if (lhs.has_value()) {
+    return rhs.has_value() && shallowEquals(*lhs, *rhs);
+  } else {
+    return !rhs.has_value();
+  }
+}
+}
+
+inline bool operator==(const Argument& lhs, const Argument& rhs) {
+  return lhs.name() == rhs.name()
+          && lhs.type() == rhs.type()
+          && lhs.N() == rhs.N()
+          && detail::defaultValueEquals_(lhs.default_value(), rhs.default_value())
+          && lhs.kwarg_only() == rhs.kwarg_only()
+          && lhs.alias_info() == rhs.alias_info();
+}
+
 struct FunctionSchema {
   FunctionSchema(
       std::string name,
@@ -142,6 +161,19 @@ public:
   }
 };
 
+inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+  return lhs.name() == rhs.name()
+      && lhs.overload_name() == rhs.overload_name()
+      && lhs.arguments() == rhs.arguments()
+      && lhs.returns() == rhs.returns()
+      && lhs.is_vararg() == rhs.is_vararg()
+      && lhs.is_varret() == rhs.is_varret();
+}
+
+inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+  return !(lhs == rhs);
+}
+
 // for debugging, make sure we can describe the call site
 inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
   return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=<default>" : "");
index 3246512..487107e 100644 (file)
@@ -962,6 +962,23 @@ inline bool IValue::isSameIdentity(IValue& rhs) {
         && this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
   }
 }
+
+inline bool shallowEquals(const IValue& lhs, const IValue& rhs) {
+  if (lhs.isNone()) {
+    return rhs.isNone();
+  } else if (lhs.isInt()) {
+    return rhs.isInt() && lhs.toInt() == rhs.toInt();
+  } else if (lhs.isString()) {
+    return rhs.isString() && lhs.toStringRef() == rhs.toStringRef();
+  } else if (lhs.isDouble()) {
+    return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
+  } else if (lhs.isBool()) {
+    return rhs.isBool() && lhs.toBool() == rhs.toBool();
+  } else {
+    AT_ERROR("shallowEquals(IValue, IValue) not implemented for type ", lhs.tagKind());
+  }
+}
+
 } // namespace c10
 
 inline size_t at::ivalue::DictHash::operator()(
@@ -980,7 +997,9 @@ inline size_t at::ivalue::DictHash::operator()(
 inline bool at::ivalue::DictEqualTo::operator()(
     const c10::IValue& lhs,
     const c10::IValue& rhs) const {
-  if (lhs.isInt()) {
+  if (lhs.isNone()) {
+    return rhs.isNone();
+  } else if (lhs.isInt()) {
     return lhs.toInt() == rhs.toInt();
   } else if (lhs.isString()) {
     return lhs.toStringRef() == rhs.toStringRef();
index 8f38dd5..f4fdbbb 100644 (file)
@@ -28,7 +28,7 @@ C10_DEFINE_OP_SCHEMA(
 C10_DEFINE_OP_SCHEMA(
     UniformFill,
     FunctionSchema(
-        "_c10_experimental::ConstantFill",
+        "_c10_experimental::UniformFill",
         "",
         (std::vector<c10::Argument>{
             c10::Argument("inputs", ListType::ofTensors()),
@@ -42,7 +42,7 @@ C10_DEFINE_OP_SCHEMA(
 C10_DEFINE_OP_SCHEMA(
     GivenTensorFill,
     FunctionSchema(
-        "_c10_experimental::ConstantFill",
+        "_c10_experimental::GivenTensorFill",
         "",
         (std::vector<c10::Argument>{
             c10::Argument("inputs", ListType::ofTensors()),
@@ -56,7 +56,7 @@ C10_DEFINE_OP_SCHEMA(
 C10_DEFINE_OP_SCHEMA(
     GivenTensorIntFill,
     FunctionSchema(
-        "_c10_experimental::ConstantFill",
+        "_c10_experimental::GivenTensorIntFill",
         "",
         (std::vector<c10::Argument>{
             c10::Argument("inputs", ListType::ofTensors()),
@@ -70,7 +70,7 @@ C10_DEFINE_OP_SCHEMA(
 C10_DEFINE_OP_SCHEMA(
     GivenTensorInt64Fill,
     FunctionSchema(
-        "_c10_experimental::ConstantFill",
+        "_c10_experimental::GivenTensorInt64Fill",
         "",
         (std::vector<c10::Argument>{
             c10::Argument("inputs", ListType::ofTensors()),