Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / inference_engine_tests / util_test.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6 namespace IE = InferenceEngine;
7
8 class NetBuilder {
9     using LayersMap = std::unordered_map<std::string, IE::CNNLayerPtr>;
10     using DataMap = std::unordered_map<std::string, IE::DataPtr>;
11     using InputsSet = std::unordered_set<IE::InputInfo::Ptr>;
12     LayersMap _layers;
13     DataMap _data;
14     InputsSet _inputs;
15 public:
16     NetBuilder() = default;
17
18     NetBuilder(const NetBuilder&) = delete;
19
20     template<typename... Args>
21     NetBuilder& data(Args&& ... args) {
22         auto newData = std::make_shared<IE::Data>(std::forward<Args>(args)...);
23         assert(!IE::contains(_data, newData->getName()));
24         _data[newData->getName()] = newData;
25         return *this;
26     }
27
28     template<typename T, typename... Args>
29     NetBuilder& layer(Args&& ... args) {
30         auto newLayer = std::make_shared<T>(std::forward<Args>(args)...);
31         assert(!IE::contains(_layers, newLayer->name));
32         _layers[newLayer->name] = std::static_pointer_cast<IE::CNNLayer>(newLayer);
33         return *this;
34     }
35
36     const LayersMap& getLayersMap() const {
37         return _layers;
38     }
39
40     const DataMap& getDataMap() const {
41         return _data;
42     }
43
44     NetBuilder& linkDataTo(const std::string& dataName,
45                            const std::string& nextlayerName) {
46         assert(IE::contains(_layers, nextlayerName));
47         assert(IE::contains(_data, dataName));
48
49         auto nextlayer = _layers[nextlayerName];
50         auto data = _data[dataName];
51
52         nextlayer->insData.push_back(data);
53         data->getInputTo().insert({nextlayerName, nextlayer});
54         return *this;
55     }
56
57     NetBuilder& linkToData(const std::string& prevlayerName,
58                            const std::string& dataName) {
59         assert(IE::contains(_layers, prevlayerName));
60         assert(IE::contains(_data, dataName));
61
62         auto prevlayer = _layers[prevlayerName];
63         auto data = _data[dataName];
64         assert(nullptr == data->getCreatorLayer().lock());
65
66         prevlayer->outData.push_back(data);
67         data->getCreatorLayer() = prevlayer;
68         return *this;
69     }
70
71     NetBuilder& linkLayers(const std::string& prevlayerName,
72                            const std::string& nextlayerName,
73                            const std::string& dataName) {
74         linkToData(prevlayerName, dataName);
75         linkDataTo(dataName, nextlayerName);
76         return *this;
77     }
78
79     NetBuilder& linkData(const std::string& prevDataName,
80                          const std::string& nextDataName,
81                          const std::string& layerName) {
82         linkDataTo(prevDataName, layerName);
83         linkToData(layerName, nextDataName);
84         return *this;
85     }
86
87     template<typename... Args>
88     NetBuilder& addInput(const std::string& dataName, Args&& ... args) {
89         assert(!dataName.empty());
90         assert(IE::contains(_data, dataName));
91         auto input = std::make_shared<IE::InputInfo>(
92                 std::forward<Args>(args)...);
93         input->setInputData(_data[dataName]);
94         _inputs.insert(std::move(input));
95         return *this;
96     }
97
98     IE::details::CNNNetworkImplPtr finalize() {
99         auto net = std::make_shared<IE::details::CNNNetworkImpl>();
100
101         for (auto&& it: _data) {
102             auto& data = it.second;
103             net->getData(it.first) = data;
104             if (nullptr == data->getCreatorLayer().lock()) {
105                 auto input = std::make_shared<IE::InputInfo>();
106                 input->setInputData(data);
107                 net->setInputInfo(input);
108             }
109         }
110         for (auto&& it: _layers) {
111             net->addLayer(it.second);
112         }
113         for (auto& i : _inputs) {
114             net->setInputInfo(std::move(i));
115         }
116
117         net->resolveOutput();
118
119         return net;
120     }
121 };