Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / transform / transform_network.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <transform/transform_network.hpp>
6 #include <limits>
7 #include <string>
8 #include <vector>
9 #include <memory>
10 #include <map>
11
12 using namespace InferenceEngine;
13
14 Transform::Port::Port(Builder::Network& network, PortInfo port, bool isInput)
15     : network(network), port(port), input(isInput) {
16     const auto& layer = network.getLayer(port.layerId());
17     if (isInput) {
18         if (layer->getInputPorts().size() < port.portId())
19             THROW_IE_EXCEPTION << "Cannot find input port "
20                                << port.portId() << " in layer "
21                                << layer->getName();
22     } else {
23         if (layer->getOutputPorts().size() < port.portId())
24             THROW_IE_EXCEPTION << "Cannot find output port "
25                                << port.portId() << " in layer "
26                                << layer->getName();
27     }
28 }
29
30 PortData::Ptr Transform::Port::getData() const {
31     return input ?
32            network.getLayer(port.layerId())->getInputPorts()[port.portId()].getData() :
33            network.getLayer(port.layerId())->getOutputPorts()[port.portId()].getData();
34 }
35
36 const std::map<std::string, Parameter> &Transform::Port::getParameters() const {
37     return input ?
38            network.getLayer(port.layerId())->getInputPorts()[port.portId()].getParameters() :
39            network.getLayer(port.layerId())->getOutputPorts()[port.portId()].getParameters();
40 }
41
42 Transform::Layer Transform::Port::getLayer() const {
43     return Transform::Network(network).getLayer(getPortInfo().layerId());
44 }
45
46 Transform::Connection Transform::Port::getConnection() const {
47     return Connection(*this);
48 }
49
50 void Transform::Port::connect(const Port& port) {
51     if (this->input)
52         this->getConnection().setSource(port);
53     else
54         this->getConnection().addDestination(port);
55 }
56
57 void Transform::Port::disconnect() {
58     getConnection().remove();
59 }
60
61 const SizeVector& Transform::Port::shape() const {
62     return this->getData()->getData()->getTensorDesc().getDims();
63 }
64
65 PortInfo Transform::Port::getPortInfo() const {
66     return port;
67 }
68
69 bool Transform::Port::operator==(const Port& rObj) const {
70     return &network == &rObj.network &&
71            port == rObj.port &&
72            input == rObj.input;
73 }
74
75 bool Transform::Port::operator!=(const Port& rObj) const {
76     return !(*this == rObj);
77 }
78
79
80 Transform::Layer::Layer(Builder::Network& network, idx_t id)
81     : network(network), layerId(id) {}
82
83 idx_t Transform::Layer::getId() const {
84     return layerId;
85 }
86
87 std::string Transform::Layer::getName() const {
88     return getLayer()->getName();
89 }
90
91 std::string Transform::Layer::getType() const {
92     return getLayer()->getType();
93 }
94
95 Builder::Layer::Ptr Transform::Layer::getLayer() const {
96     return network.getLayer(layerId);
97 }
98
99 Transform::Layer::operator Builder::Layer::Ptr() const {
100     return getLayer();
101 }
102
103 Transform::Port Transform::Layer::getInPort() const {
104     if (getLayer()->getInputPorts().size() != 1)
105         THROW_IE_EXCEPTION << "Layer " << getName()
106                            << " has more than 1 input port.";
107     return Transform::Port(network, {layerId, 0}, true);
108 }
109
110 Transform::Port Transform::Layer::getInPort(idx_t idx) const {
111     if (getLayer()->getInputPorts().size() <= idx)
112         THROW_IE_EXCEPTION << "Layer " << getName()
113                            << " has less than " << idx << " input port(s).";
114     return Transform::Port(network, {layerId, idx}, true);
115 }
116
117 std::vector<Transform::Port> Transform::Layer::getInPorts() const {
118     std::vector<Transform::Port> ports;
119     for (size_t i = 0; i < getLayer()->getInputPorts().size(); i++) {
120         ports.push_back({network, {layerId, i}, true});
121     }
122     return ports;
123 }
124
125 Transform::Port Transform::Layer::getOutPort() const {
126     if (getLayer()->getOutputPorts().size() != 1)
127         THROW_IE_EXCEPTION << "Layer " << getName()
128                            << " has more than 1 output port.";
129     return Transform::Port(network, {layerId, 0}, false);
130 }
131
132 Transform::Port Transform::Layer::getOutPort(idx_t idx) const {
133     if (getLayer()->getOutputPorts().size() <= idx)
134         THROW_IE_EXCEPTION << "Layer " << getName()
135                            << " has less than " << idx << " output port(s).";
136     return Transform::Port(network, {layerId, idx}, false);
137 }
138
139 std::vector<Transform::Port> Transform::Layer::getOutPorts() const {
140     std::vector<Transform::Port> ports;
141     for (size_t i = 0; i < getLayer()->getInputPorts().size(); i++) {
142         ports.push_back({network, {layerId, i}, false});
143     }
144     return ports;
145 }
146
147 void Transform::Layer::setParameter(const std::string& key, const Parameter& value) {
148     auto& params = getLayer()->getParameters();
149     params[key] = value;
150 }
151
152 Parameter& Transform::Layer::getParameter(const std::string& key) const {
153     auto& params = getLayer()->getParameters();
154     if (params.find(key) == params.end())
155         THROW_IE_EXCEPTION << "Layer " << getName() << " has no parameter " << key;
156     return params[key];
157 }
158
159 Transform::Connection::Connection(const Transform::Port& port)
160     : network(port.network), inPort({(std::numeric_limits<idx_t>::max)(), (std::numeric_limits<idx_t>::max)()}) {
161     if (port.input) {
162         outPorts = {port.getPortInfo()};
163         for (const auto& connection : network.getLayerConnections(port.getPortInfo().layerId())) {
164             if (connection.to() == port.getPortInfo()) {
165                 inPort = connection.from();
166                 break;
167             }
168         }
169     } else {
170         inPort = port.getPortInfo();
171         for (const auto& connection : network.getLayerConnections(port.getPortInfo().layerId())) {
172             if (connection.from() == port.getPortInfo()) {
173                 outPorts.emplace_back(connection.to());
174             }
175         }
176     }
177 }
178 Transform::Connection::Connection(Builder::Network& network, const InferenceEngine::Connection& connection)
179     : Connection(network, connection.from(), connection.to()) {}
180 Transform::Connection::Connection(Builder::Network& network, const PortInfo& inPort, const PortInfo& outPort)
181     : Connection(network, inPort, std::vector<PortInfo>({outPort})) {}
182 Transform::Connection::Connection(Builder::Network& network, const PortInfo& inPort, const std::vector<PortInfo>& outPorts)
183     : network(network), inPort(inPort), outPorts(outPorts) {}
184
185 Transform::Port Transform::Connection::getSource() const {
186     if (!inPortExist())
187         THROW_IE_EXCEPTION << "Connection doesn't have source port!";
188     return Port(network, inPort, false);
189 }
190
191 void Transform::Connection::setSource(const Transform::Port &port) {
192     if (inPortExist()) {
193         // disconnect old port
194         for (const auto& outPort : outPorts) {
195             network.disconnect({inPort, outPort});
196         }
197     }
198     inPort = port.getPortInfo();
199     for (const auto& outPort : outPorts) {
200         network.connect(inPort, outPort);
201     }
202 }
203
204 Transform::Port Transform::Connection::getDestination() const {
205     if (outPorts.size() != 1)
206         THROW_IE_EXCEPTION << "Connection has more than 1 output.";
207     return Transform::Port(network, outPorts[0], true);
208 }
209
210 Transform::Port Transform::Connection::getDestination(idx_t idx) {
211     if (outPorts.size() <= idx)
212         THROW_IE_EXCEPTION << "Connection has less than "
213                            << idx << " input port(s).";
214     return Transform::Port(network, outPorts[idx], true);
215 }
216
217 std::vector<Transform::Port> Transform::Connection::getDestinations() const {
218     std::vector<Transform::Port> ports;
219     for (const auto& port : outPorts) {
220         ports.emplace_back(network, port, true);
221     }
222     return ports;
223 }
224
225 void Transform::Connection::addDestination(const Transform::Port &port) {
226     for (const auto& outPort : outPorts) {
227         if (outPort == port.getPortInfo()) {
228             THROW_IE_EXCEPTION << "Cannot connect twice with one port!";
229         }
230     }
231     outPorts.emplace_back(port.getPortInfo());
232     if (!inPortExist())
233         return;
234     network.connect(inPort, outPorts[outPorts.size() - 1]);
235 }
236
237 void Transform::Connection::setDestination(const Transform::Port &port) {
238     if (outPorts.size() > 1) {
239         THROW_IE_EXCEPTION << "Cannot set destination for connection which has more than 1 consumer."
240                            << "Please use addDestination or setDestinations methods!";
241     }
242
243     if (!outPorts.empty()) {
244         if (inPortExist())
245             network.disconnect({inPort, outPorts[0]});
246         outPorts.clear();
247     }
248     addDestination(port);
249 }
250
251 void Transform::Connection::setDestinations(const std::vector<Transform::Port> &ports) {
252     if (!outPorts.empty() && outPorts.size() != ports.size())
253         THROW_IE_EXCEPTION << "Cannot change number of output connections!";
254
255     if (inPortExist()) {
256         for (const auto &port : outPorts) {
257             network.disconnect({inPort, port});
258         }
259     }
260     outPorts.clear();
261     for (const auto &port : ports) {
262         addDestination(port);
263     }
264 }
265
266 void Transform::Connection::remove() {
267     if (!inPortExist())
268         return;
269     for (const auto& port : outPorts) {
270         network.disconnect({inPort, port});
271     }
272 }
273
274 bool Transform::Connection::inPortExist() const {
275     static PortInfo uninitPort((std::numeric_limits<idx_t>::max)(), (std::numeric_limits<idx_t>::max)());
276     return inPort != uninitPort;
277 }
278
279 Transform::Layer Transform::Network::addLayer(const Builder::Layer &layer) {
280     idx_t layerId = network.addLayer(layer);
281     return Transform::Layer(network, layerId);
282 }
283
284 void Transform::Network::removeLayer(const Transform::Layer &layer) {
285     for (const auto& connection : network.getLayerConnections(layer.getId()))
286         network.disconnect(connection);
287     network.removeLayer(layer.getId());
288 }
289
290 Transform::Layer Transform::Network::getLayer(const std::string &name) const {
291     for (const auto& layer : network) {
292         if (layer->getName() == name)
293             return Transform::Layer(network, layer->getId());
294     }
295     THROW_IE_EXCEPTION << "Layer with name: " << name << " was not found!";
296 }
297
298 Transform::Layer Transform::Network::getLayer(idx_t id) const {
299     for (const auto& layer : network) {
300         if (layer->getId() == id)
301             return Transform::Layer(network, layer->getId());
302     }
303     THROW_IE_EXCEPTION << "Layer with id: " << id << " was not found!";
304 }
305
306 Transform::Connection Transform::Network::connect(const Transform::Layer &src,
307         const Transform::Layer &dst) {
308     Port srcPort = src.getOutPort();
309     Port dstPort = dst.getInPort();
310
311     network.connect(srcPort.getPortInfo(), dstPort.getPortInfo());
312     return Connection(network, srcPort.getPortInfo(), dstPort.getPortInfo());
313 }
314
315 Transform::Connection Transform::Network::connect(const Transform::Port &src,
316         const Transform::Port &dst) {
317     network.connect(src.getPortInfo(), dst.getPortInfo());
318     return Connection(network, src.getPortInfo(), dst.getPortInfo());
319 }
320
321 void Transform::Network::disconnect(const Transform::Layer &src, const Transform::Layer &dst) {
322     getConnection(src, dst).remove();
323 }
324
325 void Transform::Network::disconnect(const Transform::Port &src, const Transform::Port &dst) {
326     getConnection(src, dst).remove();
327 }
328
329 Builder::Network& Transform::Network::getBuilderNetwork() const {
330     return network;
331 }
332
333 Transform::Connection Transform::Network::getConnection(const Transform::Layer &src,
334         const Transform::Layer &dst) const {
335     Port srcPort = src.getOutPort();
336     Port dstPort = dst.getInPort();
337
338     for (const auto& connection : network.getConnections()) {
339         if (connection.from() == srcPort.getPortInfo() && connection.to() == dstPort.getPortInfo())
340             return Connection(network, srcPort.getPortInfo(), dstPort.getPortInfo());
341     }
342     THROW_IE_EXCEPTION << "Connection " << src.getName() << " -> " << dst.getName() << " was not found!";
343 }
344
345 Transform::Connection Transform::Network::getConnection(const Transform::Port &src,
346         const Transform::Port &dst) const {
347     for (const auto& connection : network.getConnections()) {
348         if (connection.from() == src.getPortInfo() && connection.to() == dst.getPortInfo())
349             return Connection(network, src.getPortInfo(), dst.getPortInfo());
350     }
351     THROW_IE_EXCEPTION << "Connection " << getLayer(src.getPortInfo().layerId()).getName()
352         << " -> " << getLayer(dst.getPortInfo().layerId()).getName() << " was not found!";
353 }