Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_cnn_net_reader_impl.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <string>
6 #include <fstream>
7 #include <sstream>
8 #include <memory>
9 #include <map>
10
11 #include "debug.h"
12 #include "parsers.h"
13 #include <ie_cnn_net_reader_impl.h>
14 #include "ie_format_parser.h"
15 #include <file_utils.h>
16 #include <ie_plugin.hpp>
17 #include "xml_parse_utils.h"
18
19 using namespace std;
20 using namespace InferenceEngine;
21 using namespace InferenceEngine::details;
22
23 std::string CNNNetReaderImpl::NameFromFilePath(const char* filepath) {
24     string modelName = filepath;
25     auto slashPos = modelName.rfind('/');
26     slashPos = slashPos == std::string::npos ? 0 : slashPos + 1;
27     auto dotPos = modelName.rfind('.');
28     if (dotPos != std::string::npos) {
29         modelName = modelName.substr(slashPos, dotPos - slashPos);
30     } else {
31         modelName = modelName.substr(slashPos);
32     }
33     return modelName;
34 }
35
36 CNNNetReaderImpl::CNNNetReaderImpl(const FormatParserCreator::Ptr& _creator)
37         : parseSuccess(false), _version(0), parserCreator(_creator) {}
38
39 StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, ResponseDesc* desc)  noexcept {
40     if (!_parser) {
41         return DescriptionBuffer(desc) << "network must be read first";
42     }
43     try {
44         _parser->SetWeights(weights);
45     }
46     catch (const InferenceEngineException& iee) {
47         return DescriptionBuffer(desc) << iee.what();
48     }
49
50     return OK;
51 }
52
53 int CNNNetReaderImpl::GetFileVersion(pugi::xml_node& root) {
54     return XMLParseUtils::GetIntAttr(root, "version", 0);
55 }
56
57 StatusCode CNNNetReaderImpl::ReadNetwork(const void* model, size_t size, ResponseDesc* resp) noexcept {
58     if (network) {
59         return DescriptionBuffer(NETWORK_NOT_READ, resp) << "Network has been read already, use new reader instance to read new network.";
60     }
61
62     pugi::xml_document xmlDoc;
63     pugi::xml_parse_result res = xmlDoc.load_buffer(model, size);
64     if (res.status != pugi::status_ok) {
65         return DescriptionBuffer(resp) << res.description() << "at offset " << res.offset;
66     }
67     StatusCode ret = ReadNetwork(xmlDoc);
68     if (ret != OK) {
69         return DescriptionBuffer(resp) << "Error reading network: " << description;
70     }
71     return OK;
72 }
73
74 StatusCode CNNNetReaderImpl::ReadWeights(const char* filepath, ResponseDesc* resp) noexcept {
75     int64_t fileSize = FileUtils::fileSize(filepath);
76
77     if (fileSize < 0)
78         return DescriptionBuffer(resp) << "filesize for: " << filepath << " - " << fileSize
79                                        << "<0. Please, check weights file existence.";
80
81     if (network.get() == nullptr) {
82         return DescriptionBuffer(resp) << "network is empty";
83     }
84
85     size_t ulFileSize = static_cast<size_t>(fileSize);
86
87     TBlob<uint8_t>::Ptr weightsPtr(new TBlob<uint8_t>(Precision::U8, C, {ulFileSize}));
88     weightsPtr->allocate();
89     try {
90         FileUtils::readAllFile(filepath, weightsPtr->buffer(), ulFileSize);
91     }
92     catch (const InferenceEngineException& iee) {
93         return DescriptionBuffer(resp) << iee.what();
94     }
95
96     return SetWeights(weightsPtr, resp);
97 }
98
99 StatusCode CNNNetReaderImpl::ReadNetwork(const char* filepath, ResponseDesc* resp) noexcept {
100     if (network) {
101         return DescriptionBuffer(NETWORK_NOT_READ, resp) << "Network has been read already, use new reader instance to read new network.";
102     }
103
104     pugi::xml_document xmlDoc;
105     pugi::xml_parse_result res = xmlDoc.load_file(filepath);
106     if (res.status != pugi::status_ok) {
107         std::ifstream t(filepath);
108         std::string str((std::istreambuf_iterator<char>(t)),
109                         std::istreambuf_iterator<char>());
110
111         int line = 1;
112         int pos = 0;
113         for (auto token : str) {
114             if (token == '\n') {
115                 line++;
116                 pos = 0;
117             } else {
118                 pos++;
119             }
120             if (pos >= res.offset) {
121                 break;
122             }
123         }
124
125         return DescriptionBuffer(resp) << "Error loading xmlfile: " << filepath << ", " << res.description()
126                                        << " at line: " << line << " pos: " << pos;
127     }
128     StatusCode ret = ReadNetwork(xmlDoc);
129     if (ret != OK) {
130         return DescriptionBuffer(resp) << "Error reading network: " << description;
131     }
132     return OK;
133 }
134
135 StatusCode CNNNetReaderImpl::ReadNetwork(pugi::xml_document& xmlDoc) {
136     description.clear();
137
138     try {
139         // check which version it is...
140         pugi::xml_node root = xmlDoc.document_element();
141
142         _version = GetFileVersion(root);
143         if (_version < 1) THROW_IE_EXCEPTION << "deprecated IR version: " << _version;
144         if (_version > 5) THROW_IE_EXCEPTION << "cannot parse future versions: " << _version;
145         _parser = parserCreator->create(_version);
146         network = _parser->Parse(root);
147         name = network->getName();
148         network->validate(_version);
149         parseSuccess = true;
150     } catch (const std::string& err) {
151         description = err;
152         parseSuccess = false;
153         return GENERAL_ERROR;
154     } catch (const InferenceEngineException& e) {
155         description = e.what();
156         parseSuccess = false;
157         return GENERAL_ERROR;
158     } catch (const std::exception& e) {
159         description = e.what();
160         parseSuccess = false;
161         return GENERAL_ERROR;
162     } catch (...) {
163         description = "Unknown exception thrown";
164         parseSuccess = false;
165         return UNEXPECTED;
166     }
167
168     return OK;
169 }
170
171 std::shared_ptr<IFormatParser> V2FormatParserCreator::create(int version) {
172     return std::make_shared<FormatParser>(version);
173 }
174
175 INFERENCE_ENGINE_API(InferenceEngine::ICNNNetReader*) InferenceEngine::CreateCNNNetReader() noexcept {
176     return new CNNNetReaderImpl(std::make_shared<V2FormatParserCreator>());
177 }