bool isWrite_ = false;
};
+inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
+ return lhs.isWrite() == rhs.isWrite()
+ && lhs.beforeSets() == rhs.beforeSets()
+ && lhs.afterSets() == rhs.afterSets()
+ && lhs.containedTypes() == rhs.containedTypes();
+}
+
// DEBUG ONLY; this does not match the way things are represented in the schema
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
out << "(";
#include <ATen/core/dispatch/Dispatcher.h>
+#include <sstream>
namespace c10 {
return _singleton;
}
+c10::optional<OperatorHandle> Dispatcher::findSchema(const char* operator_name, const char* overload_name) {
+ const auto found = std::find_if(operators_.begin(), operators_.end(), [&] (const OperatorDef& opDef) {
+ return opDef.schema.name() == operator_name && opDef.schema.overload_name() == overload_name;
+ });
+
+ if (found == operators_.end()) {
+ return c10::nullopt;
+ }
+
+ return OperatorHandle(found);
+}
+
+OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) {
+ const auto found = findSchema(schema.name().c_str(), schema.overload_name().c_str());
+ if (found != c10::nullopt) {
+ if (found->schema() != schema) {
+ std::ostringstream str;
+ str << schema << " vs " << found->schema();
+ AT_ERROR("Tried to register multiple operators with the same name and the same overload name but different schemas: ", str.str());
+ }
+ return *found;
+ }
+
+ operators_.emplace_back(std::move(schema));
+ return OperatorHandle(--operators_.end());
+}
+
OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);
- operators_.emplace_back(std::move(schema));
- auto op = OperatorHandle(--operators_.end());
+ auto op = findOrRegisterSchema_(std::move(schema));
- // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
- listeners_->callOnOperatorRegistered(op);
+ ++op.operatorDefIterator_->refcount;
+ if (1 == op.operatorDefIterator_->refcount) {
+ // note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
+ listeners_->callOnOperatorRegistered(op);
+ }
return op;
}
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);
- if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
- AT_ERROR("Tried to deregister op schema that still has kernels registered");
- }
+ // reduce refcount and actually deregister if no references left
+ AT_ASSERT(op.operatorDefIterator_->refcount > 0);
+ --op.operatorDefIterator_->refcount;
+ if (0 == op.operatorDefIterator_->refcount) {
+ if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
+ std::ostringstream str;
+ str << op.schema();
+ AT_ERROR("Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", str.str());
+ }
- // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
- listeners_->callOnOperatorDeregistered(op);
+ // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
+ listeners_->callOnOperatorDeregistered(op);
- operators_.erase(op.operatorDefIterator_);
+ operators_.erase(op.operatorDefIterator_);
+ }
}
void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
struct OperatorDef final {
explicit OperatorDef(FunctionSchema schema_)
: dispatchTable(schema_)
- , schema(std::move(schema_)) {}
+ , schema(std::move(schema_))
+ , refcount(0) {}
DispatchTable dispatchTable;
FunctionSchema schema;
+ size_t refcount;
};
friend class OperatorHandle;
/**
* Register a new operator schema. The handle returned can be used to register
* kernels to this operator or to call it.
+ *
+ * If a schema with the same operator name and overload name already exists,
+ * this function will check that both schemas are exactly identical and then
+ * return the existing schema.
+ *
+ * Each call to registerSchema() should have a corresponding call to
+ * deregisterSchema(), even if multiple calls register (or deregister)
+ * schemas with the same operator name and overload name.
*/
OperatorHandle registerSchema(FunctionSchema schema);
/**
* Remove an operator from the dispatcher. Make sure you removed
- * all kernels for this operatorbefore calling this.
+ * all kernels for this operator before calling this.
+ *
+ * If a schema was registered multiple times (see above how registerSchema()
+ * handles registering schemas that already exist), it must be deregistered
+ * the exact same number of times before it is actually deregistered.
+ * That is, each call to registerSchema() should have a corresponding call
+ * to deregisterSchema().
*/
void deregisterSchema(const OperatorHandle& op);
/**
+ * Looks for an operator schema with the given name and overload name
+ * and returns it if it is registered.
+ * Returns nullopt otherwise.
+ */
+ c10::optional<OperatorHandle> findSchema(const char* operator_name, const char* overload_name);
+
+ /**
* Register an operator to the dispatch table for an operator.
*/
void registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func);
private:
Dispatcher();
+ OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
+
std::list<OperatorDef> operators_;
std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_;
c10::optional<AliasInfo> alias_info_;
};
+namespace detail {
+inline bool defaultValueEquals_(const c10::optional<IValue>& lhs, const c10::optional<IValue>& rhs) {
+ if (lhs.has_value()) {
+ return rhs.has_value() && shallowEquals(*lhs, *rhs);
+ } else {
+ return !rhs.has_value();
+ }
+}
+}
+
+inline bool operator==(const Argument& lhs, const Argument& rhs) {
+ return lhs.name() == rhs.name()
+ && lhs.type() == rhs.type()
+ && lhs.N() == rhs.N()
+ && detail::defaultValueEquals_(lhs.default_value(), rhs.default_value())
+ && lhs.kwarg_only() == rhs.kwarg_only()
+ && lhs.alias_info() == rhs.alias_info();
+}
+
struct FunctionSchema {
FunctionSchema(
std::string name,
}
};
+inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+ return lhs.name() == rhs.name()
+ && lhs.overload_name() == rhs.overload_name()
+ && lhs.arguments() == rhs.arguments()
+ && lhs.returns() == rhs.returns()
+ && lhs.is_vararg() == rhs.is_vararg()
+ && lhs.is_varret() == rhs.is_varret();
+}
+
+inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
+ return !(lhs == rhs);
+}
+
// for debugging, make sure we can describe the call site
inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
return out << arg.type()->str() << " " << arg.name() << (arg.default_value() ? "=<default>" : "");
&& this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
}
}
+
+inline bool shallowEquals(const IValue& lhs, const IValue& rhs) {
+ if (lhs.isNone()) {
+ return rhs.isNone();
+ } else if (lhs.isInt()) {
+ return rhs.isInt() && lhs.toInt() == rhs.toInt();
+ } else if (lhs.isString()) {
+ return rhs.isString() && lhs.toStringRef() == rhs.toStringRef();
+ } else if (lhs.isDouble()) {
+ return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
+ } else if (lhs.isBool()) {
+ return rhs.isBool() && lhs.toBool() == rhs.toBool();
+ } else {
+ AT_ERROR("shallowEquals(IValue, IValue) not implemented for type ", lhs.tagKind());
+ }
+}
+
} // namespace c10
inline size_t at::ivalue::DictHash::operator()(
inline bool at::ivalue::DictEqualTo::operator()(
const c10::IValue& lhs,
const c10::IValue& rhs) const {
- if (lhs.isInt()) {
+ if (lhs.isNone()) {
+ return rhs.isNone();
+ } else if (lhs.isInt()) {
return lhs.toInt() == rhs.toInt();
} else if (lhs.isString()) {
return lhs.toStringRef() == rhs.toStringRef();
C10_DEFINE_OP_SCHEMA(
UniformFill,
FunctionSchema(
- "_c10_experimental::ConstantFill",
+ "_c10_experimental::UniformFill",
"",
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),
C10_DEFINE_OP_SCHEMA(
GivenTensorFill,
FunctionSchema(
- "_c10_experimental::ConstantFill",
+ "_c10_experimental::GivenTensorFill",
"",
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),
C10_DEFINE_OP_SCHEMA(
GivenTensorIntFill,
FunctionSchema(
- "_c10_experimental::ConstantFill",
+ "_c10_experimental::GivenTensorIntFill",
"",
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),
C10_DEFINE_OP_SCHEMA(
GivenTensorInt64Fill,
FunctionSchema(
- "_c10_experimental::ConstantFill",
+ "_c10_experimental::GivenTensorInt64Fill",
"",
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),