From 1877087df25999a5c65c985c718836a0d01b29ae Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 21 Mar 2019 14:51:38 -0700 Subject: [PATCH] Allow registering same operator schema multiple times (#18038) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18038 Now that we have named overloads, we can allow registering the same function schema multiple times and just check it's identical. This is going to be used in custom op registration since they register the schema every time a kernel is registered. Reviewed By: dzhulgakov Differential Revision: D14467494 fbshipit-source-id: 2c26cf72a64b65f120afe05e989302ec42597515 --- aten/src/ATen/core/alias_info.h | 7 +++ aten/src/ATen/core/dispatch/Dispatcher.cpp | 57 ++++++++++++++++++---- aten/src/ATen/core/dispatch/Dispatcher.h | 29 ++++++++++- aten/src/ATen/core/function_schema.h | 32 ++++++++++++ aten/src/ATen/core/ivalue.h | 21 +++++++- .../operators/experimental/c10/schemas/filler.cc | 8 +-- 6 files changed, 137 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/core/alias_info.h b/aten/src/ATen/core/alias_info.h index 61f482e..36704e4 100644 --- a/aten/src/ATen/core/alias_info.h +++ b/aten/src/ATen/core/alias_info.h @@ -80,6 +80,13 @@ class AliasInfo { bool isWrite_ = false; }; +inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { + return lhs.isWrite() == rhs.isWrite() + && lhs.beforeSets() == rhs.beforeSets() + && lhs.afterSets() == rhs.afterSets() + && lhs.containedTypes() == rhs.containedTypes(); +} + // DEBUG ONLY; this does not match the way things are represented in the schema inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { out << "("; diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index da0df1a..85a148f 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -1,4 +1,5 @@ #include +#include namespace c10 { @@ -39,15 +40,44 @@ C10_EXPORT Dispatcher& Dispatcher::singleton() { return _singleton; } +c10::optional Dispatcher::findSchema(const char* operator_name, const char* overload_name) { + const auto found = std::find_if(operators_.begin(), operators_.end(), [&] (const OperatorDef& opDef) { + return opDef.schema.name() == operator_name && opDef.schema.overload_name() == overload_name; + }); + + if (found == operators_.end()) { + return c10::nullopt; + } + + return OperatorHandle(found); +} + +OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) { + const auto found = findSchema(schema.name().c_str(), schema.overload_name().c_str()); + if (found != c10::nullopt) { + if (found->schema() != schema) { + std::ostringstream str; + str << schema << " vs " << found->schema(); + AT_ERROR("Tried to register multiple operators with the same name and the same overload name but different schemas: ", str.str()); + } + return *found; + } + + operators_.emplace_back(std::move(schema)); + return OperatorHandle(--operators_.end()); +} + OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); - operators_.emplace_back(std::move(schema)); - auto op = OperatorHandle(--operators_.end()); + auto op = findOrRegisterSchema_(std::move(schema)); - // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op - listeners_->callOnOperatorRegistered(op); + ++op.operatorDefIterator_->refcount; + if (1 == op.operatorDefIterator_->refcount) { + // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op + listeners_->callOnOperatorRegistered(op); + } return op; } @@ -56,14 +86,21 @@ void Dispatcher::deregisterSchema(const OperatorHandle& op) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); - if (!op.operatorDefIterator_->dispatchTable.isEmpty()) { - AT_ERROR("Tried to deregister op schema that still has kernels registered"); - } + // reduce refcount and actually deregister if no references left + AT_ASSERT(op.operatorDefIterator_->refcount > 0); + --op.operatorDefIterator_->refcount; + if (0 == op.operatorDefIterator_->refcount) { + if (!op.operatorDefIterator_->dispatchTable.isEmpty()) { + std::ostringstream str; + str << op.schema(); + AT_ERROR("Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", str.str()); + } - // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op - listeners_->callOnOperatorDeregistered(op); + // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op + listeners_->callOnOperatorDeregistered(op); - operators_.erase(op.operatorDefIterator_); + operators_.erase(op.operatorDefIterator_); + } } void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 6f8d13b..35b32a3 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -72,10 +72,12 @@ private: struct OperatorDef final { explicit OperatorDef(FunctionSchema schema_) : dispatchTable(schema_) - , schema(std::move(schema_)) {} + , schema(std::move(schema_)) + , refcount(0) {} DispatchTable dispatchTable; FunctionSchema schema; + size_t refcount; }; friend class OperatorHandle; @@ -91,16 +93,37 @@ public: /** * Register a new operator schema. The handle returned can be used to register * kernels to this operator or to call it. + * + * If a schema with the same operator name and overload name already exists, + * this function will check that both schemas are exactly identical and then + * return the existing schema. + * + * Each call to registerSchema() should have a corresponding call to + * deregisterSchema(), even if multiple calls register (or deregister) + * schemas with the same operator name and overload name. */ OperatorHandle registerSchema(FunctionSchema schema); /** * Remove an operator from the dispatcher. Make sure you removed - * all kernels for this operatorbefore calling this. + * all kernels for this operator before calling this. + * + * If a schema was registered multiple times (see above how registerSchema() + * handles registering schemas that already exist), it must be deregistered + * the exact same number of times before it is actually deregistered. + * That is, each call to registerSchema() should have a corresponding call + * to deregisterSchema(). */ void deregisterSchema(const OperatorHandle& op); /** + * Looks for an operator schema with the given name and overload name + * and returns it if it is registered. + * Returns nullopt otherwise. + */ + c10::optional findSchema(const char* operator_name, const char* overload_name); + + /** * Register an operator to the dispatch table for an operator. */ void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func); @@ -126,6 +149,8 @@ public: private: Dispatcher(); + OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema); + std::list operators_; std::unique_ptr listeners_; std::mutex mutex_; diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 1c08d54..5f80d12 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -64,6 +64,25 @@ private: c10::optional alias_info_; }; +namespace detail { +inline bool defaultValueEquals_(const c10::optional& lhs, const c10::optional& rhs) { + if (lhs.has_value()) { + return rhs.has_value() && shallowEquals(*lhs, *rhs); + } else { + return !rhs.has_value(); + } +} +} + +inline bool operator==(const Argument& lhs, const Argument& rhs) { + return lhs.name() == rhs.name() + && lhs.type() == rhs.type() + && lhs.N() == rhs.N() + && detail::defaultValueEquals_(lhs.default_value(), rhs.default_value()) + && lhs.kwarg_only() == rhs.kwarg_only() + && lhs.alias_info() == rhs.alias_info(); +} + struct FunctionSchema { FunctionSchema( std::string name, @@ -142,6 +161,19 @@ public: } }; +inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return lhs.name() == rhs.name() + && lhs.overload_name() == rhs.overload_name() + && lhs.arguments() == rhs.arguments() + && lhs.returns() == rhs.returns() + && lhs.is_vararg() == rhs.is_vararg() + && lhs.is_varret() == rhs.is_varret(); +} + +inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) { + return !(lhs == rhs); +} + // for debugging, make sure we can describe the call site inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=" : ""); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 3246512..487107e 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -962,6 +962,23 @@ inline bool IValue::isSameIdentity(IValue& rhs) { && this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; } } + +inline bool shallowEquals(const IValue& lhs, const IValue& rhs) { + if (lhs.isNone()) { + return rhs.isNone(); + } else if (lhs.isInt()) { + return rhs.isInt() && lhs.toInt() == rhs.toInt(); + } else if (lhs.isString()) { + return rhs.isString() && lhs.toStringRef() == rhs.toStringRef(); + } else if (lhs.isDouble()) { + return rhs.isDouble() && lhs.toDouble() == rhs.toDouble(); + } else if (lhs.isBool()) { + return rhs.isBool() && lhs.toBool() == rhs.toBool(); + } else { + AT_ERROR("shallowEquals(IValue, IValue) not implemented for type ", lhs.tagKind()); + } +} + } // namespace c10 inline size_t at::ivalue::DictHash::operator()( @@ -980,7 +997,9 @@ inline size_t at::ivalue::DictHash::operator()( inline bool at::ivalue::DictEqualTo::operator()( const c10::IValue& lhs, const c10::IValue& rhs) const { - if (lhs.isInt()) { + if (lhs.isNone()) { + return rhs.isNone(); + } else if (lhs.isInt()) { return lhs.toInt() == rhs.toInt(); } else if (lhs.isString()) { return lhs.toStringRef() == rhs.toStringRef(); diff --git a/caffe2/operators/experimental/c10/schemas/filler.cc b/caffe2/operators/experimental/c10/schemas/filler.cc index 8f38dd5..f4fdbbb 100644 --- a/caffe2/operators/experimental/c10/schemas/filler.cc +++ b/caffe2/operators/experimental/c10/schemas/filler.cc @@ -28,7 +28,7 @@ C10_DEFINE_OP_SCHEMA( C10_DEFINE_OP_SCHEMA( UniformFill, FunctionSchema( - "_c10_experimental::ConstantFill", + "_c10_experimental::UniformFill", "", (std::vector{ c10::Argument("inputs", ListType::ofTensors()), @@ -42,7 +42,7 @@ C10_DEFINE_OP_SCHEMA( C10_DEFINE_OP_SCHEMA( GivenTensorFill, FunctionSchema( - "_c10_experimental::ConstantFill", + "_c10_experimental::GivenTensorFill", "", (std::vector{ c10::Argument("inputs", ListType::ofTensors()), @@ -56,7 +56,7 @@ C10_DEFINE_OP_SCHEMA( C10_DEFINE_OP_SCHEMA( GivenTensorIntFill, FunctionSchema( - "_c10_experimental::ConstantFill", + "_c10_experimental::GivenTensorIntFill", "", (std::vector{ c10::Argument("inputs", ListType::ofTensors()), @@ -70,7 +70,7 @@ C10_DEFINE_OP_SCHEMA( C10_DEFINE_OP_SCHEMA( GivenTensorInt64Fill, FunctionSchema( - "_c10_experimental::ConstantFill", + "_c10_experimental::GivenTensorInt64Fill", "", (std::vector{ c10::Argument("inputs", ListType::ofTensors()), -- 2.7.4