1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
13 #include <ie_cnn_net_reader_impl.h>
14 #include "ie_format_parser.h"
15 #include <file_utils.h>
16 #include <ie_plugin.hpp>
17 #include "xml_parse_utils.h"
20 using namespace InferenceEngine;
21 using namespace InferenceEngine::details;
23 std::string CNNNetReaderImpl::NameFromFilePath(const char* filepath) {
24 string modelName = filepath;
25 auto slashPos = modelName.rfind('/');
26 slashPos = slashPos == std::string::npos ? 0 : slashPos + 1;
27 auto dotPos = modelName.rfind('.');
28 if (dotPos != std::string::npos) {
29 modelName = modelName.substr(slashPos, dotPos - slashPos);
31 modelName = modelName.substr(slashPos);
36 CNNNetReaderImpl::CNNNetReaderImpl(const FormatParserCreator::Ptr& _creator)
37 : parseSuccess(false), _version(0), parserCreator(_creator) {}
39 StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, ResponseDesc* desc) noexcept {
41 return DescriptionBuffer(desc) << "network must be read first";
44 _parser->SetWeights(weights);
46 catch (const InferenceEngineException& iee) {
47 return DescriptionBuffer(desc) << iee.what();
53 int CNNNetReaderImpl::GetFileVersion(pugi::xml_node& root) {
54 return XMLParseUtils::GetIntAttr(root, "version", 0);
57 StatusCode CNNNetReaderImpl::ReadNetwork(const void* model, size_t size, ResponseDesc* resp) noexcept {
59 return DescriptionBuffer(NETWORK_NOT_READ, resp) << "Network has been read already, use new reader instance to read new network.";
62 pugi::xml_document xmlDoc;
63 pugi::xml_parse_result res = xmlDoc.load_buffer(model, size);
64 if (res.status != pugi::status_ok) {
65 return DescriptionBuffer(resp) << res.description() << "at offset " << res.offset;
67 StatusCode ret = ReadNetwork(xmlDoc);
69 return DescriptionBuffer(resp) << "Error reading network: " << description;
74 StatusCode CNNNetReaderImpl::ReadWeights(const char* filepath, ResponseDesc* resp) noexcept {
75 int64_t fileSize = FileUtils::fileSize(filepath);
78 return DescriptionBuffer(resp) << "filesize for: " << filepath << " - " << fileSize
79 << "<0. Please, check weights file existence.";
81 if (network.get() == nullptr) {
82 return DescriptionBuffer(resp) << "network is empty";
85 size_t ulFileSize = static_cast<size_t>(fileSize);
87 TBlob<uint8_t>::Ptr weightsPtr(new TBlob<uint8_t>(Precision::U8, C, {ulFileSize}));
88 weightsPtr->allocate();
90 FileUtils::readAllFile(filepath, weightsPtr->buffer(), ulFileSize);
92 catch (const InferenceEngineException& iee) {
93 return DescriptionBuffer(resp) << iee.what();
96 return SetWeights(weightsPtr, resp);
99 StatusCode CNNNetReaderImpl::ReadNetwork(const char* filepath, ResponseDesc* resp) noexcept {
101 return DescriptionBuffer(NETWORK_NOT_READ, resp) << "Network has been read already, use new reader instance to read new network.";
104 pugi::xml_document xmlDoc;
105 pugi::xml_parse_result res = xmlDoc.load_file(filepath);
106 if (res.status != pugi::status_ok) {
107 std::ifstream t(filepath);
108 std::string str((std::istreambuf_iterator<char>(t)),
109 std::istreambuf_iterator<char>());
113 for (auto token : str) {
120 if (pos >= res.offset) {
125 return DescriptionBuffer(resp) << "Error loading xmlfile: " << filepath << ", " << res.description()
126 << " at line: " << line << " pos: " << pos;
128 StatusCode ret = ReadNetwork(xmlDoc);
130 return DescriptionBuffer(resp) << "Error reading network: " << description;
135 StatusCode CNNNetReaderImpl::ReadNetwork(pugi::xml_document& xmlDoc) {
139 // check which version it is...
140 pugi::xml_node root = xmlDoc.document_element();
142 _version = GetFileVersion(root);
143 if (_version < 1) THROW_IE_EXCEPTION << "deprecated IR version: " << _version;
144 if (_version > 5) THROW_IE_EXCEPTION << "cannot parse future versions: " << _version;
145 _parser = parserCreator->create(_version);
146 network = _parser->Parse(root);
147 name = network->getName();
148 network->validate(_version);
150 } catch (const std::string& err) {
152 parseSuccess = false;
153 return GENERAL_ERROR;
154 } catch (const InferenceEngineException& e) {
155 description = e.what();
156 parseSuccess = false;
157 return GENERAL_ERROR;
158 } catch (const std::exception& e) {
159 description = e.what();
160 parseSuccess = false;
161 return GENERAL_ERROR;
163 description = "Unknown exception thrown";
164 parseSuccess = false;
171 std::shared_ptr<IFormatParser> V2FormatParserCreator::create(int version) {
172 return std::make_shared<FormatParser>(version);
175 INFERENCE_ENGINE_API(InferenceEngine::ICNNNetReader*) InferenceEngine::CreateCNNNetReader() noexcept {
176 return new CNNNetReaderImpl(std::make_shared<V2FormatParserCreator>());