NNInterpreter interpreter;
// Check ops
- const auto& inputs = g->getInputs();
- assert(inputs.size() == 1 && "Interpreter doesn't support networks with multiple input nodes");
+ auto inputs = g->getInputs();
- auto input_node = inputs[0];
- auto input_data = loadInput(input_node->getOutputShape(0));
- interpreter.setInput(input_node->getName(), input_data);
+ auto input_data = loadInput(inputs);
+ for (auto inp: input_data) {
+ interpreter.setInput(inp.first, inp.second);
+ }
g->accept(&interpreter);
for (auto out_node : g->getOutputs()) {
return nullptr;
}
-TensorVariant InterpreterPass::loadInput(const Shape& shape) {
- auto f = fopen(cli::interInputData.c_str(), "rb");
- assert(f && "Cannot open file");
-
- int is_error = fseek(f, 0L, SEEK_END);
- assert(!is_error);
-
- auto len = ftell(f);
- assert(len != -1);
-
- auto data_size = static_cast<size_t>(shape.numElements() * sizeof(float));
-
- // Check size
- if (static_cast<size_t>(len) != data_size) {
- std::stringstream info;
- info << "Wrong input file size <" << cli::interInputData << "> = "
- << len << ". Should be :" << data_size;
+// Return pure file name w/o path and extension
+static std::string getFileName(std::string path) {
+ size_t sep = path.find_last_of("/");
+ if (sep != std::string::npos)
+ path = path.substr(sep + 1, path.size() - sep - 1);
+ size_t dot = path.find_last_of(".");
+ // TODO: There could be node names with symbols invalid in filename: if yes we should fix it.
+ if (dot != std::string::npos)
+ return path.substr(0, dot);
+ return path;
+}
- throw PassException(info.str());
+// Return input operation index with the given name
+static int getInputOp(std::string name, const std::vector<ops::InputOp*>& ops) {
+ for (unsigned ndx = 0; ndx < ops.size(); ndx++) {
+ if (ops[ndx]->getName() == name)
+ return ndx;
}
+ throw PassException("Input file name (without extension)"
+ " should be equal to model input node name");
+}
- rewind(f);
-
- std::unique_ptr<char[]> data(new char[data_size]);
- auto rlen = fread(data.get(), data_size, 1, f);
- assert(rlen == 1);
- (void) rlen;
-
- is_error = fclose(f);
- assert(is_error != EOF && "Can not close file!");
- (void) is_error;
-
- return TensorVariant(DTYPE::FLOAT32, shape, data.get());
+std::unordered_map<std::string, TensorVariant>
+InterpreterPass::loadInput(const std::vector<ops::InputOp*>& ops) {
+ std::unordered_map<std::string, TensorVariant> result;
+ for (unsigned i = 0; i < ops.size(); i++) {
+ // We assume that file name is equal to input node name
+ auto fname = getFileName(cli::interInputData[i]);
+ auto op_ndx = getInputOp(fname, ops);
+ auto f = fopen(cli::interInputData[i].c_str(), "rb");
+ assert (f);
+ int is_error = fseek(f, 0L, SEEK_END);
+ assert(!is_error);
+
+ auto len = ftell(f);
+ assert(len != -1);
+
+ auto shape = ops[op_ndx]->getOutputShape(0);
+ auto data_size = static_cast<size_t>(shape.numElements() * sizeof(float));
+
+ // Check size
+ if (static_cast<size_t>(len) != data_size) {
+ std::stringstream info;
+ info << "Wrong input file size <" << cli::interInputData[i] << "> = "
+ << len << ". Should be :" << data_size;
+
+ throw PassException(info.str());
+ }
+
+ rewind(f);
+
+ std::unique_ptr<char[]> data(new char[data_size]);
+ auto rlen = fread(data.get(), data_size, 1, f);
+ assert(rlen == 1);
+ (void)rlen;
+
+ is_error = fclose(f);
+ assert(is_error != EOF && "Can not close file!");
+ (void)is_error;
+
+ result.emplace(fname, TensorVariant(DTYPE::FLOAT32, shape, data.get()));
+ }
+ return result;
}
InterpreterPass::~InterpreterPass() {