2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/Deprecated.hpp>
8 #include <armnn/DescriptorsFwd.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/Optional.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
27 virtual ~ILayerSupport() {}
30 virtual bool IsAbsSupported(const TensorInfo& input,
31 const TensorInfo& output,
32 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
34 virtual bool IsActivationSupported(const TensorInfo& input,
35 const TensorInfo& output,
36 const ActivationDescriptor& descriptor,
37 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
39 virtual bool IsAdditionSupported(const TensorInfo& input0,
40 const TensorInfo& input1,
41 const TensorInfo& output,
42 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
44 virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
45 const TensorInfo& output,
46 const TensorInfo& mean,
47 const TensorInfo& var,
48 const TensorInfo& beta,
49 const TensorInfo& gamma,
50 const BatchNormalizationDescriptor& descriptor,
51 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
53 virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input,
54 const TensorInfo& output,
55 const BatchToSpaceNdDescriptor& descriptor,
56 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
58 virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
59 const TensorInfo& output,
60 const OriginsDescriptor& descriptor,
61 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
63 virtual bool IsConstantSupported(const TensorInfo& output,
64 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
66 virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
67 const TensorInfo& output,
68 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
70 virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
71 const TensorInfo& output,
72 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
74 virtual bool IsConvolution2dSupported(const TensorInfo& input,
75 const TensorInfo& output,
76 const Convolution2dDescriptor& descriptor,
77 const TensorInfo& weights,
78 const Optional<TensorInfo>& biases,
79 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
81 virtual bool IsDebugSupported(const TensorInfo& input,
82 const TensorInfo& output,
83 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
85 virtual bool IsDepthwiseConvolutionSupported(
86 const TensorInfo& input,
87 const TensorInfo& output,
88 const DepthwiseConvolution2dDescriptor& descriptor,
89 const TensorInfo& weights,
90 const Optional<TensorInfo>& biases,
91 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
93 virtual bool IsDequantizeSupported(const TensorInfo& input,
94 const TensorInfo& output,
95 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
97 virtual bool IsDetectionPostProcessSupported(
98 const TensorInfo& input0,
99 const TensorInfo& input1,
100 const DetectionPostProcessDescriptor& descriptor,
101 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
103 virtual bool IsDilatedDepthwiseConvolutionSupported(
104 const TensorInfo& input,
105 const TensorInfo& output,
106 const DepthwiseConvolution2dDescriptor& descriptor,
107 const TensorInfo& weights,
108 const Optional<TensorInfo>& biases,
109 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
111 virtual bool IsDivisionSupported(const TensorInfo& input0,
112 const TensorInfo& input1,
113 const TensorInfo& output,
114 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
116 virtual bool IsEqualSupported(const TensorInfo& input0,
117 const TensorInfo& input1,
118 const TensorInfo& output,
119 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
121 virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
122 const FakeQuantizationDescriptor& descriptor,
123 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
125 virtual bool IsFloorSupported(const TensorInfo& input,
126 const TensorInfo& output,
127 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
129 virtual bool IsFullyConnectedSupported(const TensorInfo& input,
130 const TensorInfo& output,
131 const TensorInfo& weights,
132 const TensorInfo& biases,
133 const FullyConnectedDescriptor& descriptor,
134 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
136 virtual bool IsGatherSupported(const TensorInfo& input0,
137 const TensorInfo& input1,
138 const TensorInfo& output,
139 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
141 virtual bool IsGreaterSupported(const TensorInfo& input0,
142 const TensorInfo& input1,
143 const TensorInfo& ouput,
144 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
146 virtual bool IsInputSupported(const TensorInfo& input,
147 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
149 virtual bool IsL2NormalizationSupported(const TensorInfo& input,
150 const TensorInfo& output,
151 const L2NormalizationDescriptor& descriptor,
152 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
154 virtual bool IsLstmSupported(const TensorInfo& input,
155 const TensorInfo& outputStateIn,
156 const TensorInfo& cellStateIn,
157 const TensorInfo& scratchBuffer,
158 const TensorInfo& outputStateOut,
159 const TensorInfo& cellStateOut,
160 const TensorInfo& output,
161 const LstmDescriptor& descriptor,
162 const LstmInputParamsInfo& paramsInfo,
163 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
165 virtual bool IsMaximumSupported(const TensorInfo& input0,
166 const TensorInfo& input1,
167 const TensorInfo& output,
168 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
170 virtual bool IsMeanSupported(const TensorInfo& input,
171 const TensorInfo& output,
172 const MeanDescriptor& descriptor,
173 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
175 virtual bool IsMemCopySupported(const TensorInfo& input,
176 const TensorInfo& output,
177 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
179 virtual bool IsMemImportSupported(const TensorInfo& input,
180 const TensorInfo& output,
181 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
183 virtual bool IsMergeSupported(const TensorInfo& input0,
184 const TensorInfo& input1,
185 const TensorInfo& output,
186 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
188 ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead")
189 virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
190 const TensorInfo& output,
191 const OriginsDescriptor& descriptor,
192 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
194 virtual bool IsMinimumSupported(const TensorInfo& input0,
195 const TensorInfo& input1,
196 const TensorInfo& ouput,
197 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
199 virtual bool IsMultiplicationSupported(const TensorInfo& input0,
200 const TensorInfo& input1,
201 const TensorInfo& output,
202 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
204 virtual bool IsNormalizationSupported(const TensorInfo& input,
205 const TensorInfo& output,
206 const NormalizationDescriptor& descriptor,
207 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
209 virtual bool IsOutputSupported(const TensorInfo& output,
210 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
212 virtual bool IsPadSupported(const TensorInfo& input,
213 const TensorInfo& output,
214 const PadDescriptor& descriptor,
215 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
217 virtual bool IsPermuteSupported(const TensorInfo& input,
218 const TensorInfo& output,
219 const PermuteDescriptor& descriptor,
220 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
222 virtual bool IsPooling2dSupported(const TensorInfo& input,
223 const TensorInfo& output,
224 const Pooling2dDescriptor& descriptor,
225 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
227 virtual bool IsPreCompiledSupported(const TensorInfo& input,
228 const PreCompiledDescriptor& descriptor,
229 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
231 virtual bool IsPreluSupported(const TensorInfo& input,
232 const TensorInfo& alpha,
233 const TensorInfo& output,
234 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
236 virtual bool IsQuantizeSupported(const TensorInfo& input,
237 const TensorInfo& output,
238 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
240 virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
241 const TensorInfo& previousCellStateIn,
242 const TensorInfo& previousOutputIn,
243 const TensorInfo& cellStateOut,
244 const TensorInfo& output,
245 const QuantizedLstmInputParamsInfo& paramsInfo,
246 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
248 virtual bool IsReshapeSupported(const TensorInfo& input,
249 const ReshapeDescriptor& descriptor,
250 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
252 ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead")
253 virtual bool IsResizeBilinearSupported(const TensorInfo& input,
254 const TensorInfo& output,
255 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
257 virtual bool IsResizeSupported(const TensorInfo& input,
258 const TensorInfo& output,
259 const ResizeDescriptor& descriptor,
260 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
262 virtual bool IsRsqrtSupported(const TensorInfo& input,
263 const TensorInfo& output,
264 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
266 virtual bool IsSoftmaxSupported(const TensorInfo& input,
267 const TensorInfo& output,
268 const SoftmaxDescriptor& descriptor,
269 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
271 virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input,
272 const TensorInfo& output,
273 const SpaceToBatchNdDescriptor& descriptor,
274 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
276 virtual bool IsSpaceToDepthSupported(const TensorInfo& input,
277 const TensorInfo& output,
278 const SpaceToDepthDescriptor& descriptor,
279 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
281 ARMNN_DEPRECATED_MSG("Use IsSplitterSupported with outputs instead")
282 virtual bool IsSplitterSupported(const TensorInfo& input,
283 const ViewsDescriptor& descriptor,
284 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
286 virtual bool IsSplitterSupported(const TensorInfo& input,
287 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
288 const ViewsDescriptor& descriptor,
289 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
291 virtual bool IsStackSupported(const std::vector<const TensorInfo*>& inputs,
292 const TensorInfo& output,
293 const StackDescriptor& descriptor,
294 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
296 virtual bool IsStridedSliceSupported(const TensorInfo& input,
297 const TensorInfo& output,
298 const StridedSliceDescriptor& descriptor,
299 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
301 virtual bool IsSubtractionSupported(const TensorInfo& input0,
302 const TensorInfo& input1,
303 const TensorInfo& output,
304 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
306 virtual bool IsSwitchSupported(const TensorInfo& input0,
307 const TensorInfo& input1,
308 const TensorInfo& output0,
309 const TensorInfo& output1,
310 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
312 virtual bool IsTransposeConvolution2dSupported(
313 const TensorInfo& input,
314 const TensorInfo& output,
315 const TransposeConvolution2dDescriptor& descriptor,
316 const TensorInfo& weights,
317 const Optional<TensorInfo>& biases,
318 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
320 }; // class ILayerSupport
322 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;