4516c0a8f96d718c7485ed947176469ff60e1468
[platform/upstream/armnn.git] / src / armnn / Network.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TensorFwd.hpp>
11 #include <armnn/Types.hpp>
12
13 #include <armnn/INetwork.hpp>
14
15 #include <string>
16 #include <vector>
17 #include <map>
18 #include <memory>
19
20 #include "Layer.hpp"
21
22 namespace armnn
23 {
24 class Graph;
25
26 /// Private implementation of INetwork.
27 class Network final : public INetwork
28 {
29 public:
30     Network();
31     ~Network();
32
33     const Graph& GetGraph() const { return *m_Graph; }
34
35     Status PrintGraph() override;
36
37     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
38
39     IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
40                                               const char* name = nullptr) override;
41
42     IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
43                                       const char* name = nullptr) override;
44
45     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
46                                              const ConstTensor& weights,
47                                              const Optional<ConstTensor>& biases,
48                                              const char* name = nullptr) override;
49
50     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
51     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
52                                              const ConstTensor& weights,
53                                              const char* name = nullptr) override;
54
55     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
56     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
57                                              const ConstTensor& weights,
58                                              const ConstTensor& biases,
59                                              const char* name = nullptr) override;
60
61     IConnectableLayer* AddDepthwiseConvolution2dLayer(
62         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
63         const ConstTensor& weights,
64         const Optional<ConstTensor>& biases,
65         const char* name = nullptr) override;
66
67     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
68     IConnectableLayer* AddDepthwiseConvolution2dLayer(
69         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
70         const ConstTensor& weights,
71         const char* name = nullptr) override;
72
73     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
74     IConnectableLayer* AddDepthwiseConvolution2dLayer(
75         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
76         const ConstTensor& weights,
77         const ConstTensor& biases,
78         const char* name = nullptr) override;
79
80     IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
81
82     IConnectableLayer* AddDetectionPostProcessLayer(
83         const DetectionPostProcessDescriptor& descriptor,
84         const ConstTensor& anchors,
85         const char* name = nullptr) override;
86
87     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
88                                               const ConstTensor& weights,
89                                               const Optional<ConstTensor>& biases,
90                                               const char* name = nullptr) override;
91
92     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
93     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
94                                               const ConstTensor& weights,
95                                               const char* name = nullptr) override;
96
97     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
98     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
99                                               const ConstTensor& weights,
100                                               const ConstTensor& biases,
101                                               const char* name = nullptr) override;
102
103     IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
104
105     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
106                                        const char* name = nullptr) override;
107
108     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
109         const char* name = nullptr) override;
110
111     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
112         const char* name = nullptr) override;
113
114     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
115         const char* name = nullptr) override;
116
117     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
118         const char* name = nullptr) override;
119
120     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
121         const char* name = nullptr) override;
122
123     ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
124     IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
125                                       const char* name = nullptr) override;
126
127     IConnectableLayer* AddAbsLayer(const char* name = nullptr) override;
128
129     IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
130
131     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
132
133     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
134                                                   const ConstTensor&                  mean,
135                                                   const ConstTensor&                  variance,
136                                                   const ConstTensor&                  beta,
137                                                   const ConstTensor&                  gamma,
138                                                   const char*                         name = nullptr) override;
139
140     ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
141     IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
142                                               const char* name = nullptr) override;
143
144     IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
145                                       const char* name = nullptr) override;
146
147     IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
148                                                const char* name = nullptr) override;
149
150     IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
151
152     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
153                                        const char* name = nullptr) override;
154
155     IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
156                                               const char* name = nullptr) override;
157
158     IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
159                                             const char* name = nullptr) override;
160
161     IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
162
163     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
164
165     IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
166                                     const LstmInputParams& params,
167                                     const char* name = nullptr) override;
168
169     IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
170
171     IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
172
173     IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
174
175     IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
176
177     IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
178
179     IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
180
181     IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
182                                             const char* name = nullptr) override;
183
184     IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
185
186     IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
187
188     IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
189
190     IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
191
192     IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
193
194     IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
195
196     IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
197
198     IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
199                                                       const ConstTensor& weights,
200                                                       const Optional<ConstTensor>& biases,
201                                                       const char* name = nullptr) override;
202
203     IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
204                                      const char* name = nullptr) override;
205
206     IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
207                                              const char* name = nullptr) override;
208
209     void Accept(ILayerVisitor& visitor) const override;
210
211 private:
212     IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
213                                                   const ConstTensor& weights,
214                                                   const Optional<ConstTensor>& biases,
215                                                   const char* name);
216
217     IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
218                                                  const ConstTensor& weights,
219                                                  const Optional<ConstTensor>& biases,
220                                                  const char* name);
221
222     IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
223         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
224         const ConstTensor& weights,
225         const Optional<ConstTensor>& biases,
226         const char* name);
227
228     std::unique_ptr<Graph> m_Graph;
229 };
230
231 class OptimizedNetwork final : public IOptimizedNetwork
232 {
233 public:
234     OptimizedNetwork(std::unique_ptr<Graph> graph);
235     ~OptimizedNetwork();
236
237     Status PrintGraph() override;
238     Status SerializeToDot(std::ostream& stream) const override;
239
240     Graph& GetGraph() { return *m_Graph; }
241
242 private:
243     std::unique_ptr<Graph> m_Graph;
244 };
245
246
247
248 struct OptimizationResult
249 {
250     bool m_Warning;
251     bool m_Error;
252
253     OptimizationResult()
254         : m_Warning(false)
255         , m_Error(false)
256     {}
257 };
258
259 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
260
261 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
262                                     struct BackendSettings& backendSettings);
263
264 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
265                                               BackendsMap& backends,
266                                               TensorHandleFactoryRegistry& registry,
267                                               Optional<std::vector<std::string>&> errMessages);
268
269 } // namespace armnn