[REFACTOR][IR] Streamline ir/op Registry (#5609)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sun, 17 May 2020 22:43:11 +0000 (15:43 -0700)
committerGitHub <noreply@github.com>
Sun, 17 May 2020 22:43:11 +0000 (15:43 -0700)
* [REFACTOR][IR] Streamline ir/op Registry

This PR refactors the attrregistry mechanism in the ir/op into
a separate common base. The common base will provide a foundation
for other attr related registries such as target and pass.

We also streamlines the terminology of the registry API.

- Use AttrMap for the column maps returned by the registry
- Use RegEntry to refer to the registry entry.

* Address review comments

24 files changed:
docs/dev/relay_add_pass.rst
include/tvm/ir/op.h
include/tvm/node/attr_registry_map.h [new file with mode: 0644]
src/ir/op.cc
src/node/attr_registry.h [new file with mode: 0644]
src/relay/analysis/mac_count.cc
src/relay/analysis/util.cc
src/relay/backend/compile_engine.cc
src/relay/ir/dataflow_matcher.cc
src/relay/transforms/alter_op_layout.cc
src/relay/transforms/annotate_target.cc
src/relay/transforms/canonicalize_cast.cc
src/relay/transforms/combine_parallel_op.cc
src/relay/transforms/convert_layout.cc
src/relay/transforms/eliminate_common_subexpr.cc
src/relay/transforms/fold_constant.cc
src/relay/transforms/fold_scale_axis.cc
src/relay/transforms/forward_rewrite.cc
src/relay/transforms/fuse_ops.cc
src/relay/transforms/gradient.cc
src/relay/transforms/infer_layout_util.h
src/relay/transforms/legalize.cc
src/relay/transforms/partial_eval.cc
tests/cpp/relay_build_module_test.cc

index 8a6f8be..3eb9586 100644 (file)
@@ -292,7 +292,7 @@ pointed to by ``op->index``. The reason we need to check is because
 .. code:: c
 
     Expr VisitExpr_(const CallNode* call) final {
-      static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+      static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
       Expr res = ExprMutator::VisitExpr_(call);
       call = res.as<CallNode>();
       // We don't constant fold function with zero arguments.
index aeda4fa..f86aeba 100644 (file)
@@ -30,6 +30,7 @@
 #include <tvm/ir/expr.h>
 #include <tvm/ir/type.h>
 #include <tvm/ir/type_relation.h>
+#include <tvm/node/attr_registry_map.h>
 #include <tvm/runtime/registry.h>
 
 #include <string>
 namespace tvm {
 
 // forward declare name.
-template <typename ValueType>
-class OpMap;
-class GenericOpMap;
-class OpRegistry;
+template <typename>
+class OpAttrMap;
 
 // TODO(tvm-team): migrate low-level intrinsics to use Op
 /*!
@@ -126,9 +125,18 @@ class OpNode : public RelayExprNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
 
  private:
+  /*! \return the internal attr registry index. */
+  uint32_t AttrRegistryIndex() const { return index_; }
+  /*! \brief repr to be printed in registry*/
+  std::string AttrRegistryName() const { return name; }
+
   // friend class
-  friend class GenericOpMap;
-  friend class OpRegistry;
+  template <typename>
+  friend class AttrRegistryMapContainerMap;
+  template <typename, typename>
+  friend class AttrRegistry;
+  friend class OpRegEntry;
+
   friend bool IsPrimitiveOp(const RelayExpr&);
   // Program internal unique index of operator.
   // Used to help index the program.
@@ -166,19 +174,19 @@ class Op : public RelayExpr {
   inline const OpNode* operator->() const;
   /*!
    * \brief Get additional registered attribute about operators.
-   *  If nothing has been registered, an empty OpMap will be returned.
+   *  If nothing has been registered, an empty OpAttrMap will be returned.
    * \param attr_name The name of the attribute.
-   * \return An OpMap of specified attr_name.
+   * \return An OpAttrMap of specified attr_name.
    * \tparam ValueType The type of the attribute.
    */
   template <typename ValueType>
-  inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
+  inline static OpAttrMap<ValueType> GetAttrMap(const std::string& attr_name);
   /*!
-   * \brief Checks if an attr is present in the registry.
+   * \brief Checks if an attr map is present in the registry.
    * \param attr_name The name of the attribute.
    * \return bool True if the attr is present.
    */
-  inline static bool HasAttr(const std::string& attr_name);
+  TVM_DLL static bool HasAttrMap(const String& attr_name);
   /*!
    * \brief Get an Op for a given operator name.
    *  Will raise an error if the op has not been registered.
@@ -194,22 +202,16 @@ class Op : public RelayExpr {
   /*!
    * \brief Get generic attrmap given attr name
    * \param key The attribute key
-   * \return reference to GenericOpMap
+   * \return The attr map.
    */
-  TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
-  /*!
-   * \brief Checks if the key is present in the registry
-   * \param key The attribute key
-   * \return bool True if the key is present
-   */
-  TVM_DLL static bool HasGenericAttr(const String& key);
+  TVM_DLL static const AttrRegistryMapContainerMap<Op>& GetAttrMapContainer(const String& key);
 };
 
 /*!
  * \brief Helper structure to register operators
  * \sa TVM_REGISTER_OP
  */
-class OpRegistry {
+class OpRegEntry {
  public:
   /*! \return the operator */
   const Op& op() const { return op_; }
@@ -219,7 +221,7 @@ class OpRegistry {
    * \param descr the description string.
    * \return reference to self.
    */
-  inline OpRegistry& describe(const std::string& descr);  // NOLINT(*)
+  inline OpRegEntry& describe(const std::string& descr);  // NOLINT(*)
   /*!
    * \brief Add argument information to the function.
    * \param name Name of the argument.
@@ -227,7 +229,7 @@ class OpRegistry {
    * \param description Description of the argument.
    * \return reference to self.
    */
-  inline OpRegistry& add_argument(const std::string& name, const std::string& type,
+  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
                                   const std::string& description);
   /*!
    * \brief Attach the type function corresponding to the return type.
@@ -236,7 +238,7 @@ class OpRegistry {
    * relation on variables.
    * \return reference to self.
    */
-  inline OpRegistry& add_type_rel(
+  inline OpRegEntry& add_type_rel(
       const std::string& rel_name,
       runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
           type_rel_func);
@@ -246,19 +248,19 @@ class OpRegistry {
    * \return reference to self.
    */
   template <typename AttrsType>
-  inline OpRegistry& set_attrs_type();
+  inline OpRegEntry& set_attrs_type();
   /*!
    * \brief Set the num_inputs
    * \param n The number of inputs to be set.
    * \return reference to self.
    */
-  inline OpRegistry& set_num_inputs(int32_t n);  // NOLINT(*)
+  inline OpRegEntry& set_num_inputs(int32_t n);  // NOLINT(*)
   /*!
    * \brief Set the support level of op.
    * \param level The support level.
    * \return reference to self.
    */
-  inline OpRegistry& set_support_level(int32_t level);  // NOLINT(*)
+  inline OpRegEntry& set_support_level(int32_t level);  // NOLINT(*)
   /*!
    * \brief Register additional attributes to operator.
    * \param attr_name The name of the attribute.
@@ -273,7 +275,7 @@ class OpRegistry {
    * \tparam ValueType The type of the value to be set.
    */
   template <typename ValueType>
-  inline OpRegistry& set_attr(const std::string& attr_name,  // NOLINT(*)
+  inline OpRegEntry& set_attr(const std::string& attr_name,  // NOLINT(*)
                               const ValueType& value, int plevel = 10);
 
   /*!
@@ -283,104 +285,42 @@ class OpRegistry {
   inline void reset_attr(const std::string& attr_name);
 
   // set the name of the op to be the same as registry
-  inline OpRegistry& set_name() {  // NOLINT(*)
+  inline OpRegEntry& set_name() {  // NOLINT(*)
     if (get()->name.length() == 0) {
       get()->name = name;
     }
     return *this;
   }
-  /*! \return The global single registry */
-  TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
+  /*!
+   * \brief Register or get a new entry.
+   * \param name The name of the operator.
+   * \return the corresponding entry.
+   */
+  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
 
  private:
-  friend class ::dmlc::Registry<OpRegistry>;
+  template <typename, typename>
+  friend class AttrRegistry;
   // the name
   std::string name;
   /*! \brief The operator */
   Op op_;
   // private constructor
-  TVM_DLL OpRegistry();
+  TVM_DLL OpRegEntry(uint32_t reg_index);
   // return internal pointer to op.
   inline OpNode* get();
-  // update the attribute OpMap
-
+  // update the attribute OpAttrMap
   TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
 };
 
 /*!
- * \brief Generic map to store additional information of Op.
- */
-class GenericOpMap {
- public:
-  /*!
-   * \brief Check if the map has op as key.
-   * \param op The key to the map
-   * \return 1 if op is contained in map, 0 otherwise.
-   */
-  inline int count(const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op
-   * \param op The key to the map
-   * \return the const reference to the content value.
-   */
-  inline const runtime::TVMRetValue& operator[](const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param op The key to the map
-   * \param def_value The default value when the key does not exist.
-   * \return the const reference to the content value.
-   * \tparam ValueType The content value type.
-   */
-  template <typename ValueType>
-  inline ValueType get(const Op& op, ValueType def_value) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param expr The key to the map
-   * \param def_value The default value when the key does not exist
-   *         or if expr is not an Op.
-   * \return the const reference to the content value.
-   * \tparam ValueType The content value type.
-   */
-  template <typename ValueType>
-  inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
-
- private:
-  friend class OpRegistry;
-  // the attribute field.
-  std::string attr_name_;
-  // internal data
-  std::vector<std::pair<runtime::TVMRetValue, int> > data_;
-  // The value
-  GenericOpMap() = default;
-};
-
-/*!
  * \brief Map<Op,ValueType> used to store meta-information about Op.
  * \tparam ValueType The type of the value stored in map.
  */
 template <typename ValueType>
-class OpMap {
+class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
  public:
   /*!
-   * \brief Check if the map has op as key.
-   * \param op The key to the map
-   * \return 1 if op is contained in map, 0 otherwise.
-   */
-  inline int count(const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op
-   * \param op The key to the map
-   * \return the const reference to the content value.
-   */
-  inline ValueType operator[](const Op& op) const;
-  /*!
-   * \brief get the corresponding value element at op with default value.
-   * \param op The key to the map
-   * \param def_value The default value when the key does not exist.
-   * \return the const reference to the content value.
-   */
-  inline ValueType get(const Op& op, ValueType def_value) const;
-  /*!
    * \brief get the corresponding value element at op with default value.
    * \param expr The key to the map
    * \param def_value The default value when the key does not exist
@@ -389,12 +329,15 @@ class OpMap {
    */
   inline ValueType get(const RelayExpr& expr, ValueType def_value) const;
 
+  using TParent = AttrRegistryMap<Op, ValueType>;
+  using TParent::count;
+  using TParent::get;
+  using TParent::operator[];
+
  private:
   friend class Op;
   // constructor
-  explicit OpMap(const GenericOpMap& map) : map_(map) {}
-  /*! \brief The internal map field */
-  const GenericOpMap& map_;
+  explicit OpAttrMap(const AttrRegistryMapContainerMap<Op>& map) : TParent(map) {}
 };
 
 #define TVM_STRINGIZE_DETAIL(x) #x
@@ -406,7 +349,7 @@ class OpMap {
 #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
 
 // internal macros to make
-#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op
+#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
 
 /*!
  * \def TVM_REGISTER_OP
@@ -425,26 +368,24 @@ class OpMap {
  */
 #define TVM_REGISTER_OP(OpName)                          \
   TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
-      ::tvm::OpRegistry::Registry()->__REGISTER_OR_GET__(OpName).set_name()
+      ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
 
 // implementations
 inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }
 
 template <typename ValueType>
-inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
-  return OpMap<ValueType>(Op::GetGenericAttr(key));
+inline OpAttrMap<ValueType> Op::GetAttrMap(const std::string& key) {
+  return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
 }
 
-inline bool Op::HasAttr(const std::string& key) { return Op::HasGenericAttr(key); }
-
-inline OpNode* OpRegistry::get() { return const_cast<OpNode*>(op_.operator->()); }
+inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
 
-inline OpRegistry& OpRegistry::describe(const std::string& descr) {  // NOLINT(*)
+inline OpRegEntry& OpRegEntry::describe(const std::string& descr) {  // NOLINT(*)
   get()->description = descr;
   return *this;
 }
 
-inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::string& type,
+inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
                                             const std::string& description) {
   auto n = make_object<AttrFieldInfoNode>();
   n->name = name;
@@ -454,7 +395,7 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, const std::
   return *this;
 }
 
-inline OpRegistry& OpRegistry::add_type_rel(
+inline OpRegEntry& OpRegEntry::add_type_rel(
     const std::string& rel_name,
     runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
         type_rel_func) {
@@ -508,25 +449,25 @@ inline OpRegistry& OpRegistry::add_type_rel(
   return *this;
 }
 
-inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) {  // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) {  // NOLINT(*)
   get()->num_inputs = n;
   return *this;
 }
 
 template <typename AttrsType>
-inline OpRegistry& OpRegistry::set_attrs_type() {  // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_attrs_type() {  // NOLINT(*)
   get()->attrs_type_key = AttrsType::_type_key;
   get()->attrs_type_index = AttrsType::RuntimeTypeIndex();
   return *this;
 }
 
-inline OpRegistry& OpRegistry::set_support_level(int32_t n) {  // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) {  // NOLINT(*)
   get()->support_level = n;
   return *this;
 }
 
 template <typename ValueType>
-inline OpRegistry& OpRegistry::set_attr(  // NOLINT(*)
+inline OpRegEntry& OpRegEntry::set_attr(  // NOLINT(*)
     const std::string& attr_name, const ValueType& value, int plevel) {
   CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
   runtime::TVMRetValue rv;
@@ -535,70 +476,18 @@ inline OpRegistry& OpRegistry::set_attr(  // NOLINT(*)
   return *this;
 }
 
-// member functions of OpMap
-inline int GenericOpMap::count(const Op& op) const {
-  if (op.defined()) {
-    const uint32_t idx = op->index_;
-    return idx < data_.size() ? (data_[idx].second != 0) : 0;
-  } else {
-    return 0;
-  }
-}
-
-inline const runtime::TVMRetValue& GenericOpMap::operator[](const Op& op) const {
-  CHECK(op.defined());
-  const uint32_t idx = op->index_;
-  CHECK(idx < data_.size() && data_[idx].second != 0)
-      << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name;
-  return data_[idx].first;
-}
-
-template <typename ValueType>
-inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
-  CHECK(op.defined());
-  const uint32_t idx = op->index_;
-  if (idx < data_.size() && data_[idx].second != 0) {
-    return data_[idx].first;
-  } else {
-    return value;
-  }
-}
+// member functions of OpAttrMap
 
 template <typename ValueType>
-inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const {
+inline ValueType OpAttrMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
   CHECK(expr.defined());
   if (const OpNode* op = expr.as<OpNode>()) {
-    const uint32_t idx = op->index_;
-    if (idx < data_.size() && data_[idx].second != 0) {
-      return data_[idx].first;
-    } else {
-      return value;
-    }
+    return this->map_.get(GetRef<Op>(op), def_value);
   } else {
-    return value;
+    return def_value;
   }
 }
 
-template <typename ValueType>
-inline int OpMap<ValueType>::count(const Op& op) const {
-  return map_.count(op);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
-  return map_[op];
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const Op& op, ValueType def_value) const {
-  return map_.get<ValueType>(op, def_value);
-}
-
-template <typename ValueType>
-inline ValueType OpMap<ValueType>::get(const RelayExpr& expr, ValueType def_value) const {
-  return map_.get<ValueType>(expr, def_value);
-}
-
 /*!
  * \brief Check that an expression is a "primitive operator".
  *
diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h
new file mode 100644 (file)
index 0000000..748b3a8
--- /dev/null
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/attr_registry_map.h
+ * \brief Attribute map used in registry.
+ */
+#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_
+#define TVM_NODE_ATTR_REGISTRY_MAP_H_
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+
+/*!
+ * \brief Generic attribute map.
+ * \tparam KeyType the type of the key.
+ */
+template <typename KeyType>
+class AttrRegistryMapContainerMap {
+ public:
+  /*!
+   * \brief Check if the map has key.
+   * \param key The key to the map
+   * \return 1 if key is contained in map, 0 otherwise.
+   */
+  int count(const KeyType& key) const {
+    if (key.defined()) {
+      const uint32_t idx = key->AttrRegistryIndex();
+      return idx < data_.size() ? (data_[idx].second != 0) : 0;
+    } else {
+      return 0;
+    }
+  }
+  /*!
+   * \brief get the corresponding value element at key.
+   * \param key The key to the map
+   * \return the const reference to the content value.
+   */
+  const runtime::TVMRetValue& operator[](const KeyType& key) const {
+    CHECK(key.defined());
+    const uint32_t idx = key->AttrRegistryIndex();
+    CHECK(idx < data_.size() && data_[idx].second != 0)
+        << "Attribute " << attr_name_ << " has not been registered for " << key->name;
+    return data_[idx].first;
+  }
+  /*!
+   * \brief get the corresponding value element at key with default value.
+   * \param key The key to the map
+   * \param def_value The default value when the key does not exist.
+   * \return the const reference to the content value.
+   * \tparam ValueType The content value type.
+   */
+  template <typename ValueType>
+  ValueType get(const KeyType& key, ValueType def_value) const {
+    CHECK(key.defined());
+    const uint32_t idx = key->AttrRegistryIndex();
+    if (idx < data_.size() && data_[idx].second != 0) {
+      return data_[idx].first;
+    } else {
+      return def_value;
+    }
+  }
+
+ private:
+  /*! \brief The name of the attr field */
+  String attr_name_;
+  /*! \brief The internal data. */
+  std::vector<std::pair<runtime::TVMRetValue, int>> data_;
+  /*! \brief The constructor */
+  AttrRegistryMapContainerMap() = default;
+  template <typename, typename>
+  friend class AttrRegistry;
+  friend class OpRegEntry;
+};
+
+/*!
+ * \brief Map<Key, ValueType> used to store meta-data.
+ * \tparam KeyType The type of the key
+ * \tparam ValueType The type of the value stored in map.
+ */
+template <typename KeyType, typename ValueType>
+class AttrRegistryMap {
+ public:
+  /*!
+   * \brief constructor
+   * \param map The internal map.
+   */
+  explicit AttrRegistryMap(const AttrRegistryMapContainerMap<KeyType>& map) : map_(map) {}
+  /*!
+   * \brief Check if the map has op as key.
+   * \param key The key to the map
+   * \return 1 if op is contained in map, 0 otherwise.
+   */
+  int count(const KeyType& key) const { return map_.count(key); }
+  /*!
+   * \brief get the corresponding value element at key.
+   * \param key The key to the map
+   * \return the const reference to the content value.
+   */
+  ValueType operator[](const KeyType& key) const { return map_[key]; }
+  /*!
+   * \brief get the corresponding value element at key with default value.
+   * \param key The key to the map
+   * \param def_value The default value when the key does not exist.
+   * \return the const reference to the content value.
+   */
+  ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); }
+
+ protected:
+  /*! \brief The internal map field */
+  const AttrRegistryMapContainerMap<KeyType>& map_;
+};
+
+}  // namespace tvm
+#endif  // TVM_NODE_ATTR_REGISTRY_MAP_H_
index 3a6bcbc..b81e358 100644 (file)
 #include <tvm/runtime/packed_func.h>
 
 #include <memory>
-#include <mutex>
 
-namespace dmlc {
-// enable registry
-DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
-}  // namespace dmlc
+#include "../node/attr_registry.h"
 
 namespace tvm {
 
@@ -41,104 +37,45 @@ using runtime::PackedFunc;
 using runtime::TVMArgs;
 using runtime::TVMRetValue;
 
-::dmlc::Registry<OpRegistry>* OpRegistry::Registry() { return ::dmlc::Registry<OpRegistry>::Get(); }
-
-// single manager of operator information.
-struct OpManager {
-  // mutex to avoid registration from multiple threads.
-  std::mutex mutex;
-  // global operator counter
-  std::atomic<int> op_counter{0};
-  // storage of additional attribute table.
-  std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
-  // frontend functions
-  std::vector<PackedFunc*> frontend_funcs;
-  // get singleton of the op manager
-  static OpManager* Global() {
-    static OpManager* inst = new OpManager();
-    return inst;
-  }
-};
+using OpRegistry = AttrRegistry<OpRegEntry, Op>;
 
 // find operator by name
 const Op& Op::Get(const String& name) {
-  const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
+  const OpRegEntry* reg = OpRegistry::Global()->Get(name);
   CHECK(reg != nullptr) << "Operator " << name << " is not registered";
   return reg->op();
 }
 
-OpRegistry::OpRegistry() {
-  OpManager* mgr = OpManager::Global();
+OpRegEntry::OpRegEntry(uint32_t reg_index) {
   ObjectPtr<OpNode> n = make_object<OpNode>();
-  n->index_ = mgr->op_counter++;
+  n->index_ = reg_index;
   op_ = Op(n);
 }
 
+OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
+  return OpRegistry::Global()->RegisterOrGet(name);
+}
+
 // Get attribute map by key
-const GenericOpMap& Op::GetGenericAttr(const String& key) {
-  OpManager* mgr = OpManager::Global();
-  std::lock_guard<std::mutex> lock(mgr->mutex);
-  auto it = mgr->attr.find(key);
-  if (it == mgr->attr.end()) {
-    LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
-  }
-  return *it->second.get();
+const AttrRegistryMapContainerMap<Op>& Op::GetAttrMapContainer(const String& attr_name) {
+  return OpRegistry::Global()->GetAttrMap(attr_name);
 }
 
 // Check if a key is present in the registry.
-bool Op::HasGenericAttr(const String& key) {
-  OpManager* mgr = OpManager::Global();
-  std::lock_guard<std::mutex> lock(mgr->mutex);
-  auto it = mgr->attr.find(key);
-  if (it == mgr->attr.end()) {
-    return false;
-  }
-  return true;
-}
+bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); }
 
-// Resets attr of the OpMap.
-void OpRegistry::reset_attr(const std::string& key) {
-  OpManager* mgr = OpManager::Global();
-  std::lock_guard<std::mutex> lock(mgr->mutex);
-  std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
-  if (op_map == nullptr) {
-    return;
-  }
-  uint32_t index = op_->index_;
-  if (op_map->data_.size() > index) {
-    op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
-  }
+// Resets attr of the OpAttrMap.
+void OpRegEntry::reset_attr(const std::string& attr_name) {
+  OpRegistry::Global()->ResetAttr(attr_name, op_);
 }
 
-void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
-  OpManager* mgr = OpManager::Global();
-  std::lock_guard<std::mutex> lock(mgr->mutex);
-  std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
-  if (op_map == nullptr) {
-    op_map.reset(new GenericOpMap());
-    op_map->attr_name_ = key;
-  }
-  uint32_t index = op_->index_;
-  if (op_map->data_.size() <= index) {
-    op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
-  }
-  std::pair<TVMRetValue, int>& p = op_map->data_[index];
-  CHECK(p.second != plevel) << "Attribute " << key << " of operator " << this->name
-                            << " is already registered with same plevel=" << plevel;
-  CHECK(value.type_code() != kTVMNullptr)
-      << "Registered packed_func is Null for " << key << " of operator " << this->name;
-  if (p.second < plevel && value.type_code() != kTVMNullptr) {
-    op_map->data_[index] = std::make_pair(value, plevel);
-  }
+void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
+  OpRegistry::Global()->UpdateAttr(key, op_, value, plevel);
 }
 
 // Frontend APIs
 TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
-  Array<runtime::String> ret;
-  for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
-    ret.push_back(name);
-  }
-  return ret;
+  return OpRegistry::Global()->ListAllNames();
 });
 
 TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
@@ -148,7 +85,7 @@ TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
 TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
   Op op = args[0];
   std::string attr_name = args[1];
-  auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+  auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
   if (op_map.count(op)) {
     *rv = op_map[op];
   }
@@ -159,14 +96,14 @@ TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue
   std::string attr_name = args[1];
   runtime::TVMArgValue value = args[2];
   int plevel = args[3];
-  auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
+  auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
   reg.set_attr(attr_name, value, plevel);
 });
 
 TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) {
   Op op = args[0];
   std::string attr_name = args[1];
-  auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
+  auto& reg = OpRegistry::Global()->RegisterOrGet(op->name);
   reg.reset_attr(attr_name);
 });
 
@@ -175,7 +112,7 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue*
   std::string attr_key = args[1];
   runtime::TVMArgValue value = args[2];
   int plevel = args[3];
-  auto& reg = OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
+  auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
   // enable resgiteration and override of certain properties
   if (attr_key == "num_inputs" && plevel > 128) {
     reg.set_num_inputs(value);
@@ -187,8 +124,8 @@ TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue*
       // do an eager copy of the PackedFunc
       PackedFunc f = args[2];
       // If we get a function from frontend, avoid deleting it.
-      OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
-      reg.set_attr(attr_key, f, plevel);
+      auto* fcopy = new PackedFunc(f);
+      reg.set_attr(attr_key, *fcopy, plevel);
     } else {
       reg.set_attr(attr_key, args[2], plevel);
     }
diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h
new file mode 100644 (file)
index 0000000..9cc5b4d
--- /dev/null
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/node/attr_registry.h
+ * \brief Common global registry for objects that also have additional attrs.
+ */
+#ifndef TVM_NODE_ATTR_REGISTRY_H_
+#define TVM_NODE_ATTR_REGISTRY_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+
+/*!
+ * \breif Implementation of registry with attributes.
+ *
+ * \tparam EntryType Tye type of the registry entry.
+ * \tparam KeyType The actual key that is used to lookup the attributes.
+ *                 each entry has a corresponding key by default.
+ */
+template <typename EntryType, typename KeyType>
+class AttrRegistry {
+ public:
+  using TSelf = AttrRegistry<EntryType, KeyType>;
+  /*!
+   * \brief Get an entry from the registry.
+   * \param name The name of the item.
+   * \return The corresponding entry.
+   */
+  const EntryType* Get(const String& name) const {
+    auto it = entry_map_.find(name);
+    if (it != entry_map_.end()) return it->second;
+    return nullptr;
+  }
+
+  /*!
+   * \brief Get an entry or register a new one.
+   * \param name The name of the item.
+   * \return The corresponding entry.
+   */
+  EntryType& RegisterOrGet(const String& name) {
+    auto it = entry_map_.find(name);
+    if (it != entry_map_.end()) return *it->second;
+    uint32_t registry_index = static_cast<uint32_t>(entries_.size());
+    auto entry = std::unique_ptr<EntryType>(new EntryType(registry_index));
+    auto* eptr = entry.get();
+    eptr->name = name;
+    entry_map_[name] = eptr;
+    entries_.emplace_back(std::move(entry));
+    return *eptr;
+  }
+
+  /*!
+   * \brief List all the entry names in the registry.
+   * \return The entry names.
+   */
+  Array<String> ListAllNames() const {
+    Array<String> names;
+    for (const auto& kv : entry_map_) {
+      names.push_back(kv.first);
+    }
+    return names;
+  }
+
+  /*!
+   * \brief Update the attribute stable.
+   * \param attr_name The name of the attribute.
+   * \param key The key to the attribute table.
+   * \param value The value to be set.
+   * \param plevel The support level.
+   */
+  void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value,
+                  int plevel) {
+    using runtime::TVMRetValue;
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto& op_map = attrs_[attr_name];
+    if (op_map == nullptr) {
+      op_map.reset(new AttrRegistryMapContainerMap<KeyType>());
+      op_map->attr_name_ = attr_name;
+    }
+
+    uint32_t index = key->AttrRegistryIndex();
+    if (op_map->data_.size() <= index) {
+      op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
+    }
+    std::pair<TVMRetValue, int>& p = op_map->data_[index];
+    CHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName()
+                              << " is already registered with same plevel=" << plevel;
+    CHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name
+                                            << " of operator " << key->AttrRegistryName();
+    if (p.second < plevel && value.type_code() != kTVMNullptr) {
+      op_map->data_[index] = std::make_pair(value, plevel);
+    }
+  }
+
+  /*!
+   * \brief Reset an attribute table entry.
+   * \param attr_name The name of the attribute.
+   * \param key The key to the attribute table.
+   */
+  void ResetAttr(const String& attr_name, const KeyType& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto& op_map = attrs_[attr_name];
+    if (op_map == nullptr) {
+      return;
+    }
+    uint32_t index = key->AttrRegistryIndex();
+    if (op_map->data_.size() > index) {
+      op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
+    }
+  }
+
+  /*!
+   * \brief Get an internal attribute map.
+   * \param attr_name The name of the attribute.
+   * \return The result attribute map.
+   */
+  const AttrRegistryMapContainerMap<KeyType>& GetAttrMap(const String& attr_name) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    auto it = attrs_.find(attr_name);
+    if (it == attrs_.end()) {
+      LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered";
+    }
+    return *it->second.get();
+  }
+
+  /*!
+   * \brief Check of attribute has been registered.
+   * \param attr_name The name of the attribute.
+   * \return The check result.
+   */
+  bool HasAttrMap(const String& attr_name) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    return attrs_.count(attr_name);
+  }
+
+  /*!
+   * \return a global singleton of the registry.
+   */
+  static TSelf* Global() {
+    static TSelf* inst = new TSelf();
+    return inst;
+  }
+
+ private:
+  // mutex to avoid registration from multiple threads.
+  std::mutex mutex_;
+  // entries in the registry
+  std::vector<std::unique_ptr<EntryType>> entries_;
+  // map from name to entries.
+  std::unordered_map<String, EntryType*> entry_map_;
+  // storage of additional attribute table.
+  std::unordered_map<String, std::unique_ptr<AttrRegistryMapContainerMap<KeyType>>> attrs_;
+};
+
+}  // namespace tvm
+#endif  // TVM_NODE_ATTR_REGISTRY_H_
index 882bba9..d2e62b7 100644 (file)
@@ -178,7 +178,7 @@ class MacCounter : private ExprVisitor {
 
  private:
   void VisitExpr_(const CallNode* call_node) final {
-    static const auto& fprep = Op::GetAttr<FMacCount>("FMacCount");
+    static const auto& fprep = Op::GetAttrMap<FMacCount>("FMacCount");
     auto f = fprep.get(call_node->op, nullptr);
     if (f != nullptr) count_ += f(GetRef<Call>(call_node));
     ExprVisitor::VisitExpr_(call_node);
index af23836..6d246c0 100644 (file)
@@ -436,7 +436,7 @@ bool IsDynamic(const Type& ty) {
 TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
 
 bool IsDataDependant(const CallNode* call) {
-  static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+  static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
   Op op = Downcast<Op>(call->op);
 
   if (!tshape_data_dependant.count(op)) {
index 12a5add..f143479 100644 (file)
@@ -191,7 +191,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   }
 
   Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
-    static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
     static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
     CHECK(flower_call) << "relay.backend.lower_call is not registered.";
 
@@ -454,8 +454,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   }
 
   Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
-    static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
-    static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
+    static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
+    static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
     CHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
     Op op = Downcast<Op>(call_node->op);
     CHECK(data_dependants_.empty() || !data_dependants_.back())
index 81fc4f0..7c70f32 100644 (file)
@@ -109,7 +109,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
     for (auto kv : attributes) {
       auto attr_name = kv.first;
       auto attr_value = kv.second;
-      auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+      auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
       if (op_map.count(op)) {
         switch (op_map[op].type_code()) {
           case kDLInt:
index 7b91e8c..3d242cd 100644 (file)
@@ -72,7 +72,7 @@ class AlterTransformMemorizer : public TransformMemorizer {
    * \return The new Call after calling the packed func.
    */
   Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
-    static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
+    static auto falter_layout = Op::GetAttrMap<FTVMAlterOpLayout>("FTVMAlterOpLayout");
     Op op = Downcast<Op>(ref_call->op);
 
     Expr new_e;
index 3635947..0d97005 100644 (file)
@@ -145,10 +145,10 @@ class AnnotateTargetRewriter : public ExprRewriter {
       Op op = Downcast<Op>(pre->op);
       CHECK(op.defined());
       for (const auto& target : this->targets_) {
-        if (!Op::HasAttr("target." + std::string(target))) {
+        if (!Op::HasAttrMap("target." + std::string(target))) {
           continue;
         }
-        auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
+        auto fannotate = Op::GetAttrMap<FTVMAnnotateTarget>("target." + std::string(target));
         if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) {
           supported_targets.push_back(target);
         }
index f478107..055ab14 100644 (file)
@@ -66,7 +66,7 @@ class CastCanonicalizer : public ExprMutator {
   CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
 
   Expr VisitExpr_(const CallNode* call) {
-    static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
 
     if (const OpNode* opnode = call->op.as<OpNode>()) {
       auto pattern = fpattern[GetRef<Op>(opnode)];
index 854a1ae..7ca2ce8 100644 (file)
@@ -81,7 +81,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
 
 // Create a branch starting from op.
 Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
-  auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+  auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
   // each branch has at least one element, the first element is always op
   Branch branch{op};
   auto it = children_map_.find(GetRef<Expr>(branch.back()));
index 7d42125..4a18925 100644 (file)
@@ -82,7 +82,7 @@ class ConvertTransformMemorizer : public TransformMemorizer {
    * \return The new Call after calling the packed func.
    */
   Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
-    static auto fconvert_layout = Op::GetAttr<FTVMConvertOpLayout>("FTVMConvertOpLayout");
+    static auto fconvert_layout = Op::GetAttrMap<FTVMConvertOpLayout>("FTVMConvertOpLayout");
     Op op = Downcast<Op>(ref_call->op);
 
     Expr new_e;
index 2861f32..27e5344 100644 (file)
@@ -42,7 +42,7 @@ class CommonSubexprEliminator : public ExprMutator {
   explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
 
   Expr VisitExpr_(const CallNode* call) final {
-    static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
     Expr new_expr = ExprMutator::VisitExpr_(call);
     const CallNode* new_call = new_expr.as<CallNode>();
     CHECK(new_call);
index 70df0ed..e2ab35b 100644 (file)
@@ -104,7 +104,7 @@ class ConstantFolder : public ExprMutator {
   }
 
   Expr VisitExpr_(const CallNode* call) final {
-    static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
 
     std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};
 
index 4083d08..a3765f3 100644 (file)
@@ -255,7 +255,7 @@ class ForwardPrep : private ExprVisitor {
     ExprVisitor::VisitExpr_(call);
     // function to be lazily invoked
     auto flazy = [this, call]() {
-      static const auto& fprep = Op::GetAttr<FForwardPrep>("FScaleAxisForwardPrep");
+      static const auto& fprep = Op::GetAttrMap<FForwardPrep>("FScaleAxisForwardPrep");
       // find the message send to this node.
       auto it = message_.find(call);
       Message out_message;
@@ -625,7 +625,7 @@ class BackwardPrep : private ExprVisitor {
   // Visit the expression.
   void VisitExpr_(const CallNode* call) {
     ExprVisitor::VisitExpr_(call);
-    static const auto& fprep = Op::GetAttr<FBackwardPrep>("FScaleAxisBackwardPrep");
+    static const auto& fprep = Op::GetAttrMap<FBackwardPrep>("FScaleAxisBackwardPrep");
     auto f = fprep.get(call->op, nullptr);
     if (f == nullptr) return;
     auto rit = ref_counter_.find(call);
@@ -727,7 +727,7 @@ class BackwardTransformer : public ObjectRef {
 };
 
 Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) {
-  static const auto& ftransform = Op::GetAttr<FBackwardTransform>("FScaleAxisBackwardTransform");
+  static const auto& ftransform = Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
   auto f = ftransform.get(call_node->op, nullptr);
   if (f != nullptr) {
     const Call call = GetRef<Call>(call_node);
index 226b338..d872116 100644 (file)
@@ -53,7 +53,7 @@ class TempRealizer : private MixedModeMutator {
 
 class ForwardRewriter : private MixedModeMutator {
  public:
-  ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
+  ForwardRewriter(const OpAttrMap<FForwardRewrite>* rewrite_map,
                   std::function<ObjectRef(const Call&)> fcontext,
                   std::function<Expr(const Expr&)> fmulti_ref_trigger)
       : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
@@ -73,7 +73,7 @@ class ForwardRewriter : private MixedModeMutator {
 
  private:
   // The rewrite rule.
-  const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
+  const OpAttrMap<FForwardRewrite>* rewrite_map_{nullptr};
   const FForwardRewrite* rewrite_func_{nullptr};
   // The context.const
   std::function<ObjectRef(const Call&)> fcontext_{nullptr};
@@ -175,7 +175,7 @@ class ForwardRewriter : private MixedModeMutator {
 Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name,
                     std::function<ObjectRef(const Call&)> fcontext,
                     std::function<Expr(const Expr&)> fmulti_ref_trigger) {
-  auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
+  auto rewrite_map = Op::GetAttrMap<FForwardRewrite>(rewrite_map_name);
   return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
 }
 
index 054244d..01f1eee 100644 (file)
@@ -226,7 +226,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
   void VisitExpr_(const CallNode* call) final {
     CHECK(graph_.node_map.count(call));
     Node* node = graph_.node_map.at(call);
-    static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
     // Now we set the pattern of this call.
     //
     // If we see a call mentioning an operator we should mark it with its
@@ -824,7 +824,7 @@ class FuseMutator : private ExprMutator {
   // Transform calls.
   Expr VisitExpr_(const CallNode* call) {
     if (call->op.as<OpNode>()) {
-      static auto fnoncomputational = Op::GetAttr<TNonComputational>("TNonComputational");
+      static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
 
       if (fnoncomputational.get(Downcast<Op>(call->op), false)) {
         return ExprMutator::VisitExpr_(call);
index 67c62f3..afe5568 100644 (file)
@@ -130,7 +130,7 @@ struct ADFunction : ADValueNode {
 };
 
 struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
-  const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+  const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
   std::vector<std::function<void(LetList* ll)>> backprop_actions;
   // we assume no closure so no need for lexical scoping
   std::unordered_map<Var, ADValue, ObjectHash, ObjectEqual> env;
@@ -354,7 +354,7 @@ struct ReverseAD : ExprMutator {
 
   Var bp;
   std::shared_ptr<ADVarMap> ad_vars;
-  const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+  const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
 
   explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
 
@@ -456,7 +456,7 @@ struct ReverseAD : ExprMutator {
 
 bool MissingGrad(const Expr& e) {
   struct MGVisitor : ExprVisitor {
-    const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
+    const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
     std::unordered_set<std::string> op_names;
 
     void VisitExpr_(const OpNode* op) final {
index e4df647..7ced51d 100644 (file)
@@ -208,7 +208,7 @@ inline Array<Array<Layout>> BinaryBroadcastLayout(const Attrs& attrs,
 static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts(
     const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
     const Array<tvm::relay::Type>& old_in_types) {
-  static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
+  static auto finfer_layout = Op::GetAttrMap<FInferCorrectLayout>("FInferCorrectLayout");
   if (!call->op.as<OpNode>()) {
     return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
   }
index 25919b4..c1f037f 100644 (file)
@@ -44,13 +44,13 @@ class Legalizer : public ExprRewriter {
     // Get the new_call node without any changes to current call node.
     Call new_call = Downcast<Call>(post);
 
-    // Check if the string is registered in the OpRegistry.
-    if (!Op::HasAttr(legalize_map_attr_name_)) {
+    // Check if the string is registered.
+    if (!Op::HasAttrMap(legalize_map_attr_name_)) {
       return post;
     }
 
     // Collect the registered legalize function.
-    auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
+    auto fop_legalize = Op::GetAttrMap<FTVMLegalize>(legalize_map_attr_name_);
     auto call_op = call_node->op;
     if (call_op.as<OpNode>()) {
       Op op = Downcast<Op>(call_node->op);
index a27cb79..3e27b87 100644 (file)
@@ -511,7 +511,7 @@ PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object<PStaticNode>(
 enum struct MatchStatus { Match, NoMatch, Unknown };
 
 bool StatefulOp(const Expr& e) {
-  static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
+  static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
   struct StatefulOpVisitor : ExprVisitor {
     bool stateful = false;
     void VisitExpr_(const OpNode* op) {
index d7ce0c0..3a2b2d9 100644 (file)
@@ -59,7 +59,7 @@ TVM_REGISTER_GLOBAL("test.strategy")
 TVM_REGISTER_GLOBAL("relay.backend.lower_call")
     .set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs,
                        const Target& target) {
-      static auto fstrategy = Op::GetAttr<relay::FTVMStrategy>("FTVMStrategy");
+      static auto fstrategy = Op::GetAttrMap<relay::FTVMStrategy>("FTVMStrategy");
       Op op = Downcast<Op>(call->op);
       auto out_type = call->checked_type();
       OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target);