From b8f37eeac7611e1be03adc26ee0005ddd6cdf87f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 1 Sep 2020 11:06:16 -0700 Subject: [PATCH] [BYOC][JSON] Support input nodes with multiple entries (#6368) * Support input nodes with multiple data entries * Rename input_var_idx_ to input_var_eid_ --- src/runtime/contrib/json/json_runtime.h | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 92830e6..9eb7fcd 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -146,12 +146,12 @@ class JSONRuntimeBase : public ModuleNode { * \param args The packed args. */ void SetInputOutputBuffers(const TVMArgs& args) { - CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size()) + CHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) << "Found mismatch in the number of provided data entryies and required."; for (size_t i = 0; i < static_cast(args.size()); i++) { - auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0) - : EntryID(outputs_[i - input_var_idx_.size()]); + auto eid = i < input_var_eid_.size() ? input_var_eid_[i] + : EntryID(outputs_[i - input_var_eid_.size()]); CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) << "Expect NDArray or DLTensor as inputs"; @@ -183,7 +183,10 @@ class JSONRuntimeBase : public ModuleNode { uint32_t nid = input_nodes_[i]; std::string name = nodes_[nid].name_; if (nodes_[nid].op_type_ == "input") { - input_var_idx_.push_back(nid); + CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size()); + for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { + input_var_eid_.push_back(EntryID(nid, j)); + } } else { CHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); @@ -261,8 +264,8 @@ class JSONRuntimeBase : public ModuleNode { std::vector outputs_; /*! \brief Data of that entry. */ std::vector data_entry_; - /*! \brief Map the input name to node index. */ - std::vector input_var_idx_; + /*! \brief Map the input name to entry id. */ + std::vector input_var_eid_; /*! \brief input const node index. */ std::vector const_idx_; /*! \brief Indicate if the engine has been initialized. */ -- 2.7.4