2 // Copyright (c) 2018-2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
29 #include <unordered_map>
31 #include <gflags/gflags.h>
33 #include "inference_engine.hpp"
34 #include "precision_utils.h"
36 #include "vpu_tools_common.hpp"
37 #include "vpu/vpu_plugin_config.hpp"
38 #include "vpu/private_plugin_config.hpp"
39 #include "samples/common.hpp"
41 static constexpr char help_message[] = "Print a help(this) message.";
42 static constexpr char model_message[] = "Path to xml model.";
43 static constexpr char inputs_dir_message[] = "Path to folder with images, only bitmap(.bmp) supported. Default: \".\".";
44 static constexpr char config_message[] = "Path to the configuration file. Default value: \"config\".";
45 static constexpr char iterations_message[] = "Specifies number of iterations. Default value: 16.";
46 static constexpr char plugin_message[] = "Specifies plugin. Supported values: myriad, hddl.\n"
47 "\t \t \tDefault value: \"myriad\".";
48 static constexpr char report_message[] = "Specifies report type. Supported values: per_layer, per_stage.\n"
49 "\t \t \tOverrides value in configuration file if provided. Default value: \"per_stage\"";
51 DEFINE_bool(h, false, help_message);
52 DEFINE_string(model, "", model_message);
53 DEFINE_string(inputs_dir, ".", inputs_dir_message);
54 DEFINE_string(config, "", config_message);
55 DEFINE_int32(iterations, 16, iterations_message);
56 DEFINE_string(plugin, "myriad", plugin_message);
57 DEFINE_string(report, "", report_message);
59 static void showUsage() {
60 std::cout << std::endl;
61 std::cout << "vpu_profile [OPTIONS]" << std::endl;
62 std::cout << "[OPTIONS]:" << std::endl;
63 std::cout << "\t-h \t \t" << help_message << std::endl;
64 std::cout << "\t-model \t <value> \t" << model_message << std::endl;
65 std::cout << "\t-inputs_dir \t <value> \t" << inputs_dir_message << std::endl;
66 std::cout << "\t-config \t <value> \t" << config_message << std::endl;
67 std::cout << "\t-iterations \t <value> \t" << iterations_message << std::endl;
68 std::cout << "\t-plugin \t <value> \t" << plugin_message << std::endl;
69 std::cout << "\t-report \t <value> \t" << report_message << std::endl;
70 std::cout << std::endl;
73 static bool parseCommandLine(int *argc, char ***argv) {
74 gflags::ParseCommandLineNonHelpFlags(argc, argv, true);
81 if (FLAGS_model.empty()) {
82 throw std::invalid_argument("Path to model xml file is required");
86 std::stringstream message;
87 message << "Unknown arguments: ";
88 for (auto arg = 1; arg < *argc; arg++) {
94 throw std::invalid_argument(message.str());
100 static std::map<std::string, std::string> configure(const std::string& confFileName, const std::string& report) {
101 auto config = parseConfig(confFileName);
103 /* Since user can specify config file we probably can avoid it */
104 config[VPU_CONFIG_KEY(LOG_LEVEL)] = CONFIG_VALUE(LOG_WARNING);
105 config[CONFIG_KEY(LOG_LEVEL)] = CONFIG_VALUE(LOG_WARNING);
106 config[VPU_CONFIG_KEY(PRINT_RECEIVE_TENSOR_TIME)] = CONFIG_VALUE(YES);
110 if (report == "per_layer") {
111 config[VPU_CONFIG_KEY(PERF_REPORT_MODE)] = VPU_CONFIG_VALUE(PER_LAYER);
112 } else if (report == "per_stage") {
113 config[VPU_CONFIG_KEY(PERF_REPORT_MODE)] = VPU_CONFIG_VALUE(PER_STAGE);
114 } else if (config.find(VPU_CONFIG_KEY(PERF_REPORT_MODE)) == config.end()) {
115 config[VPU_CONFIG_KEY(PERF_REPORT_MODE)] = VPU_CONFIG_VALUE(PER_LAYER);
121 static bool isImage(const T& blob) {
122 auto descriptor = blob->getTensorDesc();
123 if (descriptor.getLayout() != InferenceEngine::NCHW) {
127 auto channels = descriptor.getDims()[1];
128 return channels == 3;
131 static void loadInputs(std::size_t requestIdx, const std::vector<std::string>& images,
132 const std::vector<std::string>& binaries, InferenceEngine::InferRequest& request,
133 InferenceEngine::CNNNetwork& network) {
134 for (auto &&input : network.getInputsInfo()) {
135 auto blob = request.GetBlob(input.first);
138 loadImage(images[requestIdx % images.size()], blob);
140 loadBinaryTensor(binaries[requestIdx % binaries.size()], blob);
145 static std::string process_user_input(const std::string &src) {
146 std::string name = src;
147 std::transform(name.begin(), name.end(), name.begin(), ::toupper);
148 name.erase(std::remove_if(name.begin(), name.end(), ::isspace), name.end());
153 static std::size_t getNumberRequests(const std::string &plugin) {
154 static const std::unordered_map<std::string, std::size_t> supported_plugins = {
158 auto num_requests = supported_plugins.find(plugin);
159 if (num_requests == supported_plugins.end()) {
160 throw std::invalid_argument("Unknown plugin " + plugin);
163 return num_requests->second;
166 int main(int argc, char* argv[]) {
168 std::cout << "Inference Engine: " << InferenceEngine::GetInferenceEngineVersion() << std::endl;
170 if (!parseCommandLine(&argc, &argv)) {
174 auto network = readNetwork(FLAGS_model);
175 setPrecisions(network);
177 auto user_plugin = process_user_input(FLAGS_plugin);
179 InferenceEngine::Core ie;
180 auto executableNetwork = ie.LoadNetwork(network, user_plugin, configure(FLAGS_config, FLAGS_report));
182 auto num_requests = getNumberRequests(user_plugin);
184 auto images = extractFilesByExtension(FLAGS_inputs_dir, "bmp", 1);
185 auto hasImageInput = [](const InferenceEngine::CNNNetwork &network) {
186 auto inputs = network.getInputsInfo();
187 auto isImageInput = [](const InferenceEngine::InputsDataMap::value_type &input) {
188 return isImage(input.second);
190 return std::any_of(inputs.begin(), inputs.end(), isImageInput);
193 if (hasImageInput(network) && images.empty()) {
194 throw std::invalid_argument(FLAGS_inputs_dir + " does not contain images for network");
197 auto binaries = extractFilesByExtension(FLAGS_inputs_dir, "bin", 1);
198 auto hasBinaryInput = [](const InferenceEngine::CNNNetwork &network) {
199 auto inputs = network.getInputsInfo();
200 auto isBinaryInput = [](const InferenceEngine::InputsDataMap::value_type &input) {
201 return !isImage(input.second);
203 return std::any_of(inputs.begin(), inputs.end(), isBinaryInput);
206 if (hasBinaryInput(network) && binaries.empty()) {
207 throw std::invalid_argument(FLAGS_inputs_dir + " does not contain binaries for network");
210 std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> performance;
212 std::atomic<std::size_t> iteration{0};
213 std::promise<void> done;
214 bool needStartAsync{true};
215 std::size_t profiledIteration = 2 * num_requests + FLAGS_iterations;
217 std::vector<InferenceEngine::InferRequest> requests(num_requests);
218 std::vector<std::size_t> current_iterations(num_requests);
220 using callback_t = std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>;
222 for (std::size_t request = 0; request < num_requests; ++request) {
223 requests[request] = executableNetwork.CreateInferRequest();
224 current_iterations[request] = 0;
226 loadInputs(request, images, binaries, requests[request], network);
228 callback_t callback =
229 [request, profiledIteration, &done, &needStartAsync, &performance, &iteration, ¤t_iterations]
230 (InferenceEngine::InferRequest inferRequest, InferenceEngine::StatusCode code) {
231 if (code != InferenceEngine::StatusCode::OK) {
232 THROW_IE_EXCEPTION << "Infer request failed with code " << code;
235 auto current_iteration = current_iterations[request];
236 if (current_iteration == profiledIteration) {
237 performance = inferRequest.GetPerformanceCounts();
238 needStartAsync = false;
242 if (needStartAsync) {
243 current_iterations[request] = iteration++;
244 inferRequest.StartAsync();
248 requests[request].SetCompletionCallback<callback_t>(callback);
251 auto doneFuture = done.get_future();
253 for (std::size_t request = 0; request < num_requests; ++request) {
254 current_iterations[request] = iteration++;
255 requests[request].StartAsync();
259 printPerformanceCounts(performance, FLAGS_report);
260 } catch (const std::exception &error) {
261 std::cerr << error.what() << std::endl;
264 std::cerr << "Unknown/internal exception happened." << std::endl;