Release 18.02
[platform/upstream/armnn.git] / src / armnn / Network.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/TensorFwd.hpp>
9 #include <armnn/Types.hpp>
10
11 #include <armnn/INetwork.hpp>
12
13 #include <string>
14 #include <vector>
15 #include <memory>
16
17 #include "Layer.hpp"
18
19 namespace armnn
20 {
21 class Graph;
22
23 /// Private implementation of INetwork
24 class Network final : public INetwork
25 {
26 public:
27     Network();
28     ~Network();
29
30     const Graph& GetGraph() const { return *m_Graph; }
31
32     Status PrintGraph() override;
33
34     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
35
36     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
37         const ConstTensor& weights,
38         const char* name = nullptr) override;
39
40     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
41         const ConstTensor& weights,
42         const ConstTensor& biases,
43         const char* name = nullptr) override;
44
45     IConnectableLayer* AddDepthwiseConvolution2dLayer(
46         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
47         const ConstTensor&                      weights,
48         const char*                             name = nullptr) override;
49
50     IConnectableLayer* AddDepthwiseConvolution2dLayer(
51         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
52         const ConstTensor&                      weights,
53         const ConstTensor&                      biases,
54         const char*                             name = nullptr) override;
55
56     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
57         const ConstTensor& weights,
58         const char* name = nullptr) override;
59
60     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
61         const ConstTensor& weights,
62         const ConstTensor& biases,
63         const char* name = nullptr) override;
64
65     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
66                                        const char* name = nullptr) override;
67
68     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
69         const char* name = nullptr) override;
70
71     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
72         const char* name = nullptr) override;
73
74     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
75         const char* name = nullptr) override;
76
77     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
78         const char* name = nullptr) override;
79
80     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
81         const char* name = nullptr) override;
82
83     IConnectableLayer* AddMergerLayer(const OriginsDescriptor& mergerDescriptor,
84         const char* name = nullptr) override;
85
86     IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
87
88     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
89
90     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
91                                                   const ConstTensor&                  mean,
92                                                   const ConstTensor&                  variance,
93                                                   const ConstTensor&                  beta,
94                                                   const ConstTensor&                  gamma,
95                                                   const char*                         name = nullptr) override;
96
97     IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
98                                               const char* name = nullptr) override;
99
100     IConnectableLayer* AddL2NormalizationLayer(const char* name = nullptr) override;
101
102     IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
103
104     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
105                                        const char* name = nullptr) override;
106
107     IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
108
109     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
110
111 private:
112     IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
113         const ConstTensor& weights,
114         const ConstTensor* biases,
115         const char* name);
116
117     IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
118         const ConstTensor& weights,
119         const ConstTensor* biases,
120         const char* name);
121
122     IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
123         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
124         const ConstTensor& weights,
125         const ConstTensor* biases,
126         const char* name);
127
128     std::unique_ptr<Graph> m_Graph;
129 };
130
131 class OptimizedNetwork final : public IOptimizedNetwork
132 {
133 public:
134     OptimizedNetwork(std::unique_ptr<Graph> graph);
135     ~OptimizedNetwork();
136
137     Status PrintGraph() override;
138
139     Graph& GetGraph() { return *m_Graph; }
140
141 private:
142     std::unique_ptr<Graph> m_Graph;
143 };
144
145 } // namespace armnn