.. code:: c
Expr VisitExpr_(const CallNode* call) final {
- static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+ static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
+#include <tvm/node/attr_registry_map.h>
#include <tvm/runtime/registry.h>
#include <string>
namespace tvm {
// forward declare name.
-template <typename ValueType>
-class OpMap;
-class GenericOpMap;
-class OpRegistry;
+template <typename>
+class OpAttrMap;
// TODO(tvm-team): migrate low-level intrinsics to use Op
/*!
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
private:
+ /*! \return the internal attr registry index. */
+ uint32_t AttrRegistryIndex() const { return index_; }
+ /*! \brief repr to be printed in registry*/
+ std::string AttrRegistryName() const { return name; }
+
// friend class
- friend class GenericOpMap;
- friend class OpRegistry;
+ template <typename>
+ friend class AttrRegistryMapContainerMap;
+ template <typename, typename>
+ friend class AttrRegistry;
+ friend class OpRegEntry;
+
friend bool IsPrimitiveOp(const RelayExpr&);
// Program internal unique index of operator.
// Used to help index the program.
inline const OpNode* operator->() const;
/*!
* \brief Get additional registered attribute about operators.
- * If nothing has been registered, an empty OpMap will be returned.
+ * If nothing has been registered, an empty OpAttrMap will be returned.
* \param attr_name The name of the attribute.
- * \return An OpMap of specified attr_name.
+ * \return An OpAttrMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
template <typename ValueType>
- inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
+ inline static OpAttrMap<ValueType> GetAttrMap(const std::string& attr_name);
/*!
- * \brief Checks if an attr is present in the registry.
+ * \brief Checks if an attr map is present in the registry.
* \param attr_name The name of the attribute.
* \return bool True if the attr is present.
*/
- inline static bool HasAttr(const std::string& attr_name);
+ TVM_DLL static bool HasAttrMap(const String& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
/*!
* \brief Get generic attrmap given attr name
* \param key The attribute key
- * \return reference to GenericOpMap
+ * \return The attr map.
*/
- TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
- /*!
- * \brief Checks if the key is present in the registry
- * \param key The attribute key
- * \return bool True if the key is present
- */
- TVM_DLL static bool HasGenericAttr(const String& key);
+ TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
};
/*!
* \brief Helper structure to register operators
* \sa TVM_REGISTER_OP
*/
-class OpRegistry {
+class OpRegEntry {
public:
/*! \return the operator */
const Op& op() const { return op_; }
* \param descr the description string.
* \return reference to self.
*/
- inline OpRegistry& describe(const std::string& descr); // NOLINT(*)
+ inline OpRegEntry& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
- inline OpRegistry& add_argument(const std::string& name, const std::string& type,
+ inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
const std::string& description);
/*!
* \brief Attach the type function corresponding to the return type.
* relation on variables.
* \return reference to self.
*/
- inline OpRegistry& add_type_rel(
+ inline OpRegEntry& add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
type_rel_func);
* \return reference to self.
*/
template <typename AttrsType>
- inline OpRegistry& set_attrs_type();
+ inline OpRegEntry& set_attrs_type();
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
*/
- inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*)
+ inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*)
/*!
* \brief Set the support level of op.
* \param level The support level.
* \return reference to self.
*/
- inline OpRegistry& set_support_level(int32_t level); // NOLINT(*)
+ inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \tparam ValueType The type of the value to be set.
*/
template <typename ValueType>
- inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
+ inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);
/*!
inline void reset_attr(const std::string& attr_name);
// set the name of the op to be the same as registry
- inline OpRegistry& set_name() { // NOLINT(*)
+ inline OpRegEntry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
get()->name = name;
}
return *this;
}
- /*! \return The global single registry */
- TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
+ /*!
+ * \brief Register or get a new entry.
+ * \param name The name of the operator.
+ * \return the corresponding entry.
+ */
+ TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
private:
- friend class ::dmlc::Registry<OpRegistry>;
+ template <typename, typename>
+ friend class AttrRegistry;
// the name
std::string name;
/*! \brief The operator */
Op op_;
// private constructor
- TVM_DLL OpRegistry();
+ TVM_DLL OpRegEntry(uint32_t reg_index);
// return internal pointer to op.
inline OpNode* get();
- // update the attribute OpMap
-
+ // update the attribute OpAttrMap
TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};
/*!
- * \brief Generic map to store additional information of Op.
- */
-class GenericOpMap {
- public:
- /*!
- * \brief Check if the map has op as key.
- * \param op The key to the map
- * \return 1 if op is contained in map, 0 otherwise.
- */
- inline int count(const Op& op) const;
- /*!
- * \brief get the corresponding value element at op
- * \param op The key to the map
- * \return the const reference to the content value.
- */
- inline const runtime::TVMRetValue& operator[](const Op& op) const;
- /*!
- * \brief get the corresponding value element at op with default value.
- * \param op The key to the map
- * \param def_value The default value when the key does not exist.
- * \return the const reference to the content value.
- * \tparam ValueType The content value type.
- */
- template <typename ValueType>
- inline ValueType get(const Op& op, ValueType def_value) const;
- /*!
- * \brief get the corresponding value element at op with default value.
- * \param expr The key to the map
- * \param def_value The default value when the key does not exist
- * or if expr is not an Op.
- * \return the const reference to the content value.
- * \tparam ValueType The content value type.
- */
- template <typename ValueType>
- inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
-
- private:
- friend class OpRegistry;
- // the attribute field.
- std::string attr_name_;
- // internal data
- std::vector<std::pair<runtime::TVMRetValue, int> > data_;
- // The value
- GenericOpMap() = default;
-};
-
-/*!
* \brief Map<Op,ValueType> used to store meta-information about Op.
* \tparam ValueType The type of the value stored in map.
*/
template <typename ValueType>
-class OpMap {
+class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
public:
/*!
- * \brief Check if the map has op as key.
- * \param op The key to the map
- * \return 1 if op is contained in map, 0 otherwise.
- */
- inline int count(const Op& op) const;
- /*!
- * \brief get the corresponding value element at op
- * \param op The key to the map
- * \return the const reference to the content value.
- */
- inline ValueType operator[](const Op& op) const;
- /*!
- * \brief get the corresponding value element at op with default value.
- * \param op The key to the map
- * \param def_value The default value when the key does not exist.
- * \return the const reference to the content value.
- */
- inline ValueType get(const Op& op, ValueType def_value) const;
- /*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
*/
inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
+ using TParent = AttrRegistryMap<Op, ValueType>;
+ using TParent::count;
+ using TParent::get;
+ using TParent::operator[];
+
private:
friend class Op;
// constructor
- explicit OpMap(const GenericOpMap& map) : map_(map) {}
- /*! \brief The internal map field */
- const GenericOpMap& map_;
+ explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
};
#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
// internal macros to make
-#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op
+#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
/*!
* \def TVM_REGISTER_OP
*/
#define TVM_REGISTER_OP(OpName) \
TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
- ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name()
+ ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
// implementations
inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }
template <typename ValueType>
-inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
- return OpMap<ValueType>(Op::GetGenericAttr(key));
+inline OpAttrMap<ValueType> Op::GetAttrMap(const std::string& key) {
+ return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
}
-inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); }
-
-inline OpNode* OpRegistry::get() { return const_cast<OpNode*>(op_.operator->()); }
+inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
-inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*)
+inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*)
get()->description = descr;
return *this;
}
-inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type,
+inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
const std::string& description) {
auto n = make_object<AttrFieldInfoNode>();
n->name = name;
return *this;
}
-inline OpRegistry& OpRegistry::add_type_rel(
+inline OpRegEntry& OpRegEntry::add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
type_rel_func) {
return *this;
}
-inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*)
get()->num_inputs = n;
return *this;
}
template <typename AttrsType>
-inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
get()->attrs_type_key = AttrsType::_type_key;
get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
return *this;
}
-inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
get()->support_level = n;
return *this;
}
template <typename ValueType>
-inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
runtime::TVMRetValue rv;
return *this;
}
-// member functions of OpMap
-inline int GenericOpMap::count(const Op& op) const {
- if (op.defined()) {
- const uint32_t idx = op->index_;
- return idx < data_.size() ? (data_[idx].second != 0) : 0;
- } else {
- return 0;
- }
-}
-
-inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const {
- CHECK(op.defined());
- const uint32_t idx = op->index_;
- CHECK(idx < data_.size() && data_[idx].second != 0)
- << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name;
- return data_[idx].first;
-}
-
-template <typename ValueType>
-inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
- CHECK(op.defined());
- const uint32_t idx = op->index_;
- if (idx < data_.size() && data_[idx].second != 0) {
- return data_[idx].first;
- } else {
- return value;
- }
-}
+// member functions of OpAttrMap
template <typename ValueType>
-inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const {
+inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
CHECK(expr.defined());
if (const OpNode* op = expr.as<OpNode>()) {
- const uint32_t idx = op->index_;
- if (idx < data_.size() && data_[idx].second != 0) {
- return data_[idx].first;
- } else {
- return value;
- }
+ return this->map_.get(GetRef<Op>(op), def_value);
} else {
- return value;
+ return def_value;
}
}
-template <typename ValueType>
-inline int OpMap<ValueType>::count(const Op& op) const {
- return map_.count(op);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
- return map_[op];
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Op& op, ValueType def_value) const {
- return map_.get<ValueType>(op, def_value);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
- return map_.get<ValueType>(expr, def_value);
-}
-
/*!
* \brief Check that an expression is a "primitive operator".
*
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/attr_registry_map.h
+ * \brief Attribute map used in registry.
+ */
+#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_
+#define TVM_NODE_ATTR_REGISTRY_MAP_H_
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+
+/*!
+ * \brief Generic attribute map.
+ * \tparam KeyType the type of the key.
+ */
+template <typename KeyType>
+class AttrRegistryMapContainerMap {
+ public:
+ /*!
+ * \brief Check if the map has key.
+ * \param key The key to the map
+ * \return 1 if key is contained in map, 0 otherwise.
+ */
+ int count(const KeyType& key) const {
+ if (key.defined()) {
+ const uint32_t idx = key->AttrRegistryIndex();
+ return idx < data_.size() ? (data_[idx].second != 0) : 0;
+ } else {
+ return 0;
+ }
+ }
+ /*!
+ * \brief get the corresponding value element at key.
+ * \param key The key to the map
+ * \return the const reference to the content value.
+ */
+ const runtime::TVMRetValue& operator[](const KeyType& key) const {
+ CHECK(key.defined());
+ const uint32_t idx = key->AttrRegistryIndex();
+ CHECK(idx < data_.size() && data_[idx].second != 0)
+ << "Attribute " << attr_name_ << " has not been registered for " << key->name;
+ return data_[idx].first;
+ }
+ /*!
+ * \brief get the corresponding value element at key with default value.
+ * \param key The key to the map
+ * \param def_value The default value when the key does not exist.
+ * \return the const reference to the content value.
+ * \tparam ValueType The content value type.
+ */
+ template <typename ValueType>
+ ValueType get(const KeyType& key, ValueType def_value) const {
+ CHECK(key.defined());
+ const uint32_t idx = key->AttrRegistryIndex();
+ if (idx < data_.size() && data_[idx].second != 0) {
+ return data_[idx].first;
+ } else {
+ return def_value;
+ }
+ }
+
+ private:
+ /*! \brief The name of the attr field */
+ String attr_name_;
+ /*! \brief The internal data. */
+ std::vector<std::pair<runtime::TVMRetValue, int>> data_;
+ /*! \brief The constructor */
+ AttrRegistryMapContainerMap() = default;
+ template <typename, typename>
+ friend class AttrRegistry;
+ friend class OpRegEntry;
+};
+
+/*!
+ * \brief Map<Key, ValueType> used to store meta-data.
+ * \tparam KeyType The type of the key
+ * \tparam ValueType The type of the value stored in map.
+ */
+template <typename KeyType, typename ValueType>
+class AttrRegistryMap {
+ public:
+ /*!
+ * \brief constructor
+ * \param map The internal map.
+ */
+ explicit AttrRegistryMap(const AttrRegistryMapContainerMap<KeyType>& map) : map_(map) {}
+ /*!
+ * \brief Check if the map has op as key.
+ * \param key The key to the map
+ * \return 1 if op is contained in map, 0 otherwise.
+ */
+ int count(const KeyType& key) const { return map_.count(key); }
+ /*!
+ * \brief get the corresponding value element at key.
+ * \param key The key to the map
+ * \return the const reference to the content value.
+ */
+ ValueType operator[](const KeyType& key) const { return map_[key]; }
+ /*!
+ * \brief get the corresponding value element at key with default value.
+ * \param key The key to the map
+ * \param def_value The default value when the key does not exist.
+ * \return the const reference to the content value.
+ */
+ ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); }
+
+ protected:
+ /*! \brief The internal map field */
+ const AttrRegistryMapContainerMap<KeyType>& map_;
+};
+
+} // namespace tvm
+#endif // TVM_NODE_ATTR_REGISTRY_MAP_H_
#include <tvm/runtime/packed_func.h>
#include <memory>
-#include <mutex>
-namespace dmlc {
-// enable registry
-DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
-} // namespace dmlc
+#include "../node/attr_registry.h"
namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
-::dmlc::Registry<OpRegistry>* OpRegistry::Registry() { return ::dmlc::Registry<OpRegistry>::Get(); }
-
-// single manager of operator information.
-struct OpManager {
- // mutex to avoid registration from multiple threads.
- std::mutex mutex;
- // global operator counter
- std::atomic<int> op_counter{0};
- // storage of additional attribute table.
- std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
- // frontend functions
- std::vector<PackedFunc*> frontend_funcs;
- // get singleton of the op manager
- static OpManager* Global() {
- static OpManager* inst = new OpManager();
- return inst;
- }
-};
+using OpRegistry = AttrRegistry<OpRegEntry, Op>;
// find operator by name
const Op& Op::Get(const String& name) {
- const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
+ const OpRegEntry* reg = OpRegistry::Global()->Get(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
}
-OpRegistry::OpRegistry() {
- OpManager* mgr = OpManager::Global();
+OpRegEntry::OpRegEntry(uint32_t reg_index) {
ObjectPtr<OpNode> n = make_object<OpNode>();
- n->index_ = mgr->op_counter++;
+ n->index_ = reg_index;
op_ = Op(n);
}
+OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
+ return OpRegistry::Global()->RegisterOrGet(name);
+}
+
// Get attribute map by key
-const GenericOpMap& Op::GetGenericAttr(const String& key) {
- OpManager* mgr = OpManager::Global();
- std::lock_guard<std::mutex> lock(mgr->mutex);
- auto it = mgr->attr.find(key);
- if (it == mgr->attr.end()) {
- LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
- }
- return *it->second.get();
+const AttrRegistryMapContainerMap<Op>& Op::GetAttrMapContainer(const String& attr_name) {
+ return OpRegistry::Global()->GetAttrMap(attr_name);
}
// Check if a key is present in the registry.
-bool Op::HasGenericAttr(const String& key) {
- OpManager* mgr = OpManager::Global();
- std::lock_guard<std::mutex> lock(mgr->mutex);
- auto it = mgr->attr.find(key);
- if (it == mgr->attr.end()) {
- return false;
- }
- return true;
-}
+bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); }
-// Resets attr of the OpMap.
-void OpRegistry::reset_attr(const std::string& key) {
- OpManager* mgr = OpManager::Global();
- std::lock_guard<std::mutex> lock(mgr->mutex);
- std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
- if (op_map == nullptr) {
- return;
- }
- uint32_t index = op_->index_;
- if (op_map->data_.size() > index) {
- op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
- }
+// Resets attr of the OpAttrMap.
+void OpRegEntry::reset_attr(const std::string& attr_name) {
+ OpRegistry::Global()->ResetAttr(attr_name, op_);
}
-void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
- OpManager* mgr = OpManager::Global();
- std::lock_guard<std::mutex> lock(mgr->mutex);
- std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
- if (op_map == nullptr) {
- op_map.reset(new GenericOpMap());
- op_map->attr_name_ = key;
- }
- uint32_t index = op_->index_;
- if (op_map->data_.size() <= index) {
- op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
- }
- std::pair<TVMRetValue, int>& p = op_map->data_[index];
- CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name
- << " is already registered with same plevel=" << plevel;
- CHECK(value.type_code() != kTVMNullptr)
- << "Registered packed_func is Null for " << key << " of operator " << this->name;
- if (p.second < plevel && value.type_code() != kTVMNullptr) {
- op_map->data_[index] = std::make_pair(value, plevel);
- }
+void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
+ OpRegistry::Global()->UpdateAttr(key, op_, value, plevel);
}
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
- Array<runtime::String> ret;
- for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
- ret.push_back(name);
- }
- return ret;
+ return OpRegistry::Global()->ListAllNames();
});
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
- auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+ auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
*rv = op_map[op];
}
std::string attr_name = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
- auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
reg.set_attr(attr_name, value, plevel);
});
TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
- auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op->name);
reg.reset_attr(attr_name);
});
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
- auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
// enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value);
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
- OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
- reg.set_attr(attr_key, f, plevel);
+ auto* fcopy = new PackedFunc(f);
+ reg.set_attr(attr_key, *fcopy, plevel);
} else {
reg.set_attr(attr_key, args[2], plevel);
}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/node/attr_registry.h
+ * \brief Common global registry for objects that also have additional attrs.
+ */
+#ifndef TVM_NODE_ATTR_REGISTRY_H_
+#define TVM_NODE_ATTR_REGISTRY_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+
+/*!
+ * \breif Implementation of registry with attributes.
+ *
+ * \tparam EntryType Tye type of the registry entry.
+ * \tparam KeyType The actual key that is used to lookup the attributes.
+ * each entry has a corresponding key by default.
+ */
+template <typename EntryType, typename KeyType>
+class AttrRegistry {
+ public:
+ using TSelf = AttrRegistry<EntryType, KeyType>;
+ /*!
+ * \brief Get an entry from the registry.
+ * \param name The name of the item.
+ * \return The corresponding entry.
+ */
+ const EntryType* Get(const String& name) const {
+ auto it = entry_map_.find(name);
+ if (it != entry_map_.end()) return it->second;
+ return nullptr;
+ }
+
+ /*!
+ * \brief Get an entry or register a new one.
+ * \param name The name of the item.
+ * \return The corresponding entry.
+ */
+ EntryType& RegisterOrGet(const String& name) {
+ auto it = entry_map_.find(name);
+ if (it != entry_map_.end()) return *it->second;
+ uint32_t registry_index = static_cast<uint32_t>(entries_.size());
+ auto entry = std::unique_ptr<EntryType>(new EntryType(registry_index));
+ auto* eptr = entry.get();
+ eptr->name = name;
+ entry_map_[name] = eptr;
+ entries_.emplace_back(std::move(entry));
+ return *eptr;
+ }
+
+ /*!
+ * \brief List all the entry names in the registry.
+ * \return The entry names.
+ */
+ Array<String> ListAllNames() const {
+ Array<String> names;
+ for (const auto& kv : entry_map_) {
+ names.push_back(kv.first);
+ }
+ return names;
+ }
+
+ /*!
+ * \brief Update the attribute stable.
+ * \param attr_name The name of the attribute.
+ * \param key The key to the attribute table.
+ * \param value The value to be set.
+ * \param plevel The support level.
+ */
+ void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value,
+ int plevel) {
+ using runtime::TVMRetValue;
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto& op_map = attrs_[attr_name];
+ if (op_map == nullptr) {
+ op_map.reset(new AttrRegistryMapContainerMap<KeyType>());
+ op_map->attr_name_ = attr_name;
+ }
+
+ uint32_t index = key->AttrRegistryIndex();
+ if (op_map->data_.size() <= index) {
+ op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
+ }
+ std::pair<TVMRetValue, int>& p = op_map->data_[index];
+ CHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName()
+ << " is already registered with same plevel=" << plevel;
+ CHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name
+ << " of operator " << key->AttrRegistryName();
+ if (p.second < plevel && value.type_code() != kTVMNullptr) {
+ op_map->data_[index] = std::make_pair(value, plevel);
+ }
+ }
+
+ /*!
+ * \brief Reset an attribute table entry.
+ * \param attr_name The name of the attribute.
+ * \param key The key to the attribute table.
+ */
+ void ResetAttr(const String& attr_name, const KeyType& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto& op_map = attrs_[attr_name];
+ if (op_map == nullptr) {
+ return;
+ }
+ uint32_t index = key->AttrRegistryIndex();
+ if (op_map->data_.size() > index) {
+ op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
+ }
+ }
+
+ /*!
+ * \brief Get an internal attribute map.
+ * \param attr_name The name of the attribute.
+ * \return The result attribute map.
+ */
+ const AttrRegistryMapContainerMap<KeyType>& GetAttrMap(const String& attr_name) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ auto it = attrs_.find(attr_name);
+ if (it == attrs_.end()) {
+ LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered";
+ }
+ return *it->second.get();
+ }
+
+ /*!
+ * \brief Check of attribute has been registered.
+ * \param attr_name The name of the attribute.
+ * \return The check result.
+ */
+ bool HasAttrMap(const String& attr_name) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return attrs_.count(attr_name);
+ }
+
+ /*!
+ * \return a global singleton of the registry.
+ */
+ static TSelf* Global() {
+ static TSelf* inst = new TSelf();
+ return inst;
+ }
+
+ private:
+ // mutex to avoid registration from multiple threads.
+ std::mutex mutex_;
+ // entries in the registry
+ std::vector<std::unique_ptr<EntryType>> entries_;
+ // map from name to entries.
+ std::unordered_map<String, EntryType*> entry_map_;
+ // storage of additional attribute table.
+ std::unordered_map<String, std::unique_ptr<AttrRegistryMapContainerMap<KeyType>>> attrs_;
+};
+
+} // namespace tvm
+#endif // TVM_NODE_ATTR_REGISTRY_H_
private:
void VisitExpr_(const CallNode* call_node) final {
- static const auto& fprep = Op::GetAttr<FMacCount>("FMacCount");
+ static const auto& fprep = Op::GetAttrMap<FMacCount>("FMacCount");
auto f = fprep.get(call_node->op, nullptr);
if (f != nullptr) count_ += f(GetRef<Call>(call_node));
ExprVisitor::VisitExpr_(call_node);
TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
bool IsDataDependant(const CallNode* call) {
- static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+ static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
Op op = Downcast<Op>(call->op);
if (!tshape_data_dependant.count(op)) {
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
- static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
CHECK(flower_call) << "relay.backend.lower_call is not registered.";
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
- static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
- static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+ static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
+ static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
CHECK(data_dependants_.empty() || !data_dependants_.back())
for (auto kv : attributes) {
auto attr_name = kv.first;
auto attr_value = kv.second;
- auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+ auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
switch (op_map[op].type_code()) {
case kDLInt:
* \return The new Call after calling the packed func.
*/
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
- static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
+ static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
Op op = Downcast<Op>(ref_call->op);
Expr new_e;
Op op = Downcast<Op>(pre->op);
CHECK(op.defined());
for (const auto& target : this->targets_) {
- if (!Op::HasAttr("target." + std::string(target))) {
+ if (!Op::HasAttrMap("target." + std::string(target))) {
continue;
}
- auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
+ auto fannotate = Op::GetAttrMap<FTVMAnnotateTarget>("target." + std::string(target));
if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) {
supported_targets.push_back(target);
}
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const CallNode* call) {
- static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
if (const OpNode* opnode = call->op.as<OpNode>()) {
auto pattern = fpattern[GetRef<Op>(opnode)];
// Create a branch starting from op.
Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
- auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+ auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always op
Branch branch{op};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
* \return The new Call after calling the packed func.
*/
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
- static auto fconvert_layout = Op::GetAttr<FTVMConvertOpLayout>("FTVMConvertOpLayout");
+ static auto fconvert_layout = Op::GetAttrMap<FTVMConvertOpLayout>("FTVMConvertOpLayout");
Op op = Downcast<Op>(ref_call->op);
Expr new_e;
explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
Expr VisitExpr_(const CallNode* call) final {
- static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+ static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr new_expr = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
}
Expr VisitExpr_(const CallNode* call) final {
- static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+ static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
ExprVisitor::VisitExpr_(call);
// function to be lazily invoked
auto flazy = [this, call]() {
- static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
+ static const auto& fprep = Op::GetAttrMap<FForwardPrep>("FScaleAxisForwardPrep");
// find the message send to this node.
auto it = message_.find(call);
Message out_message;
// Visit the expression.
void VisitExpr_(const CallNode* call) {
ExprVisitor::VisitExpr_(call);
- static const auto& fprep = Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
+ static const auto& fprep = Op::GetAttrMap<FBackwardPrep>("FScaleAxisBackwardPrep");
auto f = fprep.get(call->op, nullptr);
if (f == nullptr) return;
auto rit = ref_counter_.find(call);
};
Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) {
- static const auto& ftransform = Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
+ static const auto& ftransform = Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
auto f = ftransform.get(call_node->op, nullptr);
if (f != nullptr) {
const Call call = GetRef<Call>(call_node);
class ForwardRewriter : private MixedModeMutator {
public:
- ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
+ ForwardRewriter(const OpAttrMap<FForwardRewrite>* rewrite_map,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger)
: rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
private:
// The rewrite rule.
- const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
+ const OpAttrMap<FForwardRewrite>* rewrite_map_{nullptr};
const FForwardRewrite* rewrite_func_{nullptr};
// The context.const
std::function<ObjectRef(const Call&)> fcontext_{nullptr};
Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name,
std::function<ObjectRef(const Call&)> fcontext,
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
- auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
+ auto rewrite_map = Op::GetAttrMap<FForwardRewrite>(rewrite_map_name);
return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
}
void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
- static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+ static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
// Now we set the pattern of this call.
//
// If we see a call mentioning an operator we should mark it with its
// Transform calls.
Expr VisitExpr_(const CallNode* call) {
if (call->op.as<OpNode>()) {
- static auto fnoncomputational = Op::GetAttr<TNonComputational>("TNonComputational");
+ static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
return ExprMutator::VisitExpr_(call);
};
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
- const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+ const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::vector<std::function<void(LetList* ll)>> backprop_actions;
// we assume no closure so no need for lexical scoping
std::unordered_map<Var, ADValue, ObjectHash, ObjectEqual> env;
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
- const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+ const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
- const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+ const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
std::unordered_set<std::string> op_names;
void VisitExpr_(const OpNode* op) final {
static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
- static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
+ static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
// Get the new_call node without any changes to current call node.
Call new_call = Downcast<Call>(post);
- // Check if the string is registered in the OpRegistry.
- if (!Op::HasAttr(legalize_map_attr_name_)) {
+ // Check if the string is registered.
+ if (!Op::HasAttrMap(legalize_map_attr_name_)) {
return post;
}
// Collect the registered legalize function.
- auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
+ auto fop_legalize = Op::GetAttrMap<FTVMLegalize>(legalize_map_attr_name_);
auto call_op = call_node->op;
if (call_op.as<OpNode>()) {
Op op = Downcast<Op>(call_node->op);
enum struct MatchStatus { Match, NoMatch, Unknown };
bool StatefulOp(const Expr& e) {
- static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+ static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
struct StatefulOpVisitor : ExprVisitor {
bool stateful = false;
void VisitExpr_(const OpNode* op) {
TVM_REGISTER_GLOBAL("relay.backend.lower_call")
.set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
const Target& target) {
- static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
+ static auto fstrategy = Op::GetAttrMap<relay::FTVMStrategy>("FTVMStrategy");
Op op = Downcast<Op>(call->op);
auto out_type = call->checked_type();
OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);