1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <transform/transform_network.hpp>
12 using namespace InferenceEngine;
14 Transform::Port::Port(Builder::Network& network, PortInfo port, bool isInput)
15 : network(network), port(port), input(isInput) {
16 const auto& layer = network.getLayer(port.layerId());
18 if (layer->getInputPorts().size() < port.portId())
19 THROW_IE_EXCEPTION << "Cannot find input port "
20 << port.portId() << " in layer "
23 if (layer->getOutputPorts().size() < port.portId())
24 THROW_IE_EXCEPTION << "Cannot find output port "
25 << port.portId() << " in layer "
30 PortData::Ptr Transform::Port::getData() const {
32 network.getLayer(port.layerId())->getInputPorts()[port.portId()].getData() :
33 network.getLayer(port.layerId())->getOutputPorts()[port.portId()].getData();
36 const std::map<std::string, Parameter> &Transform::Port::getParameters() const {
38 network.getLayer(port.layerId())->getInputPorts()[port.portId()].getParameters() :
39 network.getLayer(port.layerId())->getOutputPorts()[port.portId()].getParameters();
42 Transform::Layer Transform::Port::getLayer() const {
43 return Transform::Network(network).getLayer(getPortInfo().layerId());
46 Transform::Connection Transform::Port::getConnection() const {
47 return Connection(*this);
50 void Transform::Port::connect(const Port& port) {
52 this->getConnection().setSource(port);
54 this->getConnection().addDestination(port);
57 void Transform::Port::disconnect() {
58 getConnection().remove();
61 const SizeVector& Transform::Port::shape() const {
62 return this->getData()->getData()->getTensorDesc().getDims();
65 PortInfo Transform::Port::getPortInfo() const {
69 bool Transform::Port::operator==(const Port& rObj) const {
70 return &network == &rObj.network &&
75 bool Transform::Port::operator!=(const Port& rObj) const {
76 return !(*this == rObj);
80 Transform::Layer::Layer(Builder::Network& network, idx_t id)
81 : network(network), layerId(id) {}
83 idx_t Transform::Layer::getId() const {
87 std::string Transform::Layer::getName() const {
88 return getLayer()->getName();
91 std::string Transform::Layer::getType() const {
92 return getLayer()->getType();
95 Builder::Layer::Ptr Transform::Layer::getLayer() const {
96 return network.getLayer(layerId);
99 Transform::Layer::operator Builder::Layer::Ptr() const {
103 Transform::Port Transform::Layer::getInPort() const {
104 if (getLayer()->getInputPorts().size() != 1)
105 THROW_IE_EXCEPTION << "Layer " << getName()
106 << " has more than 1 input port.";
107 return Transform::Port(network, {layerId, 0}, true);
110 Transform::Port Transform::Layer::getInPort(idx_t idx) const {
111 if (getLayer()->getInputPorts().size() <= idx)
112 THROW_IE_EXCEPTION << "Layer " << getName()
113 << " has less than " << idx << " input port(s).";
114 return Transform::Port(network, {layerId, idx}, true);
117 std::vector<Transform::Port> Transform::Layer::getInPorts() const {
118 std::vector<Transform::Port> ports;
119 for (size_t i = 0; i < getLayer()->getInputPorts().size(); i++) {
120 ports.push_back({network, {layerId, i}, true});
125 Transform::Port Transform::Layer::getOutPort() const {
126 if (getLayer()->getOutputPorts().size() != 1)
127 THROW_IE_EXCEPTION << "Layer " << getName()
128 << " has more than 1 output port.";
129 return Transform::Port(network, {layerId, 0}, false);
132 Transform::Port Transform::Layer::getOutPort(idx_t idx) const {
133 if (getLayer()->getOutputPorts().size() <= idx)
134 THROW_IE_EXCEPTION << "Layer " << getName()
135 << " has less than " << idx << " output port(s).";
136 return Transform::Port(network, {layerId, idx}, false);
139 std::vector<Transform::Port> Transform::Layer::getOutPorts() const {
140 std::vector<Transform::Port> ports;
141 for (size_t i = 0; i < getLayer()->getInputPorts().size(); i++) {
142 ports.push_back({network, {layerId, i}, false});
147 void Transform::Layer::setParameter(const std::string& key, const Parameter& value) {
148 auto& params = getLayer()->getParameters();
152 Parameter& Transform::Layer::getParameter(const std::string& key) const {
153 auto& params = getLayer()->getParameters();
154 if (params.find(key) == params.end())
155 THROW_IE_EXCEPTION << "Layer " << getName() << " has no parameter " << key;
159 Transform::Connection::Connection(const Transform::Port& port)
160 : network(port.network), inPort({(std::numeric_limits<idx_t>::max)(), (std::numeric_limits<idx_t>::max)()}) {
162 outPorts = {port.getPortInfo()};
163 for (const auto& connection : network.getLayerConnections(port.getPortInfo().layerId())) {
164 if (connection.to() == port.getPortInfo()) {
165 inPort = connection.from();
170 inPort = port.getPortInfo();
171 for (const auto& connection : network.getLayerConnections(port.getPortInfo().layerId())) {
172 if (connection.from() == port.getPortInfo()) {
173 outPorts.emplace_back(connection.to());
178 Transform::Connection::Connection(Builder::Network& network, const InferenceEngine::Connection& connection)
179 : Connection(network, connection.from(), connection.to()) {}
180 Transform::Connection::Connection(Builder::Network& network, const PortInfo& inPort, const PortInfo& outPort)
181 : Connection(network, inPort, std::vector<PortInfo>({outPort})) {}
182 Transform::Connection::Connection(Builder::Network& network, const PortInfo& inPort, const std::vector<PortInfo>& outPorts)
183 : network(network), inPort(inPort), outPorts(outPorts) {}
185 Transform::Port Transform::Connection::getSource() const {
187 THROW_IE_EXCEPTION << "Connection doesn't have source port!";
188 return Port(network, inPort, false);
191 void Transform::Connection::setSource(const Transform::Port &port) {
193 // disconnect old port
194 for (const auto& outPort : outPorts) {
195 network.disconnect({inPort, outPort});
198 inPort = port.getPortInfo();
199 for (const auto& outPort : outPorts) {
200 network.connect(inPort, outPort);
204 Transform::Port Transform::Connection::getDestination() const {
205 if (outPorts.size() != 1)
206 THROW_IE_EXCEPTION << "Connection has more than 1 output.";
207 return Transform::Port(network, outPorts[0], true);
210 Transform::Port Transform::Connection::getDestination(idx_t idx) {
211 if (outPorts.size() <= idx)
212 THROW_IE_EXCEPTION << "Connection has less than "
213 << idx << " input port(s).";
214 return Transform::Port(network, outPorts[idx], true);
217 std::vector<Transform::Port> Transform::Connection::getDestinations() const {
218 std::vector<Transform::Port> ports;
219 for (const auto& port : outPorts) {
220 ports.emplace_back(network, port, true);
225 void Transform::Connection::addDestination(const Transform::Port &port) {
226 for (const auto& outPort : outPorts) {
227 if (outPort == port.getPortInfo()) {
228 THROW_IE_EXCEPTION << "Cannot connect twice with one port!";
231 outPorts.emplace_back(port.getPortInfo());
234 network.connect(inPort, outPorts[outPorts.size() - 1]);
237 void Transform::Connection::setDestination(const Transform::Port &port) {
238 if (outPorts.size() > 1) {
239 THROW_IE_EXCEPTION << "Cannot set destination for connection which has more than 1 consumer."
240 << "Please use addDestination or setDestinations methods!";
243 if (!outPorts.empty()) {
245 network.disconnect({inPort, outPorts[0]});
248 addDestination(port);
251 void Transform::Connection::setDestinations(const std::vector<Transform::Port> &ports) {
252 if (!outPorts.empty() && outPorts.size() != ports.size())
253 THROW_IE_EXCEPTION << "Cannot change number of output connections!";
256 for (const auto &port : outPorts) {
257 network.disconnect({inPort, port});
261 for (const auto &port : ports) {
262 addDestination(port);
266 void Transform::Connection::remove() {
269 for (const auto& port : outPorts) {
270 network.disconnect({inPort, port});
274 bool Transform::Connection::inPortExist() const {
275 static PortInfo uninitPort((std::numeric_limits<idx_t>::max)(), (std::numeric_limits<idx_t>::max)());
276 return inPort != uninitPort;
279 Transform::Layer Transform::Network::addLayer(const Builder::Layer &layer) {
280 idx_t layerId = network.addLayer(layer);
281 return Transform::Layer(network, layerId);
284 void Transform::Network::removeLayer(const Transform::Layer &layer) {
285 for (const auto& connection : network.getLayerConnections(layer.getId()))
286 network.disconnect(connection);
287 network.removeLayer(layer.getId());
290 Transform::Layer Transform::Network::getLayer(const std::string &name) const {
291 for (const auto& layer : network) {
292 if (layer->getName() == name)
293 return Transform::Layer(network, layer->getId());
295 THROW_IE_EXCEPTION << "Layer with name: " << name << " was not found!";
298 Transform::Layer Transform::Network::getLayer(idx_t id) const {
299 for (const auto& layer : network) {
300 if (layer->getId() == id)
301 return Transform::Layer(network, layer->getId());
303 THROW_IE_EXCEPTION << "Layer with id: " << id << " was not found!";
306 Transform::Connection Transform::Network::connect(const Transform::Layer &src,
307 const Transform::Layer &dst) {
308 Port srcPort = src.getOutPort();
309 Port dstPort = dst.getInPort();
311 network.connect(srcPort.getPortInfo(), dstPort.getPortInfo());
312 return Connection(network, srcPort.getPortInfo(), dstPort.getPortInfo());
315 Transform::Connection Transform::Network::connect(const Transform::Port &src,
316 const Transform::Port &dst) {
317 network.connect(src.getPortInfo(), dst.getPortInfo());
318 return Connection(network, src.getPortInfo(), dst.getPortInfo());
321 void Transform::Network::disconnect(const Transform::Layer &src, const Transform::Layer &dst) {
322 getConnection(src, dst).remove();
325 void Transform::Network::disconnect(const Transform::Port &src, const Transform::Port &dst) {
326 getConnection(src, dst).remove();
329 Builder::Network& Transform::Network::getBuilderNetwork() const {
333 Transform::Connection Transform::Network::getConnection(const Transform::Layer &src,
334 const Transform::Layer &dst) const {
335 Port srcPort = src.getOutPort();
336 Port dstPort = dst.getInPort();
338 for (const auto& connection : network.getConnections()) {
339 if (connection.from() == srcPort.getPortInfo() && connection.to() == dstPort.getPortInfo())
340 return Connection(network, srcPort.getPortInfo(), dstPort.getPortInfo());
342 THROW_IE_EXCEPTION << "Connection " << src.getName() << " -> " << dst.getName() << " was not found!";
345 Transform::Connection Transform::Network::getConnection(const Transform::Port &src,
346 const Transform::Port &dst) const {
347 for (const auto& connection : network.getConnections()) {
348 if (connection.from() == src.getPortInfo() && connection.to() == dst.getPortInfo())
349 return Connection(network, src.getPortInfo(), dst.getPortInfo());
351 THROW_IE_EXCEPTION << "Connection " << getLayer(src.getPortInfo().layerId()).getName()
352 << " -> " << getLayer(dst.getPortInfo().layerId()).getName() << " was not found!";