Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_cnn_net_reader_impl.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "ie_icnn_net_reader.h"
8 #include "cnn_network_impl.hpp"
9 #include "parsers.h"
10 #include <memory>
11 #include <string>
12 #include <map>
13
14 namespace pugi {
15 class xml_node;
16
17 class xml_document;
18 }  // namespace pugi
19
20 namespace InferenceEngine {
21 namespace details {
22
23 struct FormatParserCreator {
24     using Ptr = std::shared_ptr<FormatParserCreator>;
25     virtual std::shared_ptr<IFormatParser> create(int version) = 0;
26 };
27
28 struct V2FormatParserCreator : public FormatParserCreator {
29     std::shared_ptr<IFormatParser> create(int version) override;
30 };
31
32 class CNNNetReaderImpl : public ICNNNetReader {
33 public:
34     static std::string NameFromFilePath(const char *filepath);
35
36     explicit CNNNetReaderImpl(const FormatParserCreator::Ptr& _parserCreator);
37
38     StatusCode ReadNetwork(const char *filepath, ResponseDesc *resp) noexcept override;
39
40     StatusCode ReadNetwork(const void *model, size_t size, ResponseDesc *resp)noexcept override;
41
42     StatusCode SetWeights(const TBlob<uint8_t>::Ptr &weights, ResponseDesc *resp) noexcept override;
43
44     StatusCode ReadWeights(const char *filepath, ResponseDesc *resp) noexcept override;
45
46     ICNNNetwork *getNetwork(ResponseDesc *resp) noexcept override {
47         return network.get();
48     }
49
50
51     bool isParseSuccess(ResponseDesc *resp) noexcept override {
52         return parseSuccess;
53     }
54
55
56     StatusCode getDescription(ResponseDesc *desc) noexcept override {
57         return DescriptionBuffer(OK, desc) << description;
58     }
59
60
61     StatusCode getName(char *name, size_t len, ResponseDesc *resp) noexcept override {
62         strncpy(name, this->name.c_str(), len - 1);
63         if (len) name[len-1] = '\0';  // strncpy is not doing this, so output might be not null-terminated
64         return OK;
65     }
66
67     int getVersion(ResponseDesc * resp) noexcept override {
68         return _version;
69     }
70
71     void Release() noexcept override {
72         delete this;
73     }
74
75 private:
76     std::shared_ptr<InferenceEngine::details::IFormatParser> _parser;
77
78     static int GetFileVersion(pugi::xml_node &root);
79
80     StatusCode ReadNetwork(pugi::xml_document &xmlDoc);
81
82     std::string description;
83     std::string name;
84     InferenceEngine::details::CNNNetworkImplPtr network;
85     bool parseSuccess;
86     int _version;
87     FormatParserCreator::Ptr parserCreator;
88 };
89 }  // namespace details
90 }  // namespace InferenceEngine