Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / transform / transform_network.hpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_parameter.hpp>
8 #include <ie_builders.hpp>
9 #include <string>
10 #include <vector>
11 #include <memory>
12 #include <map>
13
14 namespace InferenceEngine {
15 namespace Transform {
16
17 class Connection;
18 class Layer;
19
20 class INFERENCE_ENGINE_API_CLASS(Port) {
21 public:
22     Port(Builder::Network& network, PortInfo port, bool isInput);
23     PortData::Ptr getData() const;
24     const std::map<std::string, Parameter>& getParameters() const;
25     Layer getLayer() const;
26     Connection getConnection() const;
27     void connect(const Port& port);
28     void disconnect();
29     const SizeVector& shape() const;
30     PortInfo getPortInfo() const;
31     bool operator==(const Port& rObj) const;
32     bool operator!=(const Port& rObj) const;
33
34 private:
35     Builder::Network& network;
36     PortInfo port;
37     bool input;
38
39     friend class Connection;
40 };
41
42 class INFERENCE_ENGINE_API_CLASS(Layer) {
43 public:
44     Layer(Builder::Network& network, idx_t id);
45     Port getInPort() const;
46     Port getInPort(idx_t idx) const;
47     std::vector<Port> getInPorts() const;
48     Port getOutPort() const;
49     Port getOutPort(idx_t idx) const;
50     std::vector<Port> getOutPorts() const;
51
52     void setParameter(const std::string& key, const Parameter& value);
53     Parameter& getParameter(const std::string& value) const;
54
55     idx_t getId() const;
56     std::string getName() const;
57     std::string getType() const;
58     operator Builder::Layer::Ptr() const;
59
60 private:
61     Builder::Network& network;
62     idx_t layerId;
63
64     Builder::Layer::Ptr getLayer() const;
65 };
66
67 class INFERENCE_ENGINE_API_CLASS(Connection) {
68 public:
69     explicit Connection(const Port& port);
70     Connection(Builder::Network& network, const InferenceEngine::Connection& connection);
71     Connection(Builder::Network& network, const PortInfo& inPort, const PortInfo& outPort);
72     Connection(Builder::Network& network, const PortInfo& inPort, const std::vector<PortInfo>& outPorts);
73
74     Port getSource() const;
75     void setSource(const Port& port);
76     Port getDestination() const;
77     Port getDestination(idx_t idx);
78     std::vector<Port> getDestinations() const;
79     void addDestination(const Port& port);
80     void setDestination(const Port& port);
81     void setDestinations(const std::vector<Port>& ports);
82     void remove();
83
84 private:
85     Builder::Network& network;
86     PortInfo inPort;
87     std::vector<PortInfo> outPorts;
88
89     bool inPortExist() const;
90 };
91
92 class INFERENCE_ENGINE_API_CLASS(Network) {
93 public:
94     explicit Network(Builder::Network& network): network(network) {}
95     virtual ~Network() = default;
96
97     Layer addLayer(const Builder::Layer& layer);
98     void removeLayer(const Layer& layer);
99     Layer getLayer(const std::string& name) const;
100     Layer getLayer(idx_t id) const;
101
102     Builder::Network& getBuilderNetwork() const;
103
104     Connection connect(const Layer& src, const Layer& dst);
105     Connection connect(const Port& src, const Port& dst);
106     void disconnect(const Layer& src, const Layer& dst);
107     void disconnect(const Port& src, const Port& dst);
108     Connection getConnection(const Layer& src, const Layer& dst) const;
109     Connection getConnection(const Port& src, const Port& dst) const;
110
111 private:
112     Builder::Network& network;
113 };
114
115 }  // namespace Transform
116 }  // namespace InferenceEngine