[NODE] General serialzation of leaf objects into bytes. (#5299)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 10 Apr 2020 05:04:08 +0000 (22:04 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 05:04:08 +0000 (22:04 -0700)
This PR refactors the serialization mechanism to support general
serialization of leaf objects into bytes.

The new feature superceded the original GetGlobalKey feature for singletons.
Added serialization support for runtime::String.

include/tvm/node/reflection.h
python/tvm/ir/json_compact.py
src/ir/env_func.cc
src/ir/op.cc
src/ir/span.cc
src/node/container.cc
src/node/reflection.cc
src/node/serialization.cc
tests/python/relay/test_json_compact.py
tests/python/unittest/test_node_reflection.py

index 18dfa12..9ed87df 100644 (file)
@@ -98,17 +98,17 @@ class ReflectionVTable {
   typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
   /*!
    * \brief creator function.
-   * \param global_key Key that identifies a global single object.
-   *        If this is not empty then FGlobalKey must be defined for the object.
+   * \param repr_bytes Repr bytes to create the object.
+   *        If this is not empty then FReprBytes must be defined for the object.
    * \return The created function.
    */
-  typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
+  typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
   /*!
-   * \brief Global key function, only needed by global objects.
+   * \brief Function to get a byte representation that can be used to recover the object.
    * \param node The node pointer.
-   * \return node The global key to the node.
+   * \return bytes The bytes that can be used to recover the object.
    */
-  typedef std::string (*FGlobalKey)(const Object* self);
+  typedef std::string (*FReprBytes)(const Object* self);
   /*!
    * \brief Dispatch the VisitAttrs function.
    * \param self The pointer to the object.
@@ -116,11 +116,13 @@ class ReflectionVTable {
    */
   inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
   /*!
-   * \brief Get global key of the object, if any.
+   * \brief Get repr bytes if any.
    * \param self The pointer to the object.
-   * \return the global key if object has one, otherwise return empty string.
+   * \param repr_bytes The output repr bytes, can be null, in which case the function
+   *                   simply queries if the ReprBytes function exists for the type.
+   * \return Whether repr bytes exists
    */
-  inline std::string GetGlobalKey(Object* self) const;
+  inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
   /*!
    * \brief Dispatch the SEqualReduce function.
    * \param self The pointer to the object.
@@ -141,10 +143,10 @@ class ReflectionVTable {
    *        by type_key and global key.
    *
    * \param type_key The type key of the object.
-   * \param global_key A global key that can be used to uniquely identify the object if any.
+   * \param repr_bytes Bytes representation of the object if any.
    */
   TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
-                                             const std::string& global_key = "") const;
+                                             const std::string& repr_bytes = "") const;
   /*!
    * \brief Get an field object by the attr name.
    * \param self The pointer to the object.
@@ -176,8 +178,8 @@ class ReflectionVTable {
   std::vector<FSHashReduce> fshash_reduce_;
   /*! \brief Creation function. */
   std::vector<FCreate> fcreate_;
-  /*! \brief Global key function. */
-  std::vector<FGlobalKey> fglobal_key_;
+  /*! \brief ReprBytes function. */
+  std::vector<FReprBytes> frepr_bytes_;
 };
 
 /*! \brief Registry of a reflection table. */
@@ -196,13 +198,13 @@ class ReflectionVTable::Registry {
     return *this;
   }
   /*!
-   * \brief Set global_key function.
-   * \param f The creator function.
+   * \brief Set bytes repr function.
+   * \param f The ReprBytes function.
    * \return rference to self.
    */
-  Registry& set_global_key(FGlobalKey f) {  // NOLINT(*)
-    CHECK_LT(type_index_, parent_->fglobal_key_.size());
-    parent_->fglobal_key_[type_index_] = f;
+  Registry& set_repr_bytes(FReprBytes f) {  // NOLINT(*)
+    CHECK_LT(type_index_, parent_->frepr_bytes_.size());
+    parent_->frepr_bytes_[type_index_] = f;
     return *this;
   }
 
@@ -365,7 +367,7 @@ ReflectionVTable::Register() {
   if (tindex >= fvisit_attrs_.size()) {
     fvisit_attrs_.resize(tindex + 1, nullptr);
     fcreate_.resize(tindex + 1, nullptr);
-    fglobal_key_.resize(tindex + 1, nullptr);
+    frepr_bytes_.resize(tindex + 1, nullptr);
     fsequal_reduce_.resize(tindex + 1, nullptr);
     fshash_reduce_.resize(tindex + 1, nullptr);
   }
@@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const {
   fvisit_attrs_[tindex](self, visitor);
 }
 
-inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
+inline bool ReflectionVTable::GetReprBytes(const Object* self,
+                                           std::string* repr_bytes) const {
   uint32_t tindex = self->type_index();
-  if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
-    return fglobal_key_[tindex](self);
+  if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
+    if (repr_bytes != nullptr) {
+      *repr_bytes = frepr_bytes_[tindex](self);
+    }
+    return true;
   } else {
-    return std::string();
+    return false;
   }
 }
 
index aa43df5..e091cd1 100644 (file)
@@ -79,8 +79,16 @@ def create_updater_06_to_07():
             return item
         return _convert
 
+    def _update_global_key(item, _):
+        item["repr_str"] = item["global_key"]
+        del item["global_key"]
+        return item
+
     node_map = {
         # Base IR
+        "SourceName": _update_global_key,
+        "EnvFunc": _update_global_key,
+        "relay.Op": _update_global_key,
         "relay.TypeVar": _ftype_var,
         "relay.GlobalTypeVar": _ftype_var,
         "relay.Type": _rename("Type"),
index 3e85c5f..4d3ed30 100644 (file)
@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
 
 TVM_REGISTER_NODE_TYPE(EnvFuncNode)
 .set_creator(CreateEnvNode)
-.set_global_key([](const Object* n) -> std::string {
+.set_repr_bytes([](const Object* n) -> std::string {
     return static_cast<const EnvFuncNode*>(n)->name;
   });
 
index 54374eb..6a50240 100644 (file)
@@ -223,7 +223,7 @@ ObjectPtr<Object> CreateOp(const std::string& name) {
 
 TVM_REGISTER_NODE_TYPE(OpNode)
 .set_creator(CreateOp)
-.set_global_key([](const Object* n) {
+.set_repr_bytes([](const Object* n) {
     return static_cast<const OpNode*>(n)->name;
   });
 
index d03903c..f84353d 100644 (file)
@@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 TVM_REGISTER_NODE_TYPE(SourceNameNode)
 .set_creator(GetSourceNameNode)
-.set_global_key([](const Object* n) {
+.set_repr_bytes([](const Object* n) {
     return static_cast<const SourceNameNode*>(n)->name;
   });
 
index 8fff151..e7e4979 100644 (file)
@@ -48,7 +48,21 @@ struct StringObjTrait {
   }
 };
 
-TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+struct RefToObjectPtr : public ObjectRef {
+  static ObjectPtr<Object> Get(const ObjectRef& ref) {
+    return GetDataPtr<Object>(ref);
+  }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
+.set_creator([](const std::string& bytes) {
+  return RefToObjectPtr::Get(runtime::String(bytes));
+})
+.set_repr_bytes([](const Object* n) -> std::string {
+  return GetRef<runtime::String>(
+      static_cast<const runtime::StringObj*>(n)).operator std::string();
+});
+
 
 struct ADTObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
index 824874f..08a914f 100644 (file)
@@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() {
 
 ObjectPtr<Object>
 ReflectionVTable::CreateInitObject(const std::string& type_key,
-                                   const std::string& global_key) const {
+                                   const std::string& repr_bytes) const {
   uint32_t tindex = Object::TypeKey2Index(type_key);
   if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
     LOG(FATAL) << "TypeError: " << type_key
                << " is not registered via TVM_REGISTER_NODE_TYPE";
   }
-  return fcreate_[tindex](global_key);
+  return fcreate_[tindex](repr_bytes);
 }
 
 class NodeAttrSetter : public AttrVisitor {
index 11c9e8f..ee6072d 100644 (file)
@@ -32,6 +32,7 @@
 #include <tvm/ir/attrs.h>
 
 #include <string>
+#include <cctype>
 #include <map>
 
 #include "../support/base64.h"
@@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) {
   return DataType(runtime::String2DLDataType(s));
 }
 
+inline std::string Base64Decode(std::string s) {
+  dmlc::MemoryStringStream mstrm(&s);
+  support::Base64InStream b64strm(&mstrm);
+  std::string output;
+  b64strm.InitPosition();
+  dmlc::Stream* strm = &b64strm;
+  strm->Read(&output);
+  return output;
+}
+
+inline std::string Base64Encode(std::string s) {
+  std::string blob;
+  dmlc::MemoryStringStream mstrm(&blob);
+  support::Base64OutStream b64strm(&mstrm);
+  dmlc::Stream* strm = &b64strm;
+  strm->Write(s);
+  b64strm.Finish();
+  return blob;
+}
+
 // indexer to index all the nodes
 class NodeIndexer : public AttrVisitor {
  public:
@@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor {
         MakeIndex(const_cast<Object*>(kv.second.get()));
       }
     } else {
-      reflection_->VisitAttrs(node, this);
+      // if the node already have repr bytes, no need to visit Attrs.
+      if (!reflection_->GetReprBytes(node, nullptr)) {
+        reflection_->VisitAttrs(node, this);
+      }
     }
   }
 };
@@ -115,8 +139,8 @@ using AttrMap = std::map<std::string, std::string>;
 struct JSONNode {
   /*! \brief The type of key of the object. */
   std::string type_key;
-  /*! \brief The global key for global object. */
-  std::string global_key;
+  /*! \brief The str repr representation. */
+  std::string repr_bytes;
   /*! \brief the attributes */
   AttrMap attrs;
   /*! \brief keys of a map. */
@@ -127,8 +151,15 @@ struct JSONNode {
   void Save(dmlc::JSONWriter *writer) const {
     writer->BeginObject();
     writer->WriteObjectKeyValue("type_key", type_key);
-    if (global_key.size() != 0) {
-      writer->WriteObjectKeyValue("global_key", global_key);
+    if (repr_bytes.size() != 0) {
+      // choose to use str representation or base64, based on whether
+      // the byte representation is printable.
+      if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
+                      [](char ch) { return std::isprint(ch); })) {
+        writer->WriteObjectKeyValue("repr_str", repr_bytes);
+      } else {
+        writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
+      }
     }
     if (attrs.size() != 0) {
       writer->WriteObjectKeyValue("attrs", attrs);
@@ -145,15 +176,24 @@ struct JSONNode {
   void Load(dmlc::JSONReader *reader) {
     attrs.clear();
     data.clear();
-    global_key.clear();
+    repr_bytes.clear();
     type_key.clear();
+    std::string repr_b64, repr_str;
     dmlc::JSONObjectReadHelper helper;
     helper.DeclareOptionalField("type_key", &type_key);
-    helper.DeclareOptionalField("global_key", &global_key);
+    helper.DeclareOptionalField("repr_b64", &repr_b64);
+    helper.DeclareOptionalField("repr_str", &repr_str);
     helper.DeclareOptionalField("attrs", &attrs);
     helper.DeclareOptionalField("keys", &keys);
     helper.DeclareOptionalField("data", &data);
     helper.ReadAllFields(reader);
+
+    if (repr_str.size() != 0) {
+      CHECK_EQ(repr_b64.size(), 0U);
+      repr_bytes = std::move(repr_str);
+    } else if (repr_b64.size() != 0) {
+      repr_bytes = Base64Decode(repr_b64);
+    }
   }
 };
 
@@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor {
       return;
     }
     node_->type_key = node->GetTypeKey();
-    node_->global_key = reflection_->GetGlobalKey(node);
-    // No need to recursively visit fields of global singleton
-    // They are registered via the environment.
-    if (node_->global_key.length() != 0) return;
+    // do not need to print additional things once we have repr bytes.
+    if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
 
     // populates the fields.
     node_->attrs.clear();
@@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) {
   for (const JSONNode& jnode : jgraph.nodes) {
     if (jnode.type_key.length() != 0) {
       ObjectPtr<Object> node =
-          reflection->CreateInitObject(jnode.type_key, jnode.global_key);
+          reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
       nodes.emplace_back(node);
     } else {
       nodes.emplace_back(ObjectPtr<Object>());
@@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) {
 
   for (size_t i = 0; i < nodes.size(); ++i) {
     setter.node_ = &jgraph.nodes[i];
-    // do not need to recover content of global singleton object
-    // they are registered via the environment
-    if (setter.node_->global_key.length() == 0) {
+    // Skip the nodes that has an repr bytes representation.
+    // NOTE: the second condition is used to guard the case
+    // where the repr bytes itself is an empty string "".
+    if (setter.node_->repr_bytes.length() == 0 &&
+        nodes[i] != nullptr &&
+        !reflection->GetReprBytes(nodes[i].get(), nullptr)) {
       setter.Set(nodes[i].get());
     }
   }
index 54812be..16d02d2 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 
 import tvm
+from tvm import relay
 from tvm import te
 import json
 
@@ -108,6 +109,22 @@ def test_global_var():
     assert isinstance(tvar, tvm.ir.GlobalVar)
 
 
+def test_op():
+    nodes = [
+        {"type_key": ""},
+        {"type_key": "relay.Op",
+         "global_key": "nn.conv2d"}
+    ]
+    data = {
+        "root" : 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.6.0"},
+        "b64ndarrays": [],
+    }
+    op = tvm.ir.load_json(json.dumps(data))
+    assert op == relay.op.get("nn.conv2d")
+
+
 def test_tir_var():
     nodes = [
         {"type_key": ""},
@@ -132,6 +149,7 @@ def test_tir_var():
 
 
 if __name__ == "__main__":
+    test_op()
     test_type_var()
     test_incomplete_type()
     test_func_tuple_type()
index f2848ff..9751922 100644 (file)
@@ -89,7 +89,20 @@ def test_env_func():
     assert x.func(10) == 11
 
 
+def test_string():
+    # non printable str, need to store by b64
+    s1 = tvm.runtime.String("xy\x01z")
+    s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
+    tvm.ir.assert_structural_equal(s1, s2)
+
+    # printable str, need to store by repr_str
+    s1 = tvm.runtime.String("xyz")
+    s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
+    tvm.ir.assert_structural_equal(s1, s2)
+
+
 if __name__ == "__main__":
+    test_string()
     test_env_func()
     test_make_node()
     test_make_smap()