7ec4219f4a49751a80df0375a2d4d20d50327014
[platform/core/ml/nnfw.git] / onert-micro / eval-driver / Driver.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <luci_interpreter/Interpreter.h>
18
19 #include <stdexcept>
20 #include <cstdlib>
21 #include <fstream>
22 #include <vector>
23 #include <string>
24 #include <iostream>
25
26 namespace
27 {
28
29 using DataBuffer = std::vector<char>;
30
31 void readDataFromFile(const std::string &filename, char *data, size_t data_size)
32 {
33   std::ifstream fs(filename, std::ifstream::binary);
34   if (fs.fail())
35     throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
36   if (fs.read(data, data_size).fail())
37     throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
38 }
39
40 void writeDataToFile(const std::string &filename, const char *data, size_t data_size)
41 {
42   std::ofstream fs(filename, std::ofstream::binary);
43   if (fs.fail())
44     throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
45   if (fs.write(data, data_size).fail())
46   {
47     throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n");
48   }
49 }
50
51 } // namespace
52
53 /*
54  * @brief EvalDriver main
55  *
56  *        Driver for testing luci-inerpreter
57  *
58  */
59 int entry(int argc, char **argv)
60 {
61   if (argc != 5)
62   {
63     std::cerr
64       << "Usage: " << argv[0]
65       << " <path/to/circle/model> <num_inputs> <path/to/input/prefix> <path/to/output/file>\n";
66     return EXIT_FAILURE;
67   }
68
69   const char *filename = argv[1];
70   const int32_t num_inputs = atoi(argv[2]);
71   const char *input_prefix = argv[3];
72   const char *output_file = argv[4];
73
74   std::ifstream file(filename, std::ios::binary | std::ios::in);
75   if (!file.good())
76   {
77     std::string errmsg = "Failed to open file";
78     throw std::runtime_error(errmsg.c_str());
79   }
80
81   file.seekg(0, std::ios::end);
82   auto fileSize = file.tellg();
83   file.seekg(0, std::ios::beg);
84
85   // reserve capacity
86   DataBuffer model_data(fileSize);
87
88   // read the data
89   file.read(model_data.data(), fileSize);
90   if (file.fail())
91   {
92     std::string errmsg = "Failed to read file";
93     throw std::runtime_error(errmsg.c_str());
94   }
95
96   // Create interpreter.
97   luci_interpreter::Interpreter interpreter(model_data.data());
98
99   // Set input.
100   // Data for n'th input is read from ${input_prefix}n
101   // (ex: Add.circle.input0, Add.circle.input1 ..)
102   int num_inference = 1;
103   for (int j = 0; j < num_inference; ++j)
104   {
105     for (int32_t i = 0; i < num_inputs; i++)
106     {
107       auto input_data = reinterpret_cast<char *>(interpreter.allocateInputTensor(i));
108       readDataFromFile(std::string(input_prefix) + std::to_string(i), input_data,
109                        interpreter.getInputDataSizeByIndex(i));
110     }
111
112     // Do inference.
113     interpreter.interpret();
114   }
115
116   // Get output.
117   int num_outputs = 1;
118   for (int i = 0; i < num_outputs; i++)
119   {
120     auto data = interpreter.readOutputTensor(i);
121
122     // Output data is written in ${output_file}
123     // (ex: Add.circle.output0)
124     writeDataToFile(std::string(output_file) + std::to_string(i), reinterpret_cast<char *>(data),
125                     interpreter.getOutputDataSizeByIndex(i));
126   }
127   return EXIT_SUCCESS;
128 }
129
130 int entry(int argc, char **argv);
131
132 #ifdef NDEBUG
133 int main(int argc, char **argv)
134 {
135   try
136   {
137     return entry(argc, argv);
138   }
139   catch (const std::exception &e)
140   {
141     std::cerr << "ERROR: " << e.what() << std::endl;
142   }
143
144   return 255;
145 }
146 #else  // NDEBUG
147 int main(int argc, char **argv)
148 {
149   // NOTE main does not catch internal exceptions for debug build to make it easy to
150   //      check the stacktrace with a debugger
151   return entry(argc, argv);
152 }
153 #endif // !NDEBUG