operator std::string() const { return std::string{get()->data, size()}; }
/*!
+ * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
+ * \param val The value to be checked
+ * \return A boolean indicating if val can be converted to String
+ */
+ static bool CanConvertFrom(const TVMArgValue& val) {
+ return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
+ }
+
+ /*!
* \brief Hash the binary bytes
* \param data The data pointer
* \param size The size of the bytes.
runtime::TVMArgValue val = args[i + 1];
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
- } else if (val.type_code() == kTVMStr) {
+ } else if (String::CanConvertFrom(val)) {
dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
- if (args.size() != 0 && args[0].type_code() == kTVMStr) {
- MapNode::ContainerType data;
- for (int i = 0; i < args.num_args; i += 2) {
- CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str";
- CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of the map to be object";
- data.emplace(
- std::make_pair(String(args[i].operator std::string()), args[i + 1].operator ObjectRef()));
- }
- auto node = make_object<MapNode>();
- node->data = std::move(data);
- *ret = Map<ObjectRef, ObjectRef>(node);
- } else {
- // Container node.
- MapNode::ContainerType data;
- for (int i = 0; i < args.num_args; i += 2) {
- CHECK(args[i].IsObjectRef<ObjectRef>()) << "key of map need to be object";
- CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of map to be object";
- data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef()));
- }
- auto node = make_object<MapNode>();
- node->data = std::move(data);
- *ret = Map<ObjectRef, ObjectRef>(node);
+ MapNode::ContainerType data;
+ for (int i = 0; i < args.num_args; i += 2) {
+ ObjectRef k =
+ String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
+ ObjectRef v = args[i + 1];
+ data.emplace(std::move(k), std::move(v));
}
+ auto node = make_object<MapNode>();
+ node->data = std::move(data);
+ *ret = Map<ObjectRef, ObjectRef>(node);
});
TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(ptr->IsInstance<MapNode>());
auto* n = static_cast<const MapNode*>(ptr);
- if (args[1].type_code() == kTVMStr) {
- auto it = n->data.find(String(args[1].operator std::string()));
- CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
- *ret = (*it).second;
- } else {
- auto it = n->data.find(args[1].operator ObjectRef());
- CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
- *ret = (*it).second;
- }
+ auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String()
+ : args[1].operator ObjectRef());
+ CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
+ *ret = (*it).second;
});
TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
const MapNode* n = static_cast<const MapNode*>(ptr);
- if (args[1].type_code() == kTVMStr) {
- *ret = static_cast<int64_t>(n->data.count(String(args[1].operator std::string())));
- } else {
- *ret = static_cast<int64_t>(n->data.count(args[1].operator ObjectRef()));
- }
+ int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String()
+ : args[1].operator ObjectRef());
+ *ret = cnt;
});
TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
return val->data == rhs.operator std::string();
}
break;
+ case kTVMObjectHandle:
+ if (rhs.IsObjectRef<String>()) {
+ if (auto* val = lhs.as<tir::StringImmNode>()) {
+ return rhs.operator String() == val->value;
+ } else if (auto* val = lhs.as<StringObj>()) {
+ return rhs.operator String() == val->data;
+ }
+ }
+ break;
default:
CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
}
/*!
* \file graph_runtime_debug.cc
*/
+#include <tvm/runtime/container.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
});
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- if (args[0].type_code() == kTVMStr) {
+ if (String::CanConvertFrom(args[0])) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else {
this->DebugGetNodeOutput(args[0], args[1]);
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- if (args[0].type_code() == kTVMStr) {
- int in_idx = this->GetInputIndex(args[0]);
+ if (String::CanConvertFrom(args[0])) {
+ int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
});
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- if (args[0].type_code() == kTVMStr) {
- int in_idx = this->GetInputIndex(args[0]);
+ if (String::CanConvertFrom(args[0])) {
+ int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
- if (args[0].type_code() == kTVMStr) {
- in_idx = this->GetInputIndex(args[0]);
- } else if (args[0].IsObjectRef<runtime::String>()) {
- auto str = args[0].AsObjectRef<runtime::String>();
- in_idx = this->GetInputIndex(str);
+ if (String::CanConvertFrom(args[0])) {
+ in_idx = this->GetInputIndex(args[0].operator String());
} else {
in_idx = args[0];
}