* \param args The packed args.
*/
void SetInputOutputBuffers(const TVMArgs& args) {
- CHECK_EQ(args.size(), input_var_idx_.size() + outputs_.size())
+ CHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size())
<< "Found mismatch in the number of provided data entryies and required.";
for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
- auto eid = i < input_var_idx_.size() ? EntryID(input_var_idx_[i], 0)
- : EntryID(outputs_[i - input_var_idx_.size()]);
+ auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
+ : EntryID(outputs_[i - input_var_eid_.size()]);
CHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
<< "Expect NDArray or DLTensor as inputs";
uint32_t nid = input_nodes_[i];
std::string name = nodes_[nid].name_;
if (nodes_[nid].op_type_ == "input") {
- input_var_idx_.push_back(nid);
+ CHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size());
+ for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) {
+ input_var_eid_.push_back(EntryID(nid, j));
+ }
} else {
CHECK_EQ(nodes_[nid].op_type_, "const");
auto pos = std::find(std::begin(const_names_), std::end(const_names_), name);
std::vector<JSONGraphNodeEntry> outputs_;
/*! \brief Data of that entry. */
std::vector<const DLTensor*> data_entry_;
- /*! \brief Map the input name to node index. */
- std::vector<uint32_t> input_var_idx_;
+ /*! \brief Map the input name to entry id. */
+ std::vector<uint32_t> input_var_eid_;
/*! \brief input const node index. */
std::vector<uint32_t> const_idx_;
/*! \brief Indicate if the engine has been initialized. */