typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
/*!
* \brief creator function.
- * \param global_key Key that identifies a global single object.
- * If this is not empty then FGlobalKey must be defined for the object.
+ * \param repr_bytes Repr bytes to create the object.
+ * If this is not empty then FReprBytes must be defined for the object.
* \return The created function.
*/
- typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
+ typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
/*!
- * \brief Global key function, only needed by global objects.
+ * \brief Function to get a byte representation that can be used to recover the object.
* \param node The node pointer.
- * \return node The global key to the node.
+ * \return bytes The bytes that can be used to recover the object.
*/
- typedef std::string (*FGlobalKey)(const Object* self);
+ typedef std::string (*FReprBytes)(const Object* self);
/*!
* \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object.
*/
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*!
- * \brief Get global key of the object, if any.
+ * \brief Get repr bytes if any.
* \param self The pointer to the object.
- * \return the global key if object has one, otherwise return empty string.
+ * \param repr_bytes The output repr bytes, can be null, in which case the function
+ * simply queries if the ReprBytes function exists for the type.
+ * \return Whether repr bytes exists
*/
- inline std::string GetGlobalKey(Object* self) const;
+ inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
/*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
* by type_key and global key.
*
* \param type_key The type key of the object.
- * \param global_key A global key that can be used to uniquely identify the object if any.
+ * \param repr_bytes Bytes representation of the object if any.
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
- const std::string& global_key = "") const;
+ const std::string& repr_bytes = "") const;
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
- /*! \brief Global key function. */
- std::vector<FGlobalKey> fglobal_key_;
+ /*! \brief ReprBytes function. */
+ std::vector<FReprBytes> frepr_bytes_;
};
/*! \brief Registry of a reflection table. */
return *this;
}
/*!
- * \brief Set global_key function.
- * \param f The creator function.
+ * \brief Set bytes repr function.
+ * \param f The ReprBytes function.
* \return rference to self.
*/
- Registry& set_global_key(FGlobalKey f) { // NOLINT(*)
- CHECK_LT(type_index_, parent_->fglobal_key_.size());
- parent_->fglobal_key_[type_index_] = f;
+ Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
+ CHECK_LT(type_index_, parent_->frepr_bytes_.size());
+ parent_->frepr_bytes_[type_index_] = f;
return *this;
}
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
- fglobal_key_.resize(tindex + 1, nullptr);
+ frepr_bytes_.resize(tindex + 1, nullptr);
fsequal_reduce_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr);
}
fvisit_attrs_[tindex](self, visitor);
}
-inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
+inline bool ReflectionVTable::GetReprBytes(const Object* self,
+ std::string* repr_bytes) const {
uint32_t tindex = self->type_index();
- if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
- return fglobal_key_[tindex](self);
+ if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
+ if (repr_bytes != nullptr) {
+ *repr_bytes = frepr_bytes_[tindex](self);
+ }
+ return true;
} else {
- return std::string();
+ return false;
}
}
return item
return _convert
+ def _update_global_key(item, _):
+ item["repr_str"] = item["global_key"]
+ del item["global_key"]
+ return item
+
node_map = {
# Base IR
+ "SourceName": _update_global_key,
+ "EnvFunc": _update_global_key,
+ "relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
-.set_global_key([](const Object* n) -> std::string {
+.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name;
});
TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
-.set_global_key([](const Object* n) {
+.set_repr_bytes([](const Object* n) {
return static_cast<const OpNode*>(n)->name;
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
-.set_global_key([](const Object* n) {
+.set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
}
};
-TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+struct RefToObjectPtr : public ObjectRef {
+ static ObjectPtr<Object> Get(const ObjectRef& ref) {
+ return GetDataPtr<Object>(ref);
+ }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
+.set_creator([](const std::string& bytes) {
+ return RefToObjectPtr::Get(runtime::String(bytes));
+})
+.set_repr_bytes([](const Object* n) -> std::string {
+ return GetRef<runtime::String>(
+ static_cast<const runtime::StringObj*>(n)).operator std::string();
+});
+
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
- const std::string& global_key) const {
+ const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
- return fcreate_[tindex](global_key);
+ return fcreate_[tindex](repr_bytes);
}
class NodeAttrSetter : public AttrVisitor {
#include <tvm/ir/attrs.h>
#include <string>
+#include <cctype>
#include <map>
#include "../support/base64.h"
return DataType(runtime::String2DLDataType(s));
}
+inline std::string Base64Decode(std::string s) {
+ dmlc::MemoryStringStream mstrm(&s);
+ support::Base64InStream b64strm(&mstrm);
+ std::string output;
+ b64strm.InitPosition();
+ dmlc::Stream* strm = &b64strm;
+ strm->Read(&output);
+ return output;
+}
+
+inline std::string Base64Encode(std::string s) {
+ std::string blob;
+ dmlc::MemoryStringStream mstrm(&blob);
+ support::Base64OutStream b64strm(&mstrm);
+ dmlc::Stream* strm = &b64strm;
+ strm->Write(s);
+ b64strm.Finish();
+ return blob;
+}
+
// indexer to index all the nodes
class NodeIndexer : public AttrVisitor {
public:
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
- reflection_->VisitAttrs(node, this);
+ // if the node already have repr bytes, no need to visit Attrs.
+ if (!reflection_->GetReprBytes(node, nullptr)) {
+ reflection_->VisitAttrs(node, this);
+ }
}
}
};
struct JSONNode {
/*! \brief The type of key of the object. */
std::string type_key;
- /*! \brief The global key for global object. */
- std::string global_key;
+ /*! \brief The str repr representation. */
+ std::string repr_bytes;
/*! \brief the attributes */
AttrMap attrs;
/*! \brief keys of a map. */
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
- if (global_key.size() != 0) {
- writer->WriteObjectKeyValue("global_key", global_key);
+ if (repr_bytes.size() != 0) {
+ // choose to use str representation or base64, based on whether
+ // the byte representation is printable.
+ if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
+ [](char ch) { return std::isprint(ch); })) {
+ writer->WriteObjectKeyValue("repr_str", repr_bytes);
+ } else {
+ writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
+ }
}
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
void Load(dmlc::JSONReader *reader) {
attrs.clear();
data.clear();
- global_key.clear();
+ repr_bytes.clear();
type_key.clear();
+ std::string repr_b64, repr_str;
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
- helper.DeclareOptionalField("global_key", &global_key);
+ helper.DeclareOptionalField("repr_b64", &repr_b64);
+ helper.DeclareOptionalField("repr_str", &repr_str);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);
+
+ if (repr_str.size() != 0) {
+ CHECK_EQ(repr_b64.size(), 0U);
+ repr_bytes = std::move(repr_str);
+ } else if (repr_b64.size() != 0) {
+ repr_bytes = Base64Decode(repr_b64);
+ }
}
};
return;
}
node_->type_key = node->GetTypeKey();
- node_->global_key = reflection_->GetGlobalKey(node);
- // No need to recursively visit fields of global singleton
- // They are registered via the environment.
- if (node_->global_key.length() != 0) return;
+ // do not need to print additional things once we have repr bytes.
+ if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
// populates the fields.
node_->attrs.clear();
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
ObjectPtr<Object> node =
- reflection->CreateInitObject(jnode.type_key, jnode.global_key);
+ reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
nodes.emplace_back(node);
} else {
nodes.emplace_back(ObjectPtr<Object>());
for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
- // do not need to recover content of global singleton object
- // they are registered via the environment
- if (setter.node_->global_key.length() == 0) {
+ // Skip the nodes that has an repr bytes representation.
+ // NOTE: the second condition is used to guard the case
+ // where the repr bytes itself is an empty string "".
+ if (setter.node_->repr_bytes.length() == 0 &&
+ nodes[i] != nullptr &&
+ !reflection->GetReprBytes(nodes[i].get(), nullptr)) {
setter.Set(nodes[i].get());
}
}
# under the License.
import tvm
+from tvm import relay
from tvm import te
import json
assert isinstance(tvar, tvm.ir.GlobalVar)
+def test_op():
+ nodes = [
+ {"type_key": ""},
+ {"type_key": "relay.Op",
+ "global_key": "nn.conv2d"}
+ ]
+ data = {
+ "root" : 1,
+ "nodes": nodes,
+ "attrs": {"tvm_version": "0.6.0"},
+ "b64ndarrays": [],
+ }
+ op = tvm.ir.load_json(json.dumps(data))
+ assert op == relay.op.get("nn.conv2d")
+
+
def test_tir_var():
nodes = [
{"type_key": ""},
if __name__ == "__main__":
+ test_op()
test_type_var()
test_incomplete_type()
test_func_tuple_type()
assert x.func(10) == 11
+def test_string():
+ # non printable str, need to store by b64
+ s1 = tvm.runtime.String("xy\x01z")
+ s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
+ tvm.ir.assert_structural_equal(s1, s2)
+
+ # printable str, need to store by repr_str
+ s1 = tvm.runtime.String("xyz")
+ s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
+ tvm.ir.assert_structural_equal(s1, s2)
+
+
if __name__ == "__main__":
+ test_string()
test_env_func()
test_make_node()
test_make_smap()