From a8e44710c6472a2ee5cb66283c7f5e77f4e4ca0d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 17 May 2020 15:43:11 -0700 Subject: [PATCH] [REFACTOR][IR] Streamline ir/op Registry (#5609) * [REFACTOR][IR] Streamline ir/op Registry This PR refactors the attrregistry mechanism in the ir/op into a separate common base. The common base will provide a foundation for other attr related registries such as target and pass. We also streamlines the terminology of the registry API. - Use AttrMap for the column maps returned by the registry - Use RegEntry to refer to the registry entry. * Address review comments --- docs/dev/relay_add_pass.rst | 2 +- include/tvm/ir/op.h | 237 ++++++----------------- include/tvm/node/attr_registry_map.h | 132 +++++++++++++ src/ir/op.cc | 111 +++-------- src/node/attr_registry.h | 181 +++++++++++++++++ src/relay/analysis/mac_count.cc | 2 +- src/relay/analysis/util.cc | 2 +- src/relay/backend/compile_engine.cc | 6 +- src/relay/ir/dataflow_matcher.cc | 2 +- src/relay/transforms/alter_op_layout.cc | 2 +- src/relay/transforms/annotate_target.cc | 4 +- src/relay/transforms/canonicalize_cast.cc | 2 +- src/relay/transforms/combine_parallel_op.cc | 2 +- src/relay/transforms/convert_layout.cc | 2 +- src/relay/transforms/eliminate_common_subexpr.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/fold_scale_axis.cc | 6 +- src/relay/transforms/forward_rewrite.cc | 6 +- src/relay/transforms/fuse_ops.cc | 4 +- src/relay/transforms/gradient.cc | 6 +- src/relay/transforms/infer_layout_util.h | 2 +- src/relay/transforms/legalize.cc | 6 +- src/relay/transforms/partial_eval.cc | 2 +- tests/cpp/relay_build_module_test.cc | 2 +- 24 files changed, 432 insertions(+), 293 deletions(-) create mode 100644 include/tvm/node/attr_registry_map.h create mode 100644 src/node/attr_registry.h diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 8a6f8be..3eb9586 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -292,7 +292,7 @@ pointed to by ``op->index``. The reason we need to check is because .. code:: c Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); Expr res = ExprMutator::VisitExpr_(call); call = res.as(); // We don't constant fold function with zero arguments. diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index aeda4fa..f86aeba 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -39,10 +40,8 @@ namespace tvm { // forward declare name. -template -class OpMap; -class GenericOpMap; -class OpRegistry; +template +class OpAttrMap; // TODO(tvm-team): migrate low-level intrinsics to use Op /*! @@ -126,9 +125,18 @@ class OpNode : public RelayExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); private: + /*! \return the internal attr registry index. */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief repr to be printed in registry*/ + std::string AttrRegistryName() const { return name; } + // friend class - friend class GenericOpMap; - friend class OpRegistry; + template + friend class AttrRegistryMapContainerMap; + template + friend class AttrRegistry; + friend class OpRegEntry; + friend bool IsPrimitiveOp(const RelayExpr&); // Program internal unique index of operator. // Used to help index the program. @@ -166,19 +174,19 @@ class Op : public RelayExpr { inline const OpNode* operator->() const; /*! * \brief Get additional registered attribute about operators. - * If nothing has been registered, an empty OpMap will be returned. + * If nothing has been registered, an empty OpAttrMap will be returned. * \param attr_name The name of the attribute. - * \return An OpMap of specified attr_name. + * \return An OpAttrMap of specified attr_name. * \tparam ValueType The type of the attribute. */ template - inline static OpMap GetAttr(const std::string& attr_name); + inline static OpAttrMap GetAttrMap(const std::string& attr_name); /*! - * \brief Checks if an attr is present in the registry. + * \brief Checks if an attr map is present in the registry. * \param attr_name The name of the attribute. * \return bool True if the attr is present. */ - inline static bool HasAttr(const std::string& attr_name); + TVM_DLL static bool HasAttrMap(const String& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. @@ -194,22 +202,16 @@ class Op : public RelayExpr { /*! * \brief Get generic attrmap given attr name * \param key The attribute key - * \return reference to GenericOpMap + * \return The attr map. */ - TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key); - /*! - * \brief Checks if the key is present in the registry - * \param key The attribute key - * \return bool True if the key is present - */ - TVM_DLL static bool HasGenericAttr(const String& key); + TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const String& key); }; /*! * \brief Helper structure to register operators * \sa TVM_REGISTER_OP */ -class OpRegistry { +class OpRegEntry { public: /*! \return the operator */ const Op& op() const { return op_; } @@ -219,7 +221,7 @@ class OpRegistry { * \param descr the description string. * \return reference to self. */ - inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + inline OpRegEntry& describe(const std::string& descr); // NOLINT(*) /*! * \brief Add argument information to the function. * \param name Name of the argument. @@ -227,7 +229,7 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string& name, const std::string& type, + inline OpRegEntry& add_argument(const std::string& name, const std::string& type, const std::string& description); /*! * \brief Attach the type function corresponding to the return type. @@ -236,7 +238,7 @@ class OpRegistry { * relation on variables. * \return reference to self. */ - inline OpRegistry& add_type_rel( + inline OpRegEntry& add_type_rel( const std::string& rel_name, runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func); @@ -246,19 +248,19 @@ class OpRegistry { * \return reference to self. */ template - inline OpRegistry& set_attrs_type(); + inline OpRegEntry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. * \return reference to self. */ - inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*) /*! * \brief Set the support level of op. * \param level The support level. * \return reference to self. */ - inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -273,7 +275,7 @@ class OpRegistry { * \tparam ValueType The type of the value to be set. */ template - inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*) const ValueType& value, int plevel = 10); /*! @@ -283,104 +285,42 @@ class OpRegistry { inline void reset_attr(const std::string& attr_name); // set the name of the op to be the same as registry - inline OpRegistry& set_name() { // NOLINT(*) + inline OpRegEntry& set_name() { // NOLINT(*) if (get()->name.length() == 0) { get()->name = name; } return *this; } - /*! \return The global single registry */ - TVM_DLL static ::dmlc::Registry* Registry(); + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static OpRegEntry& RegisterOrGet(const String& name); private: - friend class ::dmlc::Registry; + template + friend class AttrRegistry; // the name std::string name; /*! \brief The operator */ Op op_; // private constructor - TVM_DLL OpRegistry(); + TVM_DLL OpRegEntry(uint32_t reg_index); // return internal pointer to op. inline OpNode* get(); - // update the attribute OpMap - + // update the attribute OpAttrMap TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel); }; /*! - * \brief Generic map to store additional information of Op. - */ -class GenericOpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline const runtime::TVMRetValue& operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const RelayExpr& expr, ValueType def_value) const; - - private: - friend class OpRegistry; - // the attribute field. - std::string attr_name_; - // internal data - std::vector > data_; - // The value - GenericOpMap() = default; -}; - -/*! * \brief Map used to store meta-information about Op. * \tparam ValueType The type of the value stored in map. */ template -class OpMap { +class OpAttrMap : public AttrRegistryMap { public: /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline ValueType operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - */ - inline ValueType get(const Op& op, ValueType def_value) const; - /*! * \brief get the corresponding value element at op with default value. * \param expr The key to the map * \param def_value The default value when the key does not exist @@ -389,12 +329,15 @@ class OpMap { */ inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + using TParent = AttrRegistryMap; + using TParent::count; + using TParent::get; + using TParent::operator[]; + private: friend class Op; // constructor - explicit OpMap(const GenericOpMap& map) : map_(map) {} - /*! \brief The internal map field */ - const GenericOpMap& map_; + explicit OpAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; #define TVM_STRINGIZE_DETAIL(x) #x @@ -406,7 +349,7 @@ class OpMap { #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) // internal macros to make -#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op +#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op /*! * \def TVM_REGISTER_OP @@ -425,26 +368,24 @@ class OpMap { */ #define TVM_REGISTER_OP(OpName) \ TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name() + ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name() // implementations inline const OpNode* Op::operator->() const { return static_cast(get()); } template -inline OpMap Op::GetAttr(const std::string& key) { - return OpMap(Op::GetGenericAttr(key)); +inline OpAttrMap Op::GetAttrMap(const std::string& key) { + return OpAttrMap(Op::GetAttrMapContainer(key)); } -inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); } - -inline OpNode* OpRegistry::get() { return const_cast(op_.operator->()); } +inline OpNode* OpRegEntry::get() { return const_cast(op_.operator->()); } -inline OpRegistry& OpRegistry::describe(const std::string& descr) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type, +inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, const std::string& description) { auto n = make_object(); n->name = name; @@ -454,7 +395,7 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std:: return *this; } -inline OpRegistry& OpRegistry::add_type_rel( +inline OpRegEntry& OpRegEntry::add_type_rel( const std::string& rel_name, runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func) { @@ -508,25 +449,25 @@ inline OpRegistry& OpRegistry::add_type_rel( return *this; } -inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; } template -inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) get()->attrs_type_key = AttrsType::_type_key; get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } -inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; } template -inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; runtime::TVMRetValue rv; @@ -535,70 +476,18 @@ inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) return *this; } -// member functions of OpMap -inline int GenericOpMap::count(const Op& op) const { - if (op.defined()) { - const uint32_t idx = op->index_; - return idx < data_.size() ? (data_[idx].second != 0) : 0; - } else { - return 0; - } -} - -inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; - return data_[idx].first; -} - -template -inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } -} +// member functions of OpAttrMap template -inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const { +inline ValueType OpAttrMap::get(const RelayExpr& expr, ValueType def_value) const { CHECK(expr.defined()); if (const OpNode* op = expr.as()) { - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } + return this->map_.get(GetRef(op), def_value); } else { - return value; + return def_value; } } -template -inline int OpMap::count(const Op& op) const { - return map_.count(op); -} - -template -inline ValueType OpMap::operator[](const Op& op) const { - return map_[op]; -} - -template -inline ValueType OpMap::get(const Op& op, ValueType def_value) const { - return map_.get(op, def_value); -} - -template -inline ValueType OpMap::get(const RelayExpr& expr, ValueType def_value) const { - return map_.get(expr, def_value); -} - /*! * \brief Check that an expression is a "primitive operator". * diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h new file mode 100644 index 0000000..748b3a8 --- /dev/null +++ b/include/tvm/node/attr_registry_map.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/attr_registry_map.h + * \brief Attribute map used in registry. + */ +#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ +#define TVM_NODE_ATTR_REGISTRY_MAP_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief Generic attribute map. + * \tparam KeyType the type of the key. + */ +template +class AttrRegistryMapContainerMap { + public: + /*! + * \brief Check if the map has key. + * \param key The key to the map + * \return 1 if key is contained in map, 0 otherwise. + */ + int count(const KeyType& key) const { + if (key.defined()) { + const uint32_t idx = key->AttrRegistryIndex(); + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } + } + /*! + * \brief get the corresponding value element at key. + * \param key The key to the map + * \return the const reference to the content value. + */ + const runtime::TVMRetValue& operator[](const KeyType& key) const { + CHECK(key.defined()); + const uint32_t idx = key->AttrRegistryIndex(); + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for " << key->name; + return data_[idx].first; + } + /*! + * \brief get the corresponding value element at key with default value. + * \param key The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + ValueType get(const KeyType& key, ValueType def_value) const { + CHECK(key.defined()); + const uint32_t idx = key->AttrRegistryIndex(); + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return def_value; + } + } + + private: + /*! \brief The name of the attr field */ + String attr_name_; + /*! \brief The internal data. */ + std::vector> data_; + /*! \brief The constructor */ + AttrRegistryMapContainerMap() = default; + template + friend class AttrRegistry; + friend class OpRegEntry; +}; + +/*! + * \brief Map used to store meta-data. + * \tparam KeyType The type of the key + * \tparam ValueType The type of the value stored in map. + */ +template +class AttrRegistryMap { + public: + /*! + * \brief constructor + * \param map The internal map. + */ + explicit AttrRegistryMap(const AttrRegistryMapContainerMap& map) : map_(map) {} + /*! + * \brief Check if the map has op as key. + * \param key The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + int count(const KeyType& key) const { return map_.count(key); } + /*! + * \brief get the corresponding value element at key. + * \param key The key to the map + * \return the const reference to the content value. + */ + ValueType operator[](const KeyType& key) const { return map_[key]; } + /*! + * \brief get the corresponding value element at key with default value. + * \param key The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); } + + protected: + /*! \brief The internal map field */ + const AttrRegistryMapContainerMap& map_; +}; + +} // namespace tvm +#endif // TVM_NODE_ATTR_REGISTRY_MAP_H_ diff --git a/src/ir/op.cc b/src/ir/op.cc index 3a6bcbc..b81e358 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -28,12 +28,8 @@ #include #include -#include -namespace dmlc { -// enable registry -DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); -} // namespace dmlc +#include "../node/attr_registry.h" namespace tvm { @@ -41,104 +37,45 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -::dmlc::Registry* OpRegistry::Registry() { return ::dmlc::Registry::Get(); } - -// single manager of operator information. -struct OpManager { - // mutex to avoid registration from multiple threads. - std::mutex mutex; - // global operator counter - std::atomic op_counter{0}; - // storage of additional attribute table. - std::unordered_map> attr; - // frontend functions - std::vector frontend_funcs; - // get singleton of the op manager - static OpManager* Global() { - static OpManager* inst = new OpManager(); - return inst; - } -}; +using OpRegistry = AttrRegistry; // find operator by name const Op& Op::Get(const String& name) { - const OpRegistry* reg = dmlc::Registry::Find(name); + const OpRegEntry* reg = OpRegistry::Global()->Get(name); CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); } -OpRegistry::OpRegistry() { - OpManager* mgr = OpManager::Global(); +OpRegEntry::OpRegEntry(uint32_t reg_index) { ObjectPtr n = make_object(); - n->index_ = mgr->op_counter++; + n->index_ = reg_index; op_ = Op(n); } +OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) { + return OpRegistry::Global()->RegisterOrGet(name); +} + // Get attribute map by key -const GenericOpMap& Op::GetGenericAttr(const String& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - auto it = mgr->attr.find(key); - if (it == mgr->attr.end()) { - LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; - } - return *it->second.get(); +const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const String& attr_name) { + return OpRegistry::Global()->GetAttrMap(attr_name); } // Check if a key is present in the registry. -bool Op::HasGenericAttr(const String& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - auto it = mgr->attr.find(key); - if (it == mgr->attr.end()) { - return false; - } - return true; -} +bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); } -// Resets attr of the OpMap. -void OpRegistry::reset_attr(const std::string& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - std::unique_ptr& op_map = mgr->attr[key]; - if (op_map == nullptr) { - return; - } - uint32_t index = op_->index_; - if (op_map->data_.size() > index) { - op_map->data_[index] = std::make_pair(TVMRetValue(), 0); - } +// Resets attr of the OpAttrMap. +void OpRegEntry::reset_attr(const std::string& attr_name) { + OpRegistry::Global()->ResetAttr(attr_name, op_); } -void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - std::unique_ptr& op_map = mgr->attr[key]; - if (op_map == nullptr) { - op_map.reset(new GenericOpMap()); - op_map->attr_name_ = key; - } - uint32_t index = op_->index_; - if (op_map->data_.size() <= index) { - op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); - } - std::pair& p = op_map->data_[index]; - CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - CHECK(value.type_code() != kTVMNullptr) - << "Registered packed_func is Null for " << key << " of operator " << this->name; - if (p.second < plevel && value.type_code() != kTVMNullptr) { - op_map->data_[index] = std::make_pair(value, plevel); - } +void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { + OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); } // Frontend APIs TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() { - Array ret; - for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(name); - } - return ret; + return OpRegistry::Global()->ListAllNames(); }); TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op { @@ -148,7 +85,7 @@ TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op { TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { Op op = args[0]; std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); + auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { *rv = op_map[op]; } @@ -159,14 +96,14 @@ TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue std::string attr_name = args[1]; runtime::TVMArgValue value = args[2]; int plevel = args[3]; - auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }); TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { Op op = args[0]; std::string attr_name = args[1]; - auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }); @@ -175,7 +112,7 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* std::string attr_key = args[1]; runtime::TVMArgValue value = args[2]; int plevel = args[3]; - auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); + auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties if (attr_key == "num_inputs" && plevel > 128) { reg.set_num_inputs(value); @@ -187,8 +124,8 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* // do an eager copy of the PackedFunc PackedFunc f = args[2]; // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); + auto* fcopy = new PackedFunc(f); + reg.set_attr(attr_key, *fcopy, plevel); } else { reg.set_attr(attr_key, args[2], plevel); } diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h new file mode 100644 index 0000000..9cc5b4d --- /dev/null +++ b/src/node/attr_registry.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/node/attr_registry.h + * \brief Common global registry for objects that also have additional attrs. + */ +#ifndef TVM_NODE_ATTR_REGISTRY_H_ +#define TVM_NODE_ATTR_REGISTRY_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \breif Implementation of registry with attributes. + * + * \tparam EntryType Tye type of the registry entry. + * \tparam KeyType The actual key that is used to lookup the attributes. + * each entry has a corresponding key by default. + */ +template +class AttrRegistry { + public: + using TSelf = AttrRegistry; + /*! + * \brief Get an entry from the registry. + * \param name The name of the item. + * \return The corresponding entry. + */ + const EntryType* Get(const String& name) const { + auto it = entry_map_.find(name); + if (it != entry_map_.end()) return it->second; + return nullptr; + } + + /*! + * \brief Get an entry or register a new one. + * \param name The name of the item. + * \return The corresponding entry. + */ + EntryType& RegisterOrGet(const String& name) { + auto it = entry_map_.find(name); + if (it != entry_map_.end()) return *it->second; + uint32_t registry_index = static_cast(entries_.size()); + auto entry = std::unique_ptr(new EntryType(registry_index)); + auto* eptr = entry.get(); + eptr->name = name; + entry_map_[name] = eptr; + entries_.emplace_back(std::move(entry)); + return *eptr; + } + + /*! + * \brief List all the entry names in the registry. + * \return The entry names. + */ + Array ListAllNames() const { + Array names; + for (const auto& kv : entry_map_) { + names.push_back(kv.first); + } + return names; + } + + /*! + * \brief Update the attribute stable. + * \param attr_name The name of the attribute. + * \param key The key to the attribute table. + * \param value The value to be set. + * \param plevel The support level. + */ + void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value, + int plevel) { + using runtime::TVMRetValue; + std::lock_guard lock(mutex_); + auto& op_map = attrs_[attr_name]; + if (op_map == nullptr) { + op_map.reset(new AttrRegistryMapContainerMap()); + op_map->attr_name_ = attr_name; + } + + uint32_t index = key->AttrRegistryIndex(); + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); + } + std::pair& p = op_map->data_[index]; + CHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName() + << " is already registered with same plevel=" << plevel; + CHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name + << " of operator " << key->AttrRegistryName(); + if (p.second < plevel && value.type_code() != kTVMNullptr) { + op_map->data_[index] = std::make_pair(value, plevel); + } + } + + /*! + * \brief Reset an attribute table entry. + * \param attr_name The name of the attribute. + * \param key The key to the attribute table. + */ + void ResetAttr(const String& attr_name, const KeyType& key) { + std::lock_guard lock(mutex_); + auto& op_map = attrs_[attr_name]; + if (op_map == nullptr) { + return; + } + uint32_t index = key->AttrRegistryIndex(); + if (op_map->data_.size() > index) { + op_map->data_[index] = std::make_pair(TVMRetValue(), 0); + } + } + + /*! + * \brief Get an internal attribute map. + * \param attr_name The name of the attribute. + * \return The result attribute map. + */ + const AttrRegistryMapContainerMap& GetAttrMap(const String& attr_name) { + std::lock_guard lock(mutex_); + auto it = attrs_.find(attr_name); + if (it == attrs_.end()) { + LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered"; + } + return *it->second.get(); + } + + /*! + * \brief Check of attribute has been registered. + * \param attr_name The name of the attribute. + * \return The check result. + */ + bool HasAttrMap(const String& attr_name) { + std::lock_guard lock(mutex_); + return attrs_.count(attr_name); + } + + /*! + * \return a global singleton of the registry. + */ + static TSelf* Global() { + static TSelf* inst = new TSelf(); + return inst; + } + + private: + // mutex to avoid registration from multiple threads. + std::mutex mutex_; + // entries in the registry + std::vector> entries_; + // map from name to entries. + std::unordered_map entry_map_; + // storage of additional attribute table. + std::unordered_map>> attrs_; +}; + +} // namespace tvm +#endif // TVM_NODE_ATTR_REGISTRY_H_ diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index 882bba9..d2e62b7 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -178,7 +178,7 @@ class MacCounter : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { - static const auto& fprep = Op::GetAttr("FMacCount"); + static const auto& fprep = Op::GetAttrMap("FMacCount"); auto f = fprep.get(call_node->op, nullptr); if (f != nullptr) count_ += f(GetRef(call_node)); ExprVisitor::VisitExpr_(call_node); diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index af23836..6d246c0 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -436,7 +436,7 @@ bool IsDynamic(const Type& ty) { TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); bool IsDataDependant(const CallNode* call) { - static auto tshape_data_dependant = Op::GetAttr("TShapeDataDependant"); + static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); Op op = Downcast(call->op); if (!tshape_data_dependant.count(op)) { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 12a5add..f143479 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -191,7 +191,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -454,8 +454,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttr("FShapeFunc"); - static auto tshape_data_dependant = Op::GetAttr("TShapeDataDependant"); + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); CHECK(data_dependants_.empty() || !data_dependants_.back()) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 81fc4f0..7c70f32 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -109,7 +109,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons for (auto kv : attributes) { auto attr_name = kv.first; auto attr_value = kv.second; - auto op_map = Op::GetAttr(attr_name); + auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { switch (op_map[op].type_code()) { case kDLInt: diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 7b91e8c..3d242cd 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -72,7 +72,7 @@ class AlterTransformMemorizer : public TransformMemorizer { * \return The new Call after calling the packed func. */ Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { - static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + static auto falter_layout = Op::GetAttrMap("FTVMAlterOpLayout"); Op op = Downcast(ref_call->op); Expr new_e; diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 3635947..0d97005 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -145,10 +145,10 @@ class AnnotateTargetRewriter : public ExprRewriter { Op op = Downcast(pre->op); CHECK(op.defined()); for (const auto& target : this->targets_) { - if (!Op::HasAttr("target." + std::string(target))) { + if (!Op::HasAttrMap("target." + std::string(target))) { continue; } - auto fannotate = Op::GetAttr("target." + std::string(target)); + auto fannotate = Op::GetAttrMap("target." + std::string(target)); if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) { supported_targets.push_back(target); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index f478107..055ab14 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -66,7 +66,7 @@ class CastCanonicalizer : public ExprMutator { CastCanonicalizer() : cast_op_(Op::Get("cast")) {} Expr VisitExpr_(const CallNode* call) { - static auto fpattern = Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); if (const OpNode* opnode = call->op.as()) { auto pattern = fpattern[GetRef(opnode)]; diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc index 854a1ae..7ca2ce8 100644 --- a/src/relay/transforms/combine_parallel_op.cc +++ b/src/relay/transforms/combine_parallel_op.cc @@ -81,7 +81,7 @@ std::vector BranchGroupFinder::Find(const Expr& expr) { // Create a branch starting from op. Branch BranchGroupFinder::CreateBranch(const CallNode* op) { - auto fpattern = Op::GetAttr("TOpPattern"); + auto fpattern = Op::GetAttrMap("TOpPattern"); // each branch has at least one element, the first element is always op Branch branch{op}; auto it = children_map_.find(GetRef(branch.back())); diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index 7d42125..4a18925 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -82,7 +82,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { * \return The new Call after calling the packed func. */ Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { - static auto fconvert_layout = Op::GetAttr("FTVMConvertOpLayout"); + static auto fconvert_layout = Op::GetAttrMap("FTVMConvertOpLayout"); Op op = Downcast(ref_call->op); Expr new_e; diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 2861f32..27e5344 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -42,7 +42,7 @@ class CommonSubexprEliminator : public ExprMutator { explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); Expr new_expr = ExprMutator::VisitExpr_(call); const CallNode* new_call = new_expr.as(); CHECK(new_call); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 70df0ed..e2ab35b 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -104,7 +104,7 @@ class ConstantFolder : public ExprMutator { } Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 4083d08..a3765f3 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -255,7 +255,7 @@ class ForwardPrep : private ExprVisitor { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { - static const auto& fprep = Op::GetAttr("FScaleAxisForwardPrep"); + static const auto& fprep = Op::GetAttrMap("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; @@ -625,7 +625,7 @@ class BackwardPrep : private ExprVisitor { // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); - static const auto& fprep = Op::GetAttr("FScaleAxisBackwardPrep"); + static const auto& fprep = Op::GetAttrMap("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); @@ -727,7 +727,7 @@ class BackwardTransformer : public ObjectRef { }; Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); + static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { const Call call = GetRef(call_node); diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index 226b338..d872116 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -53,7 +53,7 @@ class TempRealizer : private MixedModeMutator { class ForwardRewriter : private MixedModeMutator { public: - ForwardRewriter(const OpMap* rewrite_map, + ForwardRewriter(const OpAttrMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} @@ -73,7 +73,7 @@ class ForwardRewriter : private MixedModeMutator { private: // The rewrite rule. - const OpMap* rewrite_map_{nullptr}; + const OpAttrMap* rewrite_map_{nullptr}; const FForwardRewrite* rewrite_func_{nullptr}; // The context.const std::function fcontext_{nullptr}; @@ -175,7 +175,7 @@ class ForwardRewriter : private MixedModeMutator { Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, std::function fcontext, std::function fmulti_ref_trigger) { - auto rewrite_map = Op::GetAttr(rewrite_map_name); + auto rewrite_map = Op::GetAttrMap(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 054244d..01f1eee 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -226,7 +226,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); - static auto fpattern = Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // // If we see a call mentioning an operator we should mark it with its @@ -824,7 +824,7 @@ class FuseMutator : private ExprMutator { // Transform calls. Expr VisitExpr_(const CallNode* call) { if (call->op.as()) { - static auto fnoncomputational = Op::GetAttr("TNonComputational"); + static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 67c62f3..afe5568 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -130,7 +130,7 @@ struct ADFunction : ADValueNode { }; struct FirstOrderReverseAD : ExprFunctor { - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping std::unordered_map env; @@ -354,7 +354,7 @@ struct ReverseAD : ExprMutator { Var bp; std::shared_ptr ad_vars; - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} @@ -456,7 +456,7 @@ struct ReverseAD : ExprMutator { bool MissingGrad(const Expr& e) { struct MGVisitor : ExprVisitor { - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); std::unordered_set op_names; void VisitExpr_(const OpNode* op) final { diff --git a/src/relay/transforms/infer_layout_util.h b/src/relay/transforms/infer_layout_util.h index e4df647..7ced51d 100644 --- a/src/relay/transforms/infer_layout_util.h +++ b/src/relay/transforms/infer_layout_util.h @@ -208,7 +208,7 @@ inline Array> BinaryBroadcastLayout(const Attrs& attrs, static inline std::tuple, Array, bool> InferCorrectLayouts( const Call& call, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); + static auto finfer_layout = Op::GetAttrMap("FInferCorrectLayout"); if (!call->op.as()) { return std::make_tuple<>(Array(nullptr), Array(nullptr), false); } diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 25919b4..c1f037f 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -44,13 +44,13 @@ class Legalizer : public ExprRewriter { // Get the new_call node without any changes to current call node. Call new_call = Downcast(post); - // Check if the string is registered in the OpRegistry. - if (!Op::HasAttr(legalize_map_attr_name_)) { + // Check if the string is registered. + if (!Op::HasAttrMap(legalize_map_attr_name_)) { return post; } // Collect the registered legalize function. - auto fop_legalize = Op::GetAttr(legalize_map_attr_name_); + auto fop_legalize = Op::GetAttrMap(legalize_map_attr_name_); auto call_op = call_node->op; if (call_op.as()) { Op op = Downcast(call_node->op); diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index a27cb79..3e27b87 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -511,7 +511,7 @@ PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object( enum struct MatchStatus { Match, NoMatch, Unknown }; bool StatefulOp(const Expr& e) { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); struct StatefulOpVisitor : ExprVisitor { bool stateful = false; void VisitExpr_(const OpNode* op) { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index d7ce0c0..3a2b2d9 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -59,7 +59,7 @@ TVM_REGISTER_GLOBAL("test.strategy") TVM_REGISTER_GLOBAL("relay.backend.lower_call") .set_body_typed([](const relay::Call& call, const Array& inputs, const Target& target) { - static auto fstrategy = Op::GetAttr("FTVMStrategy"); + static auto fstrategy = Op::GetAttrMap("FTVMStrategy"); Op op = Downcast(call->op); auto out_type = call->checked_type(); OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); -- 2.7.4