[Object][FFI] Introduce runtime::String::CanConvertFrom (#5718)
authorJunru Shao <junrushao1994@gmail.com>
Wed, 3 Jun 2020 15:01:55 +0000 (08:01 -0700)
committerGitHub <noreply@github.com>
Wed, 3 Jun 2020 15:01:55 +0000 (08:01 -0700)
* [Object][FFI] Introduce runtime::String::CanConvertFrom

* Update container.h

include/tvm/runtime/container.h
src/ir/attrs.cc
src/node/container.cc
src/relay/ir/dataflow_matcher.cc
src/runtime/graph/debug/graph_runtime_debug.cc
src/runtime/graph/graph_runtime.cc

index 8b71081..6bc6fbf 100644 (file)
@@ -1302,6 +1302,15 @@ class String : public ObjectRef {
   operator std::string() const { return std::string{get()->data, size()}; }
 
   /*!
+   * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
+   * \param val The value to be checked
+   * \return A boolean indicating if val can be converted to String
+   */
+  static bool CanConvertFrom(const TVMArgValue& val) {
+    return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
+  }
+
+  /*!
    * \brief Hash the binary bytes
    * \param data The data pointer
    * \param size The size of the bytes.
index 18b17d3..af46439 100644 (file)
@@ -37,7 +37,7 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un
     runtime::TVMArgValue val = args[i + 1];
     if (val.IsObjectRef<ObjectRef>()) {
       dict.Set(key, val.operator ObjectRef());
-    } else if (val.type_code() == kTVMStr) {
+    } else if (String::CanConvertFrom(val)) {
       dict.Set(key, val.operator String());
     } else {
       dict.Set(key, val.operator PrimExpr());
index f8bad00..f7b9dd3 100644 (file)
@@ -292,29 +292,16 @@ TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
 
 TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
   CHECK_EQ(args.size() % 2, 0);
-  if (args.size() != 0 && args[0].type_code() == kTVMStr) {
-    MapNode::ContainerType data;
-    for (int i = 0; i < args.num_args; i += 2) {
-      CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str";
-      CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of the map to be object";
-      data.emplace(
-          std::make_pair(String(args[i].operator std::string()), args[i + 1].operator ObjectRef()));
-    }
-    auto node = make_object<MapNode>();
-    node->data = std::move(data);
-    *ret = Map<ObjectRef, ObjectRef>(node);
-  } else {
-    // Container node.
-    MapNode::ContainerType data;
-    for (int i = 0; i < args.num_args; i += 2) {
-      CHECK(args[i].IsObjectRef<ObjectRef>()) << "key of map need to be object";
-      CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of map to be object";
-      data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef()));
-    }
-    auto node = make_object<MapNode>();
-    node->data = std::move(data);
-    *ret = Map<ObjectRef, ObjectRef>(node);
+  MapNode::ContainerType data;
+  for (int i = 0; i < args.num_args; i += 2) {
+    ObjectRef k =
+        String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
+    ObjectRef v = args[i + 1];
+    data.emplace(std::move(k), std::move(v));
   }
+  auto node = make_object<MapNode>();
+  node->data = std::move(data);
+  *ret = Map<ObjectRef, ObjectRef>(node);
 });
 
 TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -331,15 +318,10 @@ TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* re
   CHECK(ptr->IsInstance<MapNode>());
 
   auto* n = static_cast<const MapNode*>(ptr);
-  if (args[1].type_code() == kTVMStr) {
-    auto it = n->data.find(String(args[1].operator std::string()));
-    CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
-    *ret = (*it).second;
-  } else {
-    auto it = n->data.find(args[1].operator ObjectRef());
-    CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
-    *ret = (*it).second;
-  }
+  auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String()
+                                                         : args[1].operator ObjectRef());
+  CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
+  *ret = (*it).second;
 });
 
 TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -347,11 +329,9 @@ TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret)
   Object* ptr = static_cast<Object*>(args[0].value().v_handle);
   CHECK(ptr->IsInstance<MapNode>());
   const MapNode* n = static_cast<const MapNode*>(ptr);
-  if (args[1].type_code() == kTVMStr) {
-    *ret = static_cast<int64_t>(n->data.count(String(args[1].operator std::string())));
-  } else {
-    *ret = static_cast<int64_t>(n->data.count(args[1].operator ObjectRef()));
-  }
+  int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String()
+                                                              : args[1].operator ObjectRef());
+  *ret = cnt;
 });
 
 TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
index eb305c9..e9543e3 100644 (file)
@@ -121,6 +121,15 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
         return val->data == rhs.operator std::string();
       }
       break;
+    case kTVMObjectHandle:
+      if (rhs.IsObjectRef<String>()) {
+        if (auto* val = lhs.as<tir::StringImmNode>()) {
+          return rhs.operator String() == val->value;
+        } else if (auto* val = lhs.as<StringObj>()) {
+          return rhs.operator String() == val->data;
+        }
+      }
+      break;
     default:
       CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
   }
index 9f206fd..5439be9 100644 (file)
@@ -20,6 +20,7 @@
 /*!
  * \file graph_runtime_debug.cc
  */
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
@@ -173,7 +174,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name,
     });
   } else if (name == "debug_get_output") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      if (args[0].type_code() == kTVMStr) {
+      if (String::CanConvertFrom(args[0])) {
         this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
       } else {
         this->DebugGetNodeOutput(args[0], args[1]);
index 8f7f988..59bfb68 100644 (file)
@@ -390,8 +390,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
   // Return member functions during query.
   if (name == "set_input") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      if (args[0].type_code() == kTVMStr) {
-        int in_idx = this->GetInputIndex(args[0]);
+      if (String::CanConvertFrom(args[0])) {
+        int in_idx = this->GetInputIndex(args[0].operator String());
         if (in_idx >= 0) this->SetInput(in_idx, args[1]);
       } else {
         this->SetInput(args[0], args[1]);
@@ -399,8 +399,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
     });
   } else if (name == "set_input_zero_copy") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-      if (args[0].type_code() == kTVMStr) {
-        int in_idx = this->GetInputIndex(args[0]);
+      if (String::CanConvertFrom(args[0])) {
+        int in_idx = this->GetInputIndex(args[0].operator String());
         if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
       } else {
         this->SetInputZeroCopy(args[0], args[1]);
@@ -417,11 +417,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
   } else if (name == "get_input") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       int in_idx = 0;
-      if (args[0].type_code() == kTVMStr) {
-        in_idx = this->GetInputIndex(args[0]);
-      } else if (args[0].IsObjectRef<runtime::String>()) {
-        auto str = args[0].AsObjectRef<runtime::String>();
-        in_idx = this->GetInputIndex(str);
+      if (String::CanConvertFrom(args[0])) {
+        in_idx = this->GetInputIndex(args[0].operator String());
       } else {
         in_idx = args[0];
       }