IVGCVSW-3722 Add front end support for ArgMinMax
[platform/upstream/armnn.git] / include / armnn / ILayerSupport.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Deprecated.hpp>
8 #include <armnn/DescriptorsFwd.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/Optional.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
12
13 #include <cctype>
14 #include <functional>
15 #include <memory>
16 #include <vector>
17 #include "ArmNN.hpp"
18
19 namespace armnn
20 {
21
22 class TensorInfo;
23
24 class ILayerSupport
25 {
26 protected:
27     ILayerSupport() {}
28     virtual ~ILayerSupport() {}
29
30 public:
31     virtual bool IsAbsSupported(const TensorInfo& input,
32                                 const TensorInfo& output,
33                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
34
35     virtual bool IsActivationSupported(const TensorInfo& input,
36                                        const TensorInfo& output,
37                                        const ActivationDescriptor& descriptor,
38                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
39
40     virtual bool IsAdditionSupported(const TensorInfo& input0,
41                                      const TensorInfo& input1,
42                                      const TensorInfo& output,
43                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
44
45     virtual bool IsArgMinMaxSupported(const TensorInfo& input,
46                                       const TensorInfo& output,
47                                       const ArgMinMaxDescriptor& descriptor,
48                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
49
50     virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
51                                                const TensorInfo& output,
52                                                const TensorInfo& mean,
53                                                const TensorInfo& var,
54                                                const TensorInfo& beta,
55                                                const TensorInfo& gamma,
56                                                const BatchNormalizationDescriptor& descriptor,
57                                                Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
58
59     virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input,
60                                            const TensorInfo& output,
61                                            const BatchToSpaceNdDescriptor& descriptor,
62                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
63
64     virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
65                                    const TensorInfo& output,
66                                    const OriginsDescriptor& descriptor,
67                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
68
69     virtual bool IsConstantSupported(const TensorInfo& output,
70                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
71
72     virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
73                                               const TensorInfo& output,
74                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
75
76     virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
77                                               const TensorInfo& output,
78                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
79
80     virtual bool IsConvolution2dSupported(const TensorInfo& input,
81                                           const TensorInfo& output,
82                                           const Convolution2dDescriptor& descriptor,
83                                           const TensorInfo& weights,
84                                           const Optional<TensorInfo>& biases,
85                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
86
87     virtual bool IsDebugSupported(const TensorInfo& input,
88                                   const TensorInfo& output,
89                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
90
91     virtual bool IsDepthwiseConvolutionSupported(
92                      const TensorInfo& input,
93                      const TensorInfo& output,
94                      const DepthwiseConvolution2dDescriptor& descriptor,
95                      const TensorInfo& weights,
96                      const Optional<TensorInfo>& biases,
97                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
98
99     virtual bool IsDequantizeSupported(const TensorInfo& input,
100                                        const TensorInfo& output,
101                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
102
103     virtual bool IsDetectionPostProcessSupported(
104                      const TensorInfo& input0,
105                      const TensorInfo& input1,
106                      const DetectionPostProcessDescriptor& descriptor,
107                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
108
109     virtual bool IsDilatedDepthwiseConvolutionSupported(
110                     const TensorInfo& input,
111                     const TensorInfo& output,
112                     const DepthwiseConvolution2dDescriptor& descriptor,
113                     const TensorInfo& weights,
114                     const Optional<TensorInfo>& biases,
115                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
116
117     virtual bool IsDivisionSupported(const TensorInfo& input0,
118                                      const TensorInfo& input1,
119                                      const TensorInfo& output,
120                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
121
122     virtual bool IsEqualSupported(const TensorInfo& input0,
123                                   const TensorInfo& input1,
124                                   const TensorInfo& output,
125                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
126
127     virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
128                                              const FakeQuantizationDescriptor& descriptor,
129                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
130
131     virtual bool IsFloorSupported(const TensorInfo& input,
132                                   const TensorInfo& output,
133                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
134
135     virtual bool IsFullyConnectedSupported(const TensorInfo& input,
136                                            const TensorInfo& output,
137                                            const TensorInfo& weights,
138                                            const TensorInfo& biases,
139                                            const FullyConnectedDescriptor& descriptor,
140                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
141
142     virtual bool IsGatherSupported(const TensorInfo& input0,
143                                    const TensorInfo& input1,
144                                    const TensorInfo& output,
145                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
146
147     virtual bool IsGreaterSupported(const TensorInfo& input0,
148                                     const TensorInfo& input1,
149                                     const TensorInfo& ouput,
150                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
151
152     virtual bool IsInputSupported(const TensorInfo& input,
153                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
154
155     virtual bool IsL2NormalizationSupported(const TensorInfo& input,
156                                             const TensorInfo& output,
157                                             const L2NormalizationDescriptor& descriptor,
158                                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
159
160     virtual bool IsLstmSupported(const TensorInfo& input,
161                                  const TensorInfo& outputStateIn,
162                                  const TensorInfo& cellStateIn,
163                                  const TensorInfo& scratchBuffer,
164                                  const TensorInfo& outputStateOut,
165                                  const TensorInfo& cellStateOut,
166                                  const TensorInfo& output,
167                                  const LstmDescriptor& descriptor,
168                                  const LstmInputParamsInfo& paramsInfo,
169                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
170
171     virtual bool IsMaximumSupported(const TensorInfo& input0,
172                                     const TensorInfo& input1,
173                                     const TensorInfo& output,
174                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
175
176     virtual bool IsMeanSupported(const TensorInfo& input,
177                                  const TensorInfo& output,
178                                  const MeanDescriptor& descriptor,
179                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
180
181     virtual bool IsMemCopySupported(const TensorInfo& input,
182                                     const TensorInfo& output,
183                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
184
185     virtual bool IsMemImportSupported(const TensorInfo& input,
186                                       const TensorInfo& output,
187                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
188
189     virtual bool IsMergeSupported(const TensorInfo& input0,
190                                   const TensorInfo& input1,
191                                   const TensorInfo& output,
192                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
193
194     ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead")
195     virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
196                                    const TensorInfo& output,
197                                    const OriginsDescriptor& descriptor,
198                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
199
200     virtual bool IsMinimumSupported(const TensorInfo& input0,
201                                     const TensorInfo& input1,
202                                     const TensorInfo& ouput,
203                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
204
205     virtual bool IsMultiplicationSupported(const TensorInfo& input0,
206                                            const TensorInfo& input1,
207                                            const TensorInfo& output,
208                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
209
210     virtual bool IsNormalizationSupported(const TensorInfo& input,
211                                           const TensorInfo& output,
212                                           const NormalizationDescriptor& descriptor,
213                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
214
215     virtual bool IsOutputSupported(const TensorInfo& output,
216                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
217
218     virtual bool IsPadSupported(const TensorInfo& input,
219                                 const TensorInfo& output,
220                                 const PadDescriptor& descriptor,
221                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
222
223     virtual bool IsPermuteSupported(const TensorInfo& input,
224                                     const TensorInfo& output,
225                                     const PermuteDescriptor& descriptor,
226                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
227
228     virtual bool IsPooling2dSupported(const TensorInfo& input,
229                                       const TensorInfo& output,
230                                       const Pooling2dDescriptor& descriptor,
231                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
232
233     virtual bool IsPreCompiledSupported(const TensorInfo& input,
234                                         const PreCompiledDescriptor& descriptor,
235                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
236
237     virtual bool IsPreluSupported(const TensorInfo& input,
238                                   const TensorInfo& alpha,
239                                   const TensorInfo& output,
240                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
241
242     virtual bool IsQuantizeSupported(const TensorInfo& input,
243                                      const TensorInfo& output,
244                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
245
246     virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
247                                           const TensorInfo& previousCellStateIn,
248                                           const TensorInfo& previousOutputIn,
249                                           const TensorInfo& cellStateOut,
250                                           const TensorInfo& output,
251                                           const QuantizedLstmInputParamsInfo& paramsInfo,
252                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
253
254     virtual bool IsReshapeSupported(const TensorInfo& input,
255                                     const ReshapeDescriptor& descriptor,
256                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
257
258     ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead")
259     virtual bool IsResizeBilinearSupported(const TensorInfo& input,
260                                            const TensorInfo& output,
261                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
262
263     virtual bool IsResizeSupported(const TensorInfo& input,
264                                    const TensorInfo& output,
265                                    const ResizeDescriptor& descriptor,
266                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
267
268     virtual bool IsRsqrtSupported(const TensorInfo& input,
269                                   const TensorInfo& output,
270                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
271
272     virtual bool IsSoftmaxSupported(const TensorInfo& input,
273                                     const TensorInfo& output,
274                                     const SoftmaxDescriptor& descriptor,
275                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
276
277     virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input,
278                                            const TensorInfo& output,
279                                            const SpaceToBatchNdDescriptor& descriptor,
280                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
281
282     virtual bool IsSpaceToDepthSupported(const TensorInfo& input,
283                                          const TensorInfo& output,
284                                          const SpaceToDepthDescriptor& descriptor,
285                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
286
287     ARMNN_DEPRECATED_MSG("Use IsSplitterSupported with outputs instead")
288     virtual bool IsSplitterSupported(const TensorInfo& input,
289                                      const ViewsDescriptor& descriptor,
290                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
291
292     virtual bool IsSplitterSupported(const TensorInfo& input,
293                                      const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
294                                      const ViewsDescriptor& descriptor,
295                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
296
297     virtual bool IsStackSupported(const std::vector<const TensorInfo*>& inputs,
298                                   const TensorInfo& output,
299                                   const StackDescriptor& descriptor,
300                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
301
302     virtual bool IsStridedSliceSupported(const TensorInfo& input,
303                                          const TensorInfo& output,
304                                          const StridedSliceDescriptor& descriptor,
305                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
306
307     virtual bool IsSubtractionSupported(const TensorInfo& input0,
308                                         const TensorInfo& input1,
309                                         const TensorInfo& output,
310                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
311
312     virtual bool IsSwitchSupported(const TensorInfo& input0,
313                                    const TensorInfo& input1,
314                                    const TensorInfo& output0,
315                                    const TensorInfo& output1,
316                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
317
318     virtual bool IsTransposeConvolution2dSupported(
319         const TensorInfo& input,
320         const TensorInfo& output,
321         const TransposeConvolution2dDescriptor& descriptor,
322         const TensorInfo& weights,
323         const Optional<TensorInfo>& biases,
324         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
325
326 }; // class ILayerSupport
327
328 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
329
330 } // namespace armnn