7fc5b651d01209282d9cc0b6a290cc4cf154423a
[platform/upstream/armnn.git] / src / armnn / Network.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/TensorFwd.hpp>
10 #include <armnn/Types.hpp>
11
12 #include <armnn/INetwork.hpp>
13
14 #include <string>
15 #include <vector>
16 #include <map>
17 #include <memory>
18
19 #include "Layer.hpp"
20
21 namespace armnn
22 {
23 class Graph;
24
25 /// Private implementation of INetwork.
26 class Network final : public INetwork
27 {
28 public:
29     Network();
30     ~Network();
31
32     const Graph& GetGraph() const { return *m_Graph; }
33
34     Status PrintGraph() override;
35
36     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
37
38     IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
39                                               const char* name = nullptr) override;
40
41     IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
42                                       const char* name = nullptr) override;
43
44     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
45                                              const ConstTensor& weights,
46                                              const Optional<ConstTensor>& biases,
47                                              const char* name = nullptr) override;
48
49     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
50     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
51                                              const ConstTensor& weights,
52                                              const char* name = nullptr) override;
53
54     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
55     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
56                                              const ConstTensor& weights,
57                                              const ConstTensor& biases,
58                                              const char* name = nullptr) override;
59
60     IConnectableLayer* AddDepthwiseConvolution2dLayer(
61         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
62         const ConstTensor& weights,
63         const Optional<ConstTensor>& biases,
64         const char* name = nullptr) override;
65
66     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
67     IConnectableLayer* AddDepthwiseConvolution2dLayer(
68         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
69         const ConstTensor& weights,
70         const char* name = nullptr) override;
71
72     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
73     IConnectableLayer* AddDepthwiseConvolution2dLayer(
74         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
75         const ConstTensor& weights,
76         const ConstTensor& biases,
77         const char* name = nullptr) override;
78
79     IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
80
81     IConnectableLayer* AddDetectionPostProcessLayer(
82         const DetectionPostProcessDescriptor& descriptor,
83         const ConstTensor& anchors,
84         const char* name = nullptr) override;
85
86     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
87                                               const ConstTensor& weights,
88                                               const Optional<ConstTensor>& biases,
89                                               const char* name = nullptr) override;
90
91     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
92     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
93                                               const ConstTensor& weights,
94                                               const char* name = nullptr) override;
95
96     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
97     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
98                                               const ConstTensor& weights,
99                                               const ConstTensor& biases,
100                                               const char* name = nullptr) override;
101
102     IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
103
104     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
105                                        const char* name = nullptr) override;
106
107     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
108         const char* name = nullptr) override;
109
110     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
111         const char* name = nullptr) override;
112
113     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
114         const char* name = nullptr) override;
115
116     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
117         const char* name = nullptr) override;
118
119     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
120         const char* name = nullptr) override;
121
122     ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
123     IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
124                                       const char* name = nullptr) override;
125
126     IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
127
128     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
129
130     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
131                                                   const ConstTensor&                  mean,
132                                                   const ConstTensor&                  variance,
133                                                   const ConstTensor&                  beta,
134                                                   const ConstTensor&                  gamma,
135                                                   const char*                         name = nullptr) override;
136
137     ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
138     IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
139                                               const char* name = nullptr) override;
140
141     IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
142                                       const char* name = nullptr) override;
143
144     IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
145                                                const char* name = nullptr) override;
146
147     IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
148
149     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
150                                        const char* name = nullptr) override;
151
152     IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
153                                               const char* name = nullptr) override;
154
155     IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
156                                             const char* name = nullptr) override;
157
158     IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
159
160     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
161
162     IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
163                                     const LstmInputParams& params,
164                                     const char* name = nullptr) override;
165
166     IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
167
168     IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
169
170     IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
171
172     IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
173
174     IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
175
176     IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
177
178     IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
179                                             const char* name = nullptr) override;
180
181     IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
182
183     IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
184
185     IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
186
187     IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
188
189     IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
190
191     IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
192
193     IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
194
195     IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
196                                                       const ConstTensor& weights,
197                                                       const Optional<ConstTensor>& biases,
198                                                       const char* name = nullptr) override;
199
200     void Accept(ILayerVisitor& visitor) const override;
201
202 private:
203     IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
204                                                   const ConstTensor& weights,
205                                                   const Optional<ConstTensor>& biases,
206                                                   const char* name);
207
208     IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
209                                                  const ConstTensor& weights,
210                                                  const Optional<ConstTensor>& biases,
211                                                  const char* name);
212
213     IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
214         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
215         const ConstTensor& weights,
216         const Optional<ConstTensor>& biases,
217         const char* name);
218
219     std::unique_ptr<Graph> m_Graph;
220 };
221
222 class OptimizedNetwork final : public IOptimizedNetwork
223 {
224 public:
225     OptimizedNetwork(std::unique_ptr<Graph> graph);
226     ~OptimizedNetwork();
227
228     Status PrintGraph() override;
229     Status SerializeToDot(std::ostream& stream) const override;
230
231     Graph& GetGraph() { return *m_Graph; }
232
233 private:
234     std::unique_ptr<Graph> m_Graph;
235 };
236
237
238
239 struct OptimizationResult
240 {
241     bool m_Warning;
242     bool m_Error;
243
244     OptimizationResult()
245         : m_Warning(false)
246         , m_Error(false)
247     {}
248 };
249
250 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
251
252 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
253                                     struct BackendSettings& backendSettings);
254
255 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
256                                               BackendsMap& backends,
257                                               TensorHandleFactoryRegistry& registry,
258                                               Optional<std::vector<std::string>&> errMessages);
259
260 } // namespace armnn