2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <luci_interpreter/Interpreter.h>
29 using DataBuffer = std::vector<char>;
31 void readDataFromFile(const std::string &filename, char *data, size_t data_size)
33 std::ifstream fs(filename, std::ifstream::binary);
35 throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
36 if (fs.read(data, data_size).fail())
37 throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
40 void writeDataToFile(const std::string &filename, const char *data, size_t data_size)
42 std::ofstream fs(filename, std::ofstream::binary);
44 throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
45 if (fs.write(data, data_size).fail())
47 throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n");
54 * @brief EvalDriver main
56 * Driver for testing luci-inerpreter
59 int entry(int argc, char **argv)
64 << "Usage: " << argv[0]
65 << " <path/to/circle/model> <num_inputs> <path/to/input/prefix> <path/to/output/file>\n";
69 const char *filename = argv[1];
70 const int32_t num_inputs = atoi(argv[2]);
71 const char *input_prefix = argv[3];
72 const char *output_file = argv[4];
74 std::ifstream file(filename, std::ios::binary | std::ios::in);
77 std::string errmsg = "Failed to open file";
78 throw std::runtime_error(errmsg.c_str());
81 file.seekg(0, std::ios::end);
82 auto fileSize = file.tellg();
83 file.seekg(0, std::ios::beg);
86 DataBuffer model_data(fileSize);
89 file.read(model_data.data(), fileSize);
92 std::string errmsg = "Failed to read file";
93 throw std::runtime_error(errmsg.c_str());
96 // Create interpreter.
97 luci_interpreter::Interpreter interpreter(model_data.data(), true);
100 // Data for n'th input is read from ${input_prefix}n
101 // (ex: Add.circle.input0, Add.circle.input1 ..)
102 int num_inference = 1;
103 for (int j = 0; j < num_inference; ++j)
105 for (int32_t i = 0; i < num_inputs; i++)
107 auto input_data = reinterpret_cast<char *>(interpreter.allocateInputTensor(i));
108 readDataFromFile(std::string(input_prefix) + std::to_string(i), input_data,
109 interpreter.getInputDataSizeByIndex(i));
113 interpreter.interpret();
118 for (int i = 0; i < num_outputs; i++)
120 auto data = interpreter.readOutputTensor(i);
122 // Output data is written in ${output_file}
123 // (ex: Add.circle.output0)
124 writeDataToFile(std::string(output_file) + std::to_string(i), reinterpret_cast<char *>(data),
125 interpreter.getOutputDataSizeByIndex(i));
130 int entry(int argc, char **argv);
133 int main(int argc, char **argv)
137 return entry(argc, argv);
139 catch (const std::exception &e)
141 std::cerr << "ERROR: " << e.what() << std::endl;
147 int main(int argc, char **argv)
149 // NOTE main does not catch internal exceptions for debug build to make it easy to
150 // check the stacktrace with a debugger
151 return entry(argc, argv);