Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / helpers / test_models_path.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 //
6 // Created by user on 19.10.18.
7 //
8
9 #include "test_model_repo.hpp"
10 #include "test_model_path.hpp"
11
12 #ifndef _WIN32
13 # include <libgen.h>
14 # include <dirent.h>
15 #else
16 # include <os/windows/w_dirent.h>
17 #endif
18
19 #include <vector>
20 #include <iostream>
21 #include <gtest/gtest.h>
22 #include <fstream>
23
24 #ifndef _WIN32
25 static std::string getDirname (std::string filePath) {
26     std::vector<char> input(filePath.begin(), filePath.end());
27     input.push_back(0);
28     return dirname(&*input.begin());
29 }
30 #else
31 static std::string getDirname (std::string filePath) {
32         char dirname[_MAX_DIR];
33         _splitpath(filePath.c_str(), nullptr, dirname, nullptr, nullptr);
34         return dirname;
35     }
36 #endif
37
38 const char* getModelPathNonFatal() noexcept {
39 #ifdef MODELS_PATH
40     const char* models_path = std::getenv("MODELS_PATH");
41
42     if (models_path == nullptr && MODELS_PATH == nullptr) {
43         return nullptr;
44     }
45
46     if (models_path == nullptr) {
47         return MODELS_PATH;
48     }
49
50     return models_path;
51 #else
52     return nullptr;
53 #endif
54 }
55
56
57 static std::string get_models_path() {
58     const char* models_path = getModelPathNonFatal();
59
60     if (nullptr == models_path) {
61         ::testing::AssertionFailure() << "MODELS_PATH not defined";
62     }
63
64     return std::string(models_path);
65 }
66
67 static bool exist(const std::string& name) {
68     std::ifstream file(name);
69     if(!file)            // If the file was not found, then file is 0, i.e. !file=1 or true.
70         return false;    // The file was not found.
71     else                 // If the file was found, then file is non-0.
72         return true;     // The file was found.
73 }
74
75 static std::vector<std::string> getModelsDirs() {
76     auto repo_list = get_model_repo();
77     int last_delimiter = 0;
78     std::vector<std::string> folders;
79     for(;;) {
80         auto folderDelimiter = repo_list.find(':', last_delimiter);
81         if (folderDelimiter == std::string::npos) {
82             break;
83         }
84         auto nextDelimiter = repo_list.find(';', last_delimiter);
85         folders.push_back(repo_list.substr(last_delimiter, folderDelimiter - last_delimiter));
86
87         if (nextDelimiter == std::string::npos) {
88             break;
89         }
90
91         last_delimiter = nextDelimiter + 1;
92     }
93     return folders;
94 }
95
96 ModelsPath::operator std::string() const {
97
98     std::vector<std::string> absModelsPath;
99     for (auto & path  : getModelsDirs()) {
100         absModelsPath.push_back(get_models_path() + kPathSeparator + "src" + kPathSeparator + path + _rel_path.str());
101         if (exist(absModelsPath.back())) {
102             return absModelsPath.back();
103         }
104         //checking models for precision encoded in folder name
105         auto dirname = getDirname(absModelsPath.back());
106         std::vector<std::pair<std::string, std::string>> stdprecisions = {
107             {"_fp32", "FP32"},
108             {"_q78", "_Q78"},
109             {"_fp16", "FP16"},
110             {"_i16", "I16"}
111         };
112
113         auto filename = absModelsPath.back().substr(dirname.size() + 1);
114
115         for (auto &precision : stdprecisions) {
116             auto havePrecision = filename.find(precision.first);
117             if (havePrecision == std::string::npos) continue;
118
119             auto newName = filename.replace(havePrecision, precision.first.size(), "");
120             newName = dirname + kPathSeparator + precision.second + kPathSeparator + newName;
121
122             if (exist(newName)) {
123                 return newName;
124             }
125         }
126     }
127
128     // checking dirname
129     auto getModelsDirname = [](std::string path) -> std::string {
130         std::string dir = getDirname(path);
131
132         struct stat sb;
133         if (stat(dir.c_str(), &sb) != 0 || !S_ISDIR(sb.st_mode)) {
134             return "";
135         }
136         return dir;
137     };
138
139     for (auto & path : absModelsPath) {
140         std::string publicDir = getModelsDirname(path);
141
142         if (!publicDir.empty()) {
143             return path;
144         }
145     }
146     std::stringstream errorMsg;
147     errorMsg<< "path to model invalid, models found at: \n";
148
149     for (auto & path : absModelsPath) {
150         errorMsg << path <<"\n";
151     }
152     errorMsg << "also searched by parent directory names: \n";
153     for (auto & path : absModelsPath) {
154         errorMsg << getDirname(path) << "\n";
155     }
156
157     std::cout << errorMsg.str();
158     ::testing::AssertionFailure() << errorMsg.str();
159
160     // doesn't matter what to return here
161     return "";
162 }