Improve error messages in graph tuner, graph runtime, and module loader. (#6148)
authorTristan Konolige <tristan.konolige@gmail.com>
Wed, 29 Jul 2020 03:04:19 +0000 (20:04 -0700)
committerGitHub <noreply@github.com>
Wed, 29 Jul 2020 03:04:19 +0000 (20:04 -0700)
* Raise error if no operators are found in GraphTuner

* Raise error if key cannot be found in graph runtime inputs

* Detailed error message when module loader is not found

python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/contrib/graph_runtime.py
src/runtime/library_module.cc
src/runtime/stackvm/stackvm_module.cc

index 1cc4f39..76f92be 100644 (file)
@@ -152,6 +152,9 @@ class BaseGraphTuner(object):
 
         self._graph = graph
         self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
+        if len(self._in_nodes_dict) == 0:
+            raise RuntimeError("Could not find any input nodes with whose "
+                               "operator is one of %s" % self._target_ops)
         self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
         self._fetch_cfg()
         self._opt_out_op = OPT_OUT_OP
index 326eccb..22d0a8b 100644 (file)
@@ -152,7 +152,10 @@ class GraphModule(object):
            Additional arguments
         """
         if key is not None:
-            self._get_input(key).copyfrom(value)
+            v = self._get_input(key)
+            if v is None:
+                raise RuntimeError("Could not find '%s' in graph's inputs" % key)
+            v.copyfrom(value)
 
         if params:
             # upload big arrays first to avoid memory issue in rpc mode
index b12a9d1..651e19c 100644 (file)
@@ -133,9 +133,24 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
       CHECK(stream->Read(&import_tree_row_ptr));
       CHECK(stream->Read(&import_tree_child_indices));
     } else {
-      std::string fkey = "runtime.module.loadbinary_" + tkey;
+      std::string loadkey = "runtime.module.loadbinary_";
+      std::string fkey = loadkey + tkey;
       const PackedFunc* f = Registry::Get(fkey);
-      CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
+      if (f == nullptr) {
+        std::string loaders = "";
+        for (auto name : Registry::ListNames()) {
+          if (name.rfind(loadkey, 0) == 0) {
+            if (loaders.size() > 0) {
+              loaders += ", ";
+            }
+            loaders += name.substr(loadkey.size());
+          }
+        }
+        CHECK(f != nullptr)
+            << "Binary was created using " << tkey
+            << " but a loader of that name is not registered. Available loaders are " << loaders
+            << ". Perhaps you need to recompile with this runtime enabled.";
+      }
       Module m = (*f)(static_cast<void*>(stream));
       modules.emplace_back(m);
     }
index 9e1f1f5..6c9af1c 100644 (file)
@@ -101,9 +101,24 @@ class StackVMModuleNode : public runtime::ModuleNode {
     for (uint64_t i = 0; i < num_imports; ++i) {
       std::string tkey;
       CHECK(strm->Read(&tkey));
-      std::string fkey = "runtime.module.loadbinary_" + tkey;
+      std::string loadkey = "runtime.module.loadbinary_";
+      std::string fkey = loadkey + tkey;
       const PackedFunc* f = Registry::Get(fkey);
-      CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented.";
+      if (f == nullptr) {
+        std::string loaders = "";
+        for (auto name : Registry::ListNames()) {
+          if (name.rfind(loadkey, 0) == 0) {
+            if (loaders.size() > 0) {
+              loaders += ", ";
+            }
+            loaders += name.substr(loadkey.size());
+          }
+        }
+        CHECK(f != nullptr)
+            << "Binary was created using " << tkey
+            << " but a loader of that name is not registered. Available loaders are " << loaders
+            << ". Perhaps you need to recompile with this runtime enabled.";
+      }
       Module m = (*f)(static_cast<void*>(strm));
       n->imports_.emplace_back(std::move(m));
     }