[Object] Restore the StrMap behavior in JSON/SHash/SEqual (#5719)
authorJunru Shao <junrushao1994@gmail.com>
Wed, 3 Jun 2020 20:34:44 +0000 (13:34 -0700)
committerGitHub <noreply@github.com>
Wed, 3 Jun 2020 20:34:44 +0000 (13:34 -0700)
include/tvm/node/container.h
python/tvm/ir/json_compact.py
src/node/container.cc
src/node/serialization.cc
tests/python/relay/test_json_compact.py

index 1a7a8df..a3cfdaf 100644 (file)
@@ -50,6 +50,7 @@ using runtime::ObjectRef;
 using runtime::String;
 using runtime::StringObj;
 
+/*! \brief String-aware ObjectRef hash functor */
 struct ObjectHash {
   size_t operator()(const ObjectRef& a) const {
     if (const auto* str = a.as<StringObj>()) {
@@ -59,6 +60,7 @@ struct ObjectHash {
   }
 };
 
+/*! \brief String-aware ObjectRef equal functor */
 struct ObjectEqual {
   bool operator()(const ObjectRef& a, const ObjectRef& b) const {
     if (a.same_as(b)) {
@@ -96,8 +98,7 @@ class MapNode : public Object {
  * \tparam V The value NodeRef type.
  */
 template <typename K, typename V,
-          typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value ||
-                                             std::is_base_of<std::string, K>::value>::type,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
           typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
 class Map : public ObjectRef {
  public:
index 6fc24c0..2facc79 100644 (file)
@@ -129,6 +129,7 @@ def create_updater_06_to_07():
         "relay.PassContext": _rename("transform.PassContext"),
         "relay.ModulePass": _rename("transform.ModulePass"),
         "relay.Sequential": _rename("transform.Sequential"),
+        "StrMap": _rename("Map"),
         # TIR
         "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
         "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
index f7b9dd3..bdebb7f 100644 (file)
@@ -247,40 +247,51 @@ struct MapNodeTrait {
   }
 
   static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
-    if (key->data.empty()) {
-      hash_reduce(uint64_t(0));
-      return;
-    }
-    if (key->data.begin()->first->IsInstance<StringObj>()) {
+    bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) {
+      return v.first->template IsInstance<StringObj>();
+    });
+    if (is_str_map) {
       SHashReduceForSMap(key, hash_reduce);
     } else {
       SHashReduceForOMap(key, hash_reduce);
     }
   }
 
+  static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
+    for (const auto& kv : lhs->data) {
+      // Only allow equal checking if the keys are already mapped
+      // This resolves common use cases where we want to store
+      // Map<Var, Value> where Var is defined in the function
+      // parameters.
+      ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+      if (!rhs_key.defined()) return false;
+      auto it = rhs->data.find(rhs_key);
+      if (it == rhs->data.end()) return false;
+      if (!equal(kv.second, it->second)) return false;
+    }
+    return true;
+  }
+
+  static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
+    for (const auto& kv : lhs->data) {
+      auto it = rhs->data.find(kv.first);
+      if (it == rhs->data.end()) return false;
+      if (!equal(kv.second, it->second)) return false;
+    }
+    return true;
+  }
+
   static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
     if (rhs->data.size() != lhs->data.size()) return false;
     if (rhs->data.size() == 0) return true;
-    if (lhs->data.begin()->first->IsInstance<StringObj>()) {
-      for (const auto& kv : lhs->data) {
-        auto it = rhs->data.find(kv.first);
-        if (it == rhs->data.end()) return false;
-        if (!equal(kv.second, it->second)) return false;
-      }
-    } else {
-      for (const auto& kv : lhs->data) {
-        // Only allow equal checking if the keys are already mapped
-        // This resolves common use cases where we want to store
-        // Map<Var, Value> where Var is defined in the function
-        // parameters.
-        ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
-        if (!rhs_key.defined()) return false;
-        auto it = rhs->data.find(rhs_key);
-        if (it == rhs->data.end()) return false;
-        if (!equal(kv.second, it->second)) return false;
-      }
+    bool ls = std::all_of(lhs->data.begin(), lhs->data.end(),
+                          [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+    bool rs = std::all_of(rhs->data.begin(), rhs->data.end(),
+                          [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+    if (ls != rs) {
+      return false;
     }
-    return true;
+    return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal);
   }
 };
 
index 9845a6f..3866533 100644 (file)
@@ -110,11 +110,18 @@ class NodeIndexer : public AttrVisitor {
       }
     } else if (node->IsInstance<MapNode>()) {
       MapNode* n = static_cast<MapNode*>(node);
-      for (const auto& kv : n->data) {
-        if (!kv.first->IsInstance<StringObj>()) {
+      bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
+        return v.first->template IsInstance<StringObj>();
+      });
+      if (is_str_map) {
+        for (const auto& kv : n->data) {
+          MakeIndex(const_cast<Object*>(kv.second.get()));
+        }
+      } else {
+        for (const auto& kv : n->data) {
           MakeIndex(const_cast<Object*>(kv.first.get()));
+          MakeIndex(const_cast<Object*>(kv.second.get()));
         }
-        MakeIndex(const_cast<Object*>(kv.second.get()));
       }
     } else {
       // if the node already have repr bytes, no need to visit Attrs.
@@ -246,13 +253,19 @@ class JSONAttrGetter : public AttrVisitor {
       }
     } else if (node->IsInstance<MapNode>()) {
       MapNode* n = static_cast<MapNode*>(node);
-      for (const auto& kv : n->data) {
-        if (const auto* str = kv.first.as<StringObj>()) {
-          node_->keys.push_back(std::string(str->data, str->size));
-        } else {
+      bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
+        return v.first->template IsInstance<StringObj>();
+      });
+      if (is_str_map) {
+        for (const auto& kv : n->data) {
+          node_->keys.push_back(Downcast<String>(kv.first));
+          node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
+        }
+      } else {
+        for (const auto& kv : n->data) {
           node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
+          node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
         }
-        node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
       }
     } else {
       // recursively index normal object.
index c961f99..00d41f0 100644 (file)
@@ -186,6 +186,34 @@ def test_tir_var():
     assert y.name == "y"
 
 
+def test_str_map():
+    nodes = [
+        {'type_key': ''},
+        {'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]},
+        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
+        {'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}},
+        {'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}},
+        {'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}},
+        {'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}},
+        {'type_key': 'runtime.String', 'repr_str': 'x'},
+        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}},
+        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
+        {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}}
+    ]
+    data = {
+        "root" : 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.6.0"},
+        "b64ndarrays": [],
+    }
+    x = tvm.ir.load_json(json.dumps(data))
+    assert(isinstance(x, tvm.ir.container.Map))
+    assert(len(x) == 2)
+    assert('x' in x)
+    assert('z' in x)
+    assert(bool(x['z'] == 2))
+
+
 if __name__ == "__main__":
     test_op()
     test_type_var()
@@ -194,3 +222,4 @@ if __name__ == "__main__":
     test_func_tuple_type()
     test_global_var()
     test_tir_var()
+    test_str_map()