c67569bf008baa9df2a128d4f5b986485bcc5596
[platform/upstream/armnn.git] / include / armnn / ILayerSupport.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Deprecated.hpp>
8 #include <armnn/DescriptorsFwd.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/Optional.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
12
13 #include <cctype>
14 #include <functional>
15 #include <memory>
16 #include <vector>
17
18 namespace armnn
19 {
20
21 class TensorInfo;
22
23 class ILayerSupport
24 {
25 protected:
26     ILayerSupport() {}
27     virtual ~ILayerSupport() {}
28
29 public:
30     virtual bool IsAbsSupported(const TensorInfo& input,
31                                 const TensorInfo& output,
32                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
33
34     virtual bool IsActivationSupported(const TensorInfo& input,
35                                        const TensorInfo& output,
36                                        const ActivationDescriptor& descriptor,
37                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
38
39     virtual bool IsAdditionSupported(const TensorInfo& input0,
40                                      const TensorInfo& input1,
41                                      const TensorInfo& output,
42                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
43
44     virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
45                                                const TensorInfo& output,
46                                                const TensorInfo& mean,
47                                                const TensorInfo& var,
48                                                const TensorInfo& beta,
49                                                const TensorInfo& gamma,
50                                                const BatchNormalizationDescriptor& descriptor,
51                                                Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
52
53     virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input,
54                                            const TensorInfo& output,
55                                            const BatchToSpaceNdDescriptor& descriptor,
56                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
57
58     virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
59                                    const TensorInfo& output,
60                                    const OriginsDescriptor& descriptor,
61                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
62
63     virtual bool IsConstantSupported(const TensorInfo& output,
64                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
65
66     virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
67                                               const TensorInfo& output,
68                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
69
70     virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
71                                               const TensorInfo& output,
72                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
73
74     virtual bool IsConvolution2dSupported(const TensorInfo& input,
75                                           const TensorInfo& output,
76                                           const Convolution2dDescriptor& descriptor,
77                                           const TensorInfo& weights,
78                                           const Optional<TensorInfo>& biases,
79                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
80
81     virtual bool IsDebugSupported(const TensorInfo& input,
82                                   const TensorInfo& output,
83                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
84
85     virtual bool IsDepthwiseConvolutionSupported(
86                      const TensorInfo& input,
87                      const TensorInfo& output,
88                      const DepthwiseConvolution2dDescriptor& descriptor,
89                      const TensorInfo& weights,
90                      const Optional<TensorInfo>& biases,
91                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
92
93     virtual bool IsDequantizeSupported(const TensorInfo& input,
94                                        const TensorInfo& output,
95                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
96
97     virtual bool IsDetectionPostProcessSupported(
98                      const TensorInfo& input0,
99                      const TensorInfo& input1,
100                      const DetectionPostProcessDescriptor& descriptor,
101                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
102
103     virtual bool IsDilatedDepthwiseConvolutionSupported(
104                     const TensorInfo& input,
105                     const TensorInfo& output,
106                     const DepthwiseConvolution2dDescriptor& descriptor,
107                     const TensorInfo& weights,
108                     const Optional<TensorInfo>& biases,
109                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
110
111     virtual bool IsDivisionSupported(const TensorInfo& input0,
112                                      const TensorInfo& input1,
113                                      const TensorInfo& output,
114                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
115
116     virtual bool IsEqualSupported(const TensorInfo& input0,
117                                   const TensorInfo& input1,
118                                   const TensorInfo& output,
119                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
120
121     virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
122                                              const FakeQuantizationDescriptor& descriptor,
123                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
124
125     virtual bool IsFloorSupported(const TensorInfo& input,
126                                   const TensorInfo& output,
127                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
128
129     virtual bool IsFullyConnectedSupported(const TensorInfo& input,
130                                            const TensorInfo& output,
131                                            const TensorInfo& weights,
132                                            const TensorInfo& biases,
133                                            const FullyConnectedDescriptor& descriptor,
134                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
135
136     virtual bool IsGatherSupported(const TensorInfo& input0,
137                                    const TensorInfo& input1,
138                                    const TensorInfo& output,
139                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
140
141     virtual bool IsGreaterSupported(const TensorInfo& input0,
142                                     const TensorInfo& input1,
143                                     const TensorInfo& ouput,
144                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
145
146     virtual bool IsInputSupported(const TensorInfo& input,
147                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
148
149     virtual bool IsL2NormalizationSupported(const TensorInfo& input,
150                                             const TensorInfo& output,
151                                             const L2NormalizationDescriptor& descriptor,
152                                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
153
154     virtual bool IsLstmSupported(const TensorInfo& input,
155                                  const TensorInfo& outputStateIn,
156                                  const TensorInfo& cellStateIn,
157                                  const TensorInfo& scratchBuffer,
158                                  const TensorInfo& outputStateOut,
159                                  const TensorInfo& cellStateOut,
160                                  const TensorInfo& output,
161                                  const LstmDescriptor& descriptor,
162                                  const LstmInputParamsInfo& paramsInfo,
163                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
164
165     virtual bool IsMaximumSupported(const TensorInfo& input0,
166                                     const TensorInfo& input1,
167                                     const TensorInfo& output,
168                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
169
170     virtual bool IsMeanSupported(const TensorInfo& input,
171                                  const TensorInfo& output,
172                                  const MeanDescriptor& descriptor,
173                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
174
175     virtual bool IsMemCopySupported(const TensorInfo& input,
176                                     const TensorInfo& output,
177                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
178
179     virtual bool IsMemImportSupported(const TensorInfo& input,
180                                       const TensorInfo& output,
181                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
182
183     virtual bool IsMergeSupported(const TensorInfo& input0,
184                                   const TensorInfo& input1,
185                                   const TensorInfo& output,
186                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
187
188     ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead")
189     virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
190                                    const TensorInfo& output,
191                                    const OriginsDescriptor& descriptor,
192                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
193
194     virtual bool IsMinimumSupported(const TensorInfo& input0,
195                                     const TensorInfo& input1,
196                                     const TensorInfo& ouput,
197                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
198
199     virtual bool IsMultiplicationSupported(const TensorInfo& input0,
200                                            const TensorInfo& input1,
201                                            const TensorInfo& output,
202                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
203
204     virtual bool IsNormalizationSupported(const TensorInfo& input,
205                                           const TensorInfo& output,
206                                           const NormalizationDescriptor& descriptor,
207                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
208
209     virtual bool IsOutputSupported(const TensorInfo& output,
210                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
211
212     virtual bool IsPadSupported(const TensorInfo& input,
213                                 const TensorInfo& output,
214                                 const PadDescriptor& descriptor,
215                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
216
217     virtual bool IsPermuteSupported(const TensorInfo& input,
218                                     const TensorInfo& output,
219                                     const PermuteDescriptor& descriptor,
220                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
221
222     virtual bool IsPooling2dSupported(const TensorInfo& input,
223                                       const TensorInfo& output,
224                                       const Pooling2dDescriptor& descriptor,
225                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
226
227     virtual bool IsPreCompiledSupported(const TensorInfo& input,
228                                         const PreCompiledDescriptor& descriptor,
229                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
230
231     virtual bool IsPreluSupported(const TensorInfo& input,
232                                   const TensorInfo& alpha,
233                                   const TensorInfo& output,
234                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
235
236     virtual bool IsQuantizeSupported(const TensorInfo& input,
237                                      const TensorInfo& output,
238                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
239
240     virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
241                                           const TensorInfo& previousCellStateIn,
242                                           const TensorInfo& previousOutputIn,
243                                           const TensorInfo& cellStateOut,
244                                           const TensorInfo& output,
245                                           const QuantizedLstmInputParamsInfo& paramsInfo,
246                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
247
248     virtual bool IsReshapeSupported(const TensorInfo& input,
249                                     const ReshapeDescriptor& descriptor,
250                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
251
252     ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead")
253     virtual bool IsResizeBilinearSupported(const TensorInfo& input,
254                                            const TensorInfo& output,
255                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
256
257     virtual bool IsResizeSupported(const TensorInfo& input,
258                                    const TensorInfo& output,
259                                    const ResizeDescriptor& descriptor,
260                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
261
262     virtual bool IsRsqrtSupported(const TensorInfo& input,
263                                   const TensorInfo& output,
264                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
265
266     virtual bool IsSoftmaxSupported(const TensorInfo& input,
267                                     const TensorInfo& output,
268                                     const SoftmaxDescriptor& descriptor,
269                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
270
271     virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input,
272                                            const TensorInfo& output,
273                                            const SpaceToBatchNdDescriptor& descriptor,
274                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
275
276     virtual bool IsSpaceToDepthSupported(const TensorInfo& input,
277                                          const TensorInfo& output,
278                                          const SpaceToDepthDescriptor& descriptor,
279                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
280
281     ARMNN_DEPRECATED_MSG("Use IsSplitterSupported with outputs instead")
282     virtual bool IsSplitterSupported(const TensorInfo& input,
283                                      const ViewsDescriptor& descriptor,
284                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
285
286     virtual bool IsSplitterSupported(const TensorInfo& input,
287                                      const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
288                                      const ViewsDescriptor& descriptor,
289                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
290
291     virtual bool IsStackSupported(const std::vector<const TensorInfo*>& inputs,
292                                   const TensorInfo& output,
293                                   const StackDescriptor& descriptor,
294                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
295
296     virtual bool IsStridedSliceSupported(const TensorInfo& input,
297                                          const TensorInfo& output,
298                                          const StridedSliceDescriptor& descriptor,
299                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
300
301     virtual bool IsSubtractionSupported(const TensorInfo& input0,
302                                         const TensorInfo& input1,
303                                         const TensorInfo& output,
304                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
305
306     virtual bool IsSwitchSupported(const TensorInfo& input0,
307                                    const TensorInfo& input1,
308                                    const TensorInfo& output0,
309                                    const TensorInfo& output1,
310                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
311
312     virtual bool IsTransposeConvolution2dSupported(
313         const TensorInfo& input,
314         const TensorInfo& output,
315         const TransposeConvolution2dDescriptor& descriptor,
316         const TensorInfo& weights,
317         const Optional<TensorInfo>& biases,
318         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
319
320 }; // class ILayerSupport
321
322 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
323
324 } // namespace armnn