IVGCVSW-3722 Add front end support for ArgMinMax
[platform/upstream/armnn.git] / src / armnn / Network.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/DescriptorsFwd.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TensorFwd.hpp>
11 #include <armnn/Types.hpp>
12
13 #include <armnn/INetwork.hpp>
14
15 #include <string>
16 #include <vector>
17 #include <map>
18 #include <memory>
19
20 #include "Layer.hpp"
21
22 namespace armnn
23 {
24 class Graph;
25
26 /// Private implementation of INetwork.
27 class Network final : public INetwork
28 {
29 public:
30     Network();
31     ~Network();
32
33     const Graph& GetGraph() const { return *m_Graph; }
34
35     Status PrintGraph() override;
36
37     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
38
39     IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
40                                          const char* name = nullptr) override;
41
42     IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
43                                               const char* name = nullptr) override;
44
45     IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
46                                       const char* name = nullptr) override;
47
48     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
49                                              const ConstTensor& weights,
50                                              const Optional<ConstTensor>& biases,
51                                              const char* name = nullptr) override;
52
53     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
54     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
55                                              const ConstTensor& weights,
56                                              const char* name = nullptr) override;
57
58     ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
59     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
60                                              const ConstTensor& weights,
61                                              const ConstTensor& biases,
62                                              const char* name = nullptr) override;
63
64     IConnectableLayer* AddDepthwiseConvolution2dLayer(
65         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
66         const ConstTensor& weights,
67         const Optional<ConstTensor>& biases,
68         const char* name = nullptr) override;
69
70     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
71     IConnectableLayer* AddDepthwiseConvolution2dLayer(
72         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
73         const ConstTensor& weights,
74         const char* name = nullptr) override;
75
76     ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
77     IConnectableLayer* AddDepthwiseConvolution2dLayer(
78         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
79         const ConstTensor& weights,
80         const ConstTensor& biases,
81         const char* name = nullptr) override;
82
83     IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
84
85     IConnectableLayer* AddDetectionPostProcessLayer(
86         const DetectionPostProcessDescriptor& descriptor,
87         const ConstTensor& anchors,
88         const char* name = nullptr) override;
89
90     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
91                                               const ConstTensor& weights,
92                                               const Optional<ConstTensor>& biases,
93                                               const char* name = nullptr) override;
94
95     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
96     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
97                                               const ConstTensor& weights,
98                                               const char* name = nullptr) override;
99
100     ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
101     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
102                                               const ConstTensor& weights,
103                                               const ConstTensor& biases,
104                                               const char* name = nullptr) override;
105
106     IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
107
108     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
109                                        const char* name = nullptr) override;
110
111     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
112         const char* name = nullptr) override;
113
114     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
115         const char* name = nullptr) override;
116
117     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
118         const char* name = nullptr) override;
119
120     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
121         const char* name = nullptr) override;
122
123     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
124         const char* name = nullptr) override;
125
126     ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
127     IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
128                                       const char* name = nullptr) override;
129
130     IConnectableLayer* AddAbsLayer(const char* name = nullptr) override;
131
132     IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
133
134     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
135
136     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
137                                                   const ConstTensor&                  mean,
138                                                   const ConstTensor&                  variance,
139                                                   const ConstTensor&                  beta,
140                                                   const ConstTensor&                  gamma,
141                                                   const char*                         name = nullptr) override;
142
143     ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
144     IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
145                                               const char* name = nullptr) override;
146
147     IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
148                                       const char* name = nullptr) override;
149
150     IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
151                                                const char* name = nullptr) override;
152
153     IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
154
155     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
156                                        const char* name = nullptr) override;
157
158     IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
159                                               const char* name = nullptr) override;
160
161     IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
162                                             const char* name = nullptr) override;
163
164     IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
165
166     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
167
168     IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
169                                     const LstmInputParams& params,
170                                     const char* name = nullptr) override;
171
172     IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
173
174     IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
175
176     IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
177
178     IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
179
180     IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
181
182     IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
183
184     IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
185                                             const char* name = nullptr) override;
186
187     IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
188
189     IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
190
191     IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
192
193     IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
194
195     IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
196
197     IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
198
199     IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
200
201     IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
202                                                       const ConstTensor& weights,
203                                                       const Optional<ConstTensor>& biases,
204                                                       const char* name = nullptr) override;
205
206     IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
207                                      const char* name = nullptr) override;
208
209     IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
210                                              const char* name = nullptr) override;
211
212     void Accept(ILayerVisitor& visitor) const override;
213
214 private:
215     IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
216                                                   const ConstTensor& weights,
217                                                   const Optional<ConstTensor>& biases,
218                                                   const char* name);
219
220     IConnectableLayer* AddConvolution2dLayerImpl(const Convolution2dDescriptor& convolution2dDescriptor,
221                                                  const ConstTensor& weights,
222                                                  const Optional<ConstTensor>& biases,
223                                                  const char* name);
224
225     IConnectableLayer* AddDepthwiseConvolution2dLayerImpl(
226         const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
227         const ConstTensor& weights,
228         const Optional<ConstTensor>& biases,
229         const char* name);
230
231     std::unique_ptr<Graph> m_Graph;
232 };
233
234 class OptimizedNetwork final : public IOptimizedNetwork
235 {
236 public:
237     OptimizedNetwork(std::unique_ptr<Graph> graph);
238     ~OptimizedNetwork();
239
240     Status PrintGraph() override;
241     Status SerializeToDot(std::ostream& stream) const override;
242
243     Graph& GetGraph() { return *m_Graph; }
244
245 private:
246     std::unique_ptr<Graph> m_Graph;
247 };
248
249
250
251 struct OptimizationResult
252 {
253     bool m_Warning;
254     bool m_Error;
255
256     OptimizationResult()
257         : m_Warning(false)
258         , m_Error(false)
259     {}
260 };
261
262 using BackendsMap = std::map<BackendId, std::unique_ptr<class IBackendInternal>>;
263
264 BackendsMap CreateSupportedBackends(TensorHandleFactoryRegistry& handleFactoryRegistry,
265                                     struct BackendSettings& backendSettings);
266
267 OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
268                                               BackendsMap& backends,
269                                               TensorHandleFactoryRegistry& registry,
270                                               Optional<std::vector<std::string>&> errMessages);
271
272 } // namespace armnn