[BYOC][JSON] Support input nodes with multiple entries (#6368)
authorTrevor Morris <trevmorr@amazon.com>
Tue, 1 Sep 2020 18:06:16 +0000 (11:06 -0700)
committerGitHub <noreply@github.com>
Tue, 1 Sep 2020 18:06:16 +0000 (11:06 -0700)
* Support input nodes with multiple data entries

* Rename input_var_idx_ to input_var_eid_

src/runtime/contrib/json/json_runtime.h

index 92830e6..9eb7fcd 100644 (file)
@@ -146,12 +146,12 @@ class JSONRuntimeBase : public ModuleNode {
    * \param args The packed args.
    */
   void SetInputOutputBuffers(const TVMArgs& args) {
-    CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size())
+    CHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
         << "Found mismatch in the number of provided data entryies and required.";
 
     for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
-      auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0)
-                                           : EntryID(outputs_[i - input_var_idx_.size()]);
+      auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
+                                           : EntryID(outputs_[i - input_var_eid_.size()]);
       CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
           << "Expect NDArray or DLTensor as inputs";
 
@@ -183,7 +183,10 @@ class JSONRuntimeBase : public ModuleNode {
       uint32_t nid = input_nodes_[i];
       std::string name = nodes_[nid].name_;
       if (nodes_[nid].op_type_ == "input") {
-        input_var_idx_.push_back(nid);
+        CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size());
+        for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
+          input_var_eid_.push_back(EntryID(nid, j));
+        }
       } else {
         CHECK_EQ(nodes_[nid].op_type_, "const");
         auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
@@ -261,8 +264,8 @@ class JSONRuntimeBase : public ModuleNode {
   std::vector<JSONGraphNodeEntry> outputs_;
   /*! \brief Data of that entry. */
   std::vector<const DLTensor*> data_entry_;
-  /*! \brief Map the input name to node index. */
-  std::vector<uint32_t> input_var_idx_;
+  /*! \brief Map the input name to entry id. */
+  std::vector<uint32_t> input_var_eid_;
   /*! \brief input const node index. */
   std::vector<uint32_t> const_idx_;
   /*! \brief Indicate if the engine has been initialized. */