1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 namespace IE = InferenceEngine;
9 using LayersMap = std::unordered_map<std::string, IE::CNNLayerPtr>;
10 using DataMap = std::unordered_map<std::string, IE::DataPtr>;
11 using InputsSet = std::unordered_set<IE::InputInfo::Ptr>;
16 NetBuilder() = default;
18 NetBuilder(const NetBuilder&) = delete;
20 template<typename... Args>
21 NetBuilder& data(Args&& ... args) {
22 auto newData = std::make_shared<IE::Data>(std::forward<Args>(args)...);
23 assert(!IE::contains(_data, newData->getName()));
24 _data[newData->getName()] = newData;
28 template<typename T, typename... Args>
29 NetBuilder& layer(Args&& ... args) {
30 auto newLayer = std::make_shared<T>(std::forward<Args>(args)...);
31 assert(!IE::contains(_layers, newLayer->name));
32 _layers[newLayer->name] = std::static_pointer_cast<IE::CNNLayer>(newLayer);
36 const LayersMap& getLayersMap() const {
40 const DataMap& getDataMap() const {
44 NetBuilder& linkDataTo(const std::string& dataName,
45 const std::string& nextlayerName) {
46 assert(IE::contains(_layers, nextlayerName));
47 assert(IE::contains(_data, dataName));
49 auto nextlayer = _layers[nextlayerName];
50 auto data = _data[dataName];
52 nextlayer->insData.push_back(data);
53 data->getInputTo().insert({nextlayerName, nextlayer});
57 NetBuilder& linkToData(const std::string& prevlayerName,
58 const std::string& dataName) {
59 assert(IE::contains(_layers, prevlayerName));
60 assert(IE::contains(_data, dataName));
62 auto prevlayer = _layers[prevlayerName];
63 auto data = _data[dataName];
64 assert(nullptr == data->getCreatorLayer().lock());
66 prevlayer->outData.push_back(data);
67 data->getCreatorLayer() = prevlayer;
71 NetBuilder& linkLayers(const std::string& prevlayerName,
72 const std::string& nextlayerName,
73 const std::string& dataName) {
74 linkToData(prevlayerName, dataName);
75 linkDataTo(dataName, nextlayerName);
79 NetBuilder& linkData(const std::string& prevDataName,
80 const std::string& nextDataName,
81 const std::string& layerName) {
82 linkDataTo(prevDataName, layerName);
83 linkToData(layerName, nextDataName);
87 template<typename... Args>
88 NetBuilder& addInput(const std::string& dataName, Args&& ... args) {
89 assert(!dataName.empty());
90 assert(IE::contains(_data, dataName));
91 auto input = std::make_shared<IE::InputInfo>(
92 std::forward<Args>(args)...);
93 input->setInputData(_data[dataName]);
94 _inputs.insert(std::move(input));
98 IE::details::CNNNetworkImplPtr finalize() {
99 auto net = std::make_shared<IE::details::CNNNetworkImpl>();
101 for (auto&& it: _data) {
102 auto& data = it.second;
103 net->getData(it.first) = data;
104 if (nullptr == data->getCreatorLayer().lock()) {
105 auto input = std::make_shared<IE::InputInfo>();
106 input->setInputData(data);
107 net->setInputInfo(input);
110 for (auto&& it: _layers) {
111 net->addLayer(it.second);
113 for (auto& i : _inputs) {
114 net->setInputInfo(std::move(i));
117 net->resolveOutput();