Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / helpers / xml_helper.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6 #include <string>
7 #include <list>
8 #include <sstream>
9 #include <memory>
10 #include <map>
11 #include <vector>
12 #include "inference_engine/parsers.h"
13 #include "pugixml.hpp"
14 #include <fstream>
15 #include <stdio.h>
16 #include "cpp/ie_cnn_network.h"
17 #include <gtest/gtest.h>
18 #include "ie_icnn_network_stats.hpp"
19
20 namespace testing {
21         class XMLHelper {
22     public:
23         XMLHelper(InferenceEngine::details::IFormatParser* p) {
24             parser.reset(p);
25             _doc.reset(new pugi::xml_document());
26             _root.reset(new pugi::xml_node());
27         }
28         void loadContent(const std::string &fileContent) {
29             auto res = _doc->load_string(fileContent.c_str());
30             EXPECT_EQ(pugi::status_ok, res.status) << res.description() << " at offset " << res.offset;
31             *_root = _doc->document_element();
32         }
33
34         void loadFile(const std::string &filename) {
35             auto res = _doc->load_file(filename.c_str());
36             EXPECT_EQ(pugi::status_ok, res.status) << res.description() << " at offset " << res.offset;
37             *_root = _doc->document_element();
38         }
39
40         void parse() {
41             parser->Parse(*_root);
42         }
43
44         InferenceEngine::details::CNNNetworkImplPtr parseWithReturningNetwork() {
45             return parser->Parse(*_root);
46         }
47
48         void setWeights(const InferenceEngine::TBlob<uint8_t>::Ptr &weights) {
49             parser->SetWeights(weights);
50         }
51
52         std::string readFileContent(const std::string & filePath) {
53             const auto openFlags = std::ios_base::ate | std::ios_base::binary;
54             std::ifstream fp (getXmlPath(filePath), openFlags);
55             EXPECT_TRUE(fp.is_open());
56
57             std::streamsize size = fp.tellg();
58             EXPECT_GE( size , 1) << "file is empty: " << filePath;
59
60             std::string str;
61
62             str.reserve((size_t)size);
63             fp.seekg(0, std::ios::beg);
64
65             str.assign((std::istreambuf_iterator<char>(fp)),
66                        std::istreambuf_iterator<char>());
67             return str;
68         }
69
70     private:
71         std::string getXmlPath(const std::string & filePath){
72             std::string xmlPath = filePath;
73             const auto openFlags = std::ios_base::ate | std::ios_base::binary;
74             std::ifstream fp (xmlPath, openFlags);
75             //TODO: Dueto multi directory build systems, and single directory build system
76             //, it is usualy a problem to deal with relative paths.
77             if (!fp.is_open()) {
78                 fp.open(getParentDir(xmlPath), openFlags);
79                 EXPECT_TRUE(fp.is_open())
80                 << "cannot open file " << xmlPath <<" or " << getParentDir(xmlPath);
81                 fp.close();
82                 xmlPath = getParentDir(xmlPath);
83             }
84             return xmlPath;
85         }
86
87         const char kPathSeparator =
88 #if defined _WIN32 || defined __CYGWIN__
89         '\\';
90 #else
91         '/';
92 #endif
93         const std::string parentDir = std::string("..") + kPathSeparator;
94         std::string getParentDir(std::string currentFile) const {
95             return parentDir + currentFile;
96         }
97         std::unique_ptr<InferenceEngine::details::IFormatParser> parser;
98         std::vector<std::string> _classes;
99         std::unique_ptr<pugi::xml_node> _root;
100         std::unique_ptr<pugi::xml_document> _doc;
101         };
102
103 inline InferenceEngine::NetworkStatsMap loadStatisticFromFile(const std::string& xmlPath) {
104     auto splitParseCommas = [&](const std::string& s) ->std::vector<float> {
105         std::vector<float> res;
106         std::stringstream ss(s);
107
108         float val;
109
110         while (ss >> val) {
111             res.push_back(val);
112
113             if (ss.peek() == ',')
114                 ss.ignore();
115         }
116
117         return res;
118     };
119
120     InferenceEngine::NetworkStatsMap newNetNodesStats;
121
122     pugi::xml_document doc;
123
124     pugi::xml_parse_result pr = doc.load_file(xmlPath.c_str());
125
126
127     if (!pr) {
128         THROW_IE_EXCEPTION << "Can't load stat file " << xmlPath;
129     }
130
131     auto stats = doc.child("stats");
132     auto layers = stats.child("layers");
133
134     InferenceEngine::NetworkNodeStatsPtr nodeStats;
135     size_t offset;
136     size_t size;
137     size_t count;
138
139     for (auto layer : layers.children("layer")) {
140         nodeStats = InferenceEngine::NetworkNodeStatsPtr(new InferenceEngine::NetworkNodeStats());
141
142         std::string name = layer.child("name").text().get();
143
144         newNetNodesStats[name] = nodeStats;
145
146         nodeStats->_minOutputs = splitParseCommas(layer.child("min").text().get());
147         nodeStats->_maxOutputs = splitParseCommas(layer.child("max").text().get());
148     }
149
150     return newNetNodesStats;
151 }
152
153 }