IVGCVSW-3173 Extend reference softmax workload to support qsymm16
[platform/upstream/armnn.git] / src / backends / reference / RefLayerSupport.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefLayerSupport.hpp"
7 #include "RefBackendId.hpp"
8
9 #include <InternalTypes.hpp>
10 #include <LayerSupportCommon.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Descriptors.hpp>
13
14 #include <backendsCommon/BackendRegistry.hpp>
15 #include <backendsCommon/test/WorkloadTestUtils.hpp>
16
17 #include <boost/core/ignore_unused.hpp>
18
19 #include <vector>
20 #include <algorithm>
21 #include <array>
22
23 using namespace boost;
24
25 namespace armnn
26 {
27
28 namespace
29 {
30
31 template<typename Float32Func, typename Uint8Func, typename ... Params>
32 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33                                DataType dataType,
34                                Float32Func floatFuncPtr,
35                                Uint8Func uint8FuncPtr,
36                                Params&&... params)
37 {
38     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39                                          dataType,
40                                          &FalseFunc<Params...>,
41                                          floatFuncPtr,
42                                          uint8FuncPtr,
43                                          &FalseFunc<Params...>,
44                                          &FalseFunc<Params...>,
45                                          std::forward<Params>(params)...);
46 }
47
48 } // anonymous namespace
49
50
51 namespace
52 {
53 template<typename F>
54 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
55 {
56     bool supported = rule();
57     if (!supported && reason)
58     {
59         reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
60     }
61     return supported;
62 }
63
64 struct Rule
65 {
66     bool operator()() const
67     {
68         return m_Res;
69     }
70
71     bool m_Res = true;
72 };
73
74 template<typename T>
75 bool AllTypesAreEqualImpl(T t)
76 {
77     return true;
78 }
79
80 template<typename T, typename... Rest>
81 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
82 {
83     static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
84
85     return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
86 }
87
88 struct TypesAreEqual : public Rule
89 {
90     template<typename ... Ts>
91     TypesAreEqual(const Ts&... ts)
92     {
93         m_Res = AllTypesAreEqualImpl(ts...);
94     }
95 };
96
97 struct QuantizationParametersAreEqual : public Rule
98 {
99     QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
100     {
101         m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102                 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
103     }
104 };
105
106 struct TypeAnyOf : public Rule
107 {
108     template<typename Container>
109     TypeAnyOf(const TensorInfo& info, const Container& c)
110     {
111         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
112         {
113             return dt == info.GetDataType();
114         });
115     }
116 };
117
118 struct BiasAndWeightsTypesMatch : public Rule
119 {
120     BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121     {
122         m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123     }
124 };
125
126 struct BiasAndWeightsTypesCompatible : public Rule
127 {
128     template<typename Container>
129     BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130     {
131         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132         {
133             return dt ==  GetBiasTypeFromWeightsType(info.GetDataType()).value();
134         });
135     }
136 };
137
138 struct ShapesAreSameRank : public Rule
139 {
140     ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141     {
142         m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143     }
144 };
145
146 struct ShapesAreSameTotalSize : public Rule
147 {
148     ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149     {
150         m_Res = info0.GetNumElements() == info1.GetNumElements();
151     }
152 };
153
154 struct ShapesAreBroadcastCompatible : public Rule
155 {
156     unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157     {
158         unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159         unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160         return sizeIn;
161     }
162
163     ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164     {
165         const TensorShape& shape0 = in0.GetShape();
166         const TensorShape& shape1 = in1.GetShape();
167         const TensorShape& outShape = out.GetShape();
168
169         for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170         {
171             unsigned int sizeOut = outShape[i];
172             unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173             unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175             m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176                      ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177         }
178     }
179 };
180 } // namespace
181
182
183 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184                                             const TensorInfo& output,
185                                             const ActivationDescriptor& descriptor,
186                                             Optional<std::string&> reasonIfUnsupported) const
187 {
188    bool supported = true;
189
190     // Define supported types.
191     std::array<DataType,3> supportedTypes = {
192         DataType::Float32,
193         DataType::QuantisedAsymm8,
194         DataType::QuantisedSymm16
195     };
196
197     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198                                   "Reference activation: input type not supported.");
199
200     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201                                   "Reference activation: output type not supported.");
202
203     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204                                   "Reference activation: input and output types mismatched.");
205
206     supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207                                   "Reference activation: input and output shapes are of different rank.");
208
209
210     struct ActivationFunctionSupported : public Rule
211     {
212         ActivationFunctionSupported(const ActivationDescriptor& desc)
213         {
214             switch(desc.m_Function)
215             {
216                 case ActivationFunction::Abs:
217                 case ActivationFunction::BoundedReLu:
218                 case ActivationFunction::LeakyReLu:
219                 case ActivationFunction::Linear:
220                 case ActivationFunction::ReLu:
221                 case ActivationFunction::Sigmoid:
222                 case ActivationFunction::SoftReLu:
223                 case ActivationFunction::Sqrt:
224                 case ActivationFunction::Square:
225                 case ActivationFunction::TanH:
226                 {
227                     m_Res = true;
228                     break;
229                 }
230                 default:
231                 {
232                     m_Res = false;
233                     break;
234                 }
235             }
236         }
237     };
238
239     // Function is supported
240     supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241                                   "Reference activation: function not supported.");
242
243     return supported;
244 }
245
246 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247                                           const TensorInfo& input1,
248                                           const TensorInfo& output,
249                                           Optional<std::string&> reasonIfUnsupported) const
250 {
251     bool supported = true;
252
253     std::array<DataType,3> supportedTypes = {
254         DataType::Float32,
255         DataType::QuantisedAsymm8,
256         DataType::QuantisedSymm16
257     };
258
259     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260                                   "Reference addition: input 0 is not a supported type.");
261
262     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263                                   "Reference addition: input 1 is not a supported type.");
264
265     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266                                   "Reference addition: output is not a supported type.");
267
268     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269                                   "Reference addition: input 0 and Input 1 types are mismatched");
270
271     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272                                   "Reference addition: input and output types are mismatched");
273
274     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275                                   "Reference addition: shapes are not suitable for implicit broadcast.");
276
277     return supported;
278 }
279
280 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281                                                     const TensorInfo& output,
282                                                     const TensorInfo& mean,
283                                                     const TensorInfo& var,
284                                                     const TensorInfo& beta,
285                                                     const TensorInfo& gamma,
286                                                     const BatchNormalizationDescriptor& descriptor,
287                                                     Optional<std::string&> reasonIfUnsupported) const
288 {
289     ignore_unused(output);
290     ignore_unused(mean);
291     ignore_unused(var);
292     ignore_unused(beta);
293     ignore_unused(gamma);
294     ignore_unused(descriptor);
295     return IsSupportedForDataTypeRef(reasonIfUnsupported,
296                                      input.GetDataType(),
297                                      &TrueFunc<>,
298                                      &TrueFunc<>);
299 }
300
301 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
302                                                 const TensorInfo& output,
303                                                 const BatchToSpaceNdDescriptor& descriptor,
304                                                 Optional<std::string&> reasonIfUnsupported) const
305 {
306     ignore_unused(descriptor);
307     return (IsSupportedForDataTypeRef(reasonIfUnsupported,
308                                       input.GetDataType(),
309                                       &TrueFunc<>,
310                                       &TrueFunc<>) &&
311             IsSupportedForDataTypeRef(reasonIfUnsupported,
312                                       output.GetDataType(),
313                                       &TrueFunc<>,
314                                       &TrueFunc<>));
315 }
316
317 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
318                                         const TensorInfo& output,
319                                         const ConcatDescriptor& descriptor,
320                                         Optional<std::string&> reasonIfUnsupported) const
321 {
322     ignore_unused(descriptor);
323
324     bool supported = true;
325     std::array<DataType,3> supportedTypes =
326     {
327             DataType::Float32,
328             DataType::QuantisedAsymm8,
329             DataType::QuantisedSymm16
330     };
331
332     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
333                                   "Reference concatenation: output type not supported");
334     for (const TensorInfo* input : inputs)
335     {
336         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
337             "Reference concatenation: input type not supported");
338
339         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
340             "Reference concatenation: input and output types mismatched.");
341     }
342
343     return supported;
344 }
345
346 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
347                                           Optional<std::string&> reasonIfUnsupported) const
348 {
349     std::array<DataType,4> supportedTypes =
350     {
351         DataType::Float32,
352         DataType::Signed32,
353         DataType::QuantisedAsymm8,
354         DataType::QuantisedSymm16
355     };
356
357     return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
358                                   "Reference constant: output is not a supported type.");
359 }
360
361 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
362                                                    const TensorInfo& output,
363                                                    Optional<std::string&> reasonIfUnsupported) const
364 {
365     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
366                                           input.GetDataType(),
367                                           &TrueFunc<>,
368                                           &FalseInputFuncF32<>,
369                                           &FalseFuncU8<>,
370                                           &FalseFuncI32<>,
371                                           &FalseFuncU8<>) &&
372             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
373                                           output.GetDataType(),
374                                           &FalseOutputFuncF16<>,
375                                           &TrueFunc<>,
376                                           &FalseFuncU8<>,
377                                           &FalseFuncI32<>,
378                                           &FalseFuncU8<>));
379 }
380
381 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
382                                                    const TensorInfo& output,
383                                                    Optional<std::string&> reasonIfUnsupported) const
384 {
385     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
386                                           input.GetDataType(),
387                                           &FalseInputFuncF16<>,
388                                           &TrueFunc<>,
389                                           &FalseFuncU8<>,
390                                           &FalseFuncI32<>,
391                                           &FalseFuncU8<>) &&
392             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
393                                           output.GetDataType(),
394                                           &TrueFunc<>,
395                                           &FalseOutputFuncF32<>,
396                                           &FalseFuncU8<>,
397                                           &FalseFuncI32<>,
398                                           &FalseFuncU8<>));
399 }
400
401 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
402                                                const TensorInfo& output,
403                                                const Convolution2dDescriptor& descriptor,
404                                                const TensorInfo& weights,
405                                                const Optional<TensorInfo>& biases,
406                                                Optional<std::string&> reasonIfUnsupported) const
407 {
408     bool supported = true;
409
410     // Define supported types.
411     std::array<DataType,3> supportedTypes = {
412             DataType::Float32,
413             DataType::QuantisedAsymm8,
414             DataType::QuantisedSymm16
415     };
416
417     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
418                                   "Reference addition: input is not a supported type.");
419
420     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
421                                   "Reference addition: output is not a supported type.");
422
423     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
424                                   "Reference addition: weights is not a supported type.");
425
426     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
427                                   "Reference activation: input and output types mismatched.");
428
429     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
430                                   "Reference activation: input and weights types mismatched.");
431
432     if (biases.has_value())
433     {
434         std::array<DataType,3> biasesSupportedTypes = {
435                 DataType::Float32,
436                 DataType::Signed32
437         };
438         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
439                                       "Reference addition: biases is not a supported type.");
440     }
441     ignore_unused(descriptor);
442
443     return supported;
444 }
445
446 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
447                                        const TensorInfo& output,
448                                        Optional<std::string&> reasonIfUnsupported) const
449 {
450     ignore_unused(output);
451     return IsSupportedForDataTypeRef(reasonIfUnsupported,
452                                      input.GetDataType(),
453                                      &TrueFunc<>,
454                                      &TrueFunc<>);
455 }
456
457 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
458                                                       const TensorInfo& output,
459                                                       const DepthwiseConvolution2dDescriptor& descriptor,
460                                                       const TensorInfo& weights,
461                                                       const Optional<TensorInfo>& biases,
462                                                       Optional<std::string&> reasonIfUnsupported) const
463 {
464     ignore_unused(output);
465     ignore_unused(descriptor);
466     ignore_unused(weights);
467     ignore_unused(biases);
468     return IsSupportedForDataTypeRef(reasonIfUnsupported,
469                                      input.GetDataType(),
470                                      &TrueFunc<>,
471                                      &TrueFunc<>);
472 }
473
474 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
475                                             const TensorInfo& output,
476                                             Optional<std::string&> reasonIfUnsupported) const
477 {
478    bool supported = true;
479
480     std::array<DataType,2> supportedInputTypes = {
481         DataType::QuantisedAsymm8,
482         DataType::QuantisedSymm16
483     };
484
485     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
486                                   "Reference dequantize: input type not supported.");
487
488     std::array<DataType,2> supportedOutputTypes = {
489         DataType::Float32,
490     };
491
492     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
493                                   "Reference dequantize: output type not supported.");
494
495     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
496                                   "Reference dequantize: input and output shapes have different num total elements.");
497
498     return supported;
499 }
500
501 bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
502                                                       const armnn::TensorInfo& input1,
503                                                       const armnn::DetectionPostProcessDescriptor& descriptor,
504                                                       armnn::Optional<std::string&> reasonIfUnsupported) const
505 {
506     ignore_unused(input1);
507     return IsSupportedForDataTypeRef(reasonIfUnsupported,
508                                      input0.GetDataType(),
509                                      &TrueFunc<>,
510                                      &TrueFunc<>);
511 }
512
513 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
514                                                              const TensorInfo& output,
515                                                              const DepthwiseConvolution2dDescriptor& descriptor,
516                                                              const TensorInfo& weights,
517                                                              const Optional<TensorInfo>& biases,
518                                                              Optional<std::string&> reasonIfUnsupported) const
519 {
520     if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
521     {
522         return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
523     }
524     else
525     {
526         if (reasonIfUnsupported)
527         {
528             reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
529         }
530         return false;
531     }
532 }
533
534
535     bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
536                                           const TensorInfo& input1,
537                                           const TensorInfo& output,
538                                           Optional<std::string&> reasonIfUnsupported) const
539 {
540     bool supported = true;
541
542     std::array<DataType,3> supportedTypes = {
543         DataType::Float32,
544         DataType::QuantisedAsymm8,
545         DataType::QuantisedSymm16
546     };
547
548     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
549                                   "Reference division: input 0 is not a supported type.");
550
551     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
552                                   "Reference division: input 1 is not a supported type.");
553
554     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
555                                   "Reference division: output is not a supported type.");
556
557     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
558                                   "Reference division: input 0 and Input 1 types are mismatched");
559
560     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
561                                   "Reference division: input and output types are mismatched");
562
563     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
564                                   "Reference division: shapes are not suitable for implicit broadcast.");
565
566     return supported;
567 }
568
569 bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
570                                        const TensorInfo& input1,
571                                        const TensorInfo& output,
572                                        Optional<std::string&> reasonIfUnsupported) const
573 {
574     ignore_unused(input0);
575     ignore_unused(input1);
576     ignore_unused(output);
577     ignore_unused(reasonIfUnsupported);
578     return IsSupportedForDataTypeRef(reasonIfUnsupported,
579                                      input0.GetDataType(),
580                                      &TrueFunc<>,
581                                      &TrueFunc<>);
582 }
583
584 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
585                                                   const FakeQuantizationDescriptor& descriptor,
586                                                   Optional<std::string&> reasonIfUnsupported) const
587 {
588     ignore_unused(descriptor);
589     return IsSupportedForDataTypeRef(reasonIfUnsupported,
590                                      input.GetDataType(),
591                                      &TrueFunc<>,
592                                      &FalseFuncU8<>);
593 }
594
595 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
596                                        const TensorInfo& output,
597                                        Optional<std::string&> reasonIfUnsupported) const
598 {
599     ignore_unused(output);
600     return IsSupportedForDataTypeRef(reasonIfUnsupported,
601                                      input.GetDataType(),
602                                      &TrueFunc<>,
603                                      &FalseFuncU8<>);
604 }
605
606 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
607                                                 const TensorInfo& output,
608                                                 const TensorInfo& weights,
609                                                 const TensorInfo& biases,
610                                                 const FullyConnectedDescriptor& descriptor,
611                                                 Optional<std::string&> reasonIfUnsupported) const
612 {
613     bool supported = true;
614
615     // Define supported types.
616     std::array<DataType,3> supportedTypes =
617     {
618             DataType::Float32,
619             DataType::QuantisedAsymm8,
620             DataType::QuantisedSymm16
621     };
622
623     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624                                   "Reference Fully Connected: input type not supported.");
625
626     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
627                                   "Reference Fully Connected: output type not supported.");
628
629     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
630                                   "Reference Fully Connected: input and output types mismatched.");
631
632     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
633                                   "Reference Fully Connected: weights type not supported.");
634
635     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
636                                   "Reference Fully Connected: input and weight types mismatched.");
637
638     if (descriptor.m_BiasEnabled)
639     {
640         // Defined supported types for bias
641         std::array<DataType, 2>
642         supportedBiasTypes =
643         {
644             DataType::Float32,
645             DataType::Signed32
646         };
647
648         supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
649                                       "Reference Fully Connected: bias type not supported.");
650
651         supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
652                                       "Reference Fully Connected: bias and weight types mismatch.");
653
654         supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
655                                       "Reference Fully Connected: bias type inferred from weights is incompatible.");
656
657     }
658
659     return supported;
660 }
661
662 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
663                                         const armnn::TensorInfo& input1,
664                                         const armnn::TensorInfo& output,
665                                         armnn::Optional<std::string&> reasonIfUnsupported) const
666 {
667     ignore_unused(input1);
668     ignore_unused(output);
669     return IsSupportedForDataTypeRef(reasonIfUnsupported,
670                                      input0.GetDataType(),
671                                      &TrueFunc<>,
672                                      &TrueFunc<>);
673 }
674
675 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
676                                          const TensorInfo& input1,
677                                          const TensorInfo& output,
678                                          Optional<std::string&> reasonIfUnsupported) const
679 {
680     ignore_unused(input0);
681     ignore_unused(input1);
682     ignore_unused(output);
683     ignore_unused(reasonIfUnsupported);
684     return IsSupportedForDataTypeRef(reasonIfUnsupported,
685                                      input0.GetDataType(),
686                                      &TrueFunc<>,
687                                      &TrueFunc<>);
688 }
689
690 bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
691                                        Optional<std::string&> reasonIfUnsupported) const
692 {
693     return IsSupportedForDataTypeRef(reasonIfUnsupported,
694                                      input.GetDataType(),
695                                      &TrueFunc<>,
696                                      &TrueFunc<>);
697 }
698
699 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
700                                                  const TensorInfo& output,
701                                                  const L2NormalizationDescriptor& descriptor,
702                                                  Optional<std::string&> reasonIfUnsupported) const
703 {
704     ignore_unused(output);
705     ignore_unused(descriptor);
706     return IsSupportedForDataTypeRef(reasonIfUnsupported,
707                                      input.GetDataType(),
708                                      &TrueFunc<>,
709                                      &FalseFuncU8<>);
710 }
711
712 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
713                                       const TensorInfo& outputStateIn,
714                                       const TensorInfo& cellStateIn,
715                                       const TensorInfo& scratchBuffer,
716                                       const TensorInfo& outputStateOut,
717                                       const TensorInfo& cellStateOut,
718                                       const TensorInfo& output,
719                                       const LstmDescriptor& descriptor,
720                                       const TensorInfo& inputToForgetWeights,
721                                       const TensorInfo& inputToCellWeights,
722                                       const TensorInfo& inputToOutputWeights,
723                                       const TensorInfo& recurrentToForgetWeights,
724                                       const TensorInfo& recurrentToCellWeights,
725                                       const TensorInfo& recurrentToOutputWeights,
726                                       const TensorInfo& forgetGateBias,
727                                       const TensorInfo& cellBias,
728                                       const TensorInfo& outputGateBias,
729                                       const TensorInfo* inputToInputWeights,
730                                       const TensorInfo* recurrentToInputWeights,
731                                       const TensorInfo* cellToInputWeights,
732                                       const TensorInfo* inputGateBias,
733                                       const TensorInfo* projectionWeights,
734                                       const TensorInfo* projectionBias,
735                                       const TensorInfo* cellToForgetWeights,
736                                       const TensorInfo* cellToOutputWeights,
737                                       Optional<std::string&> reasonIfUnsupported) const
738 {
739     ignore_unused(descriptor);
740     ignore_unused(inputToForgetWeights);
741     ignore_unused(inputToCellWeights);
742     ignore_unused(inputToOutputWeights);
743     ignore_unused(recurrentToForgetWeights);
744     ignore_unused(recurrentToCellWeights);
745     ignore_unused(recurrentToOutputWeights);
746     ignore_unused(forgetGateBias);
747     ignore_unused(cellBias);
748     ignore_unused(outputGateBias);
749     ignore_unused(inputToInputWeights);
750     ignore_unused(recurrentToInputWeights);
751     ignore_unused(cellToInputWeights);
752     ignore_unused(inputGateBias);
753     ignore_unused(projectionWeights);
754     ignore_unused(projectionBias);
755     ignore_unused(cellToForgetWeights);
756     ignore_unused(cellToOutputWeights);
757
758     bool supported = true;
759
760     std::array<DataType,2> supportedTypes = {
761         DataType::Float32,
762         DataType::QuantisedSymm16
763     };
764
765     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
766                                   "Reference Lstm: input is not a supported type.");
767
768     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
769                                   "Reference Lstm: input and outputStateIn types are mismatched");
770
771     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
772                                   "Reference Lstm: input and cellStateIn types are mismatched");
773
774     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
775                                   "Reference Lstm: input and scratchBuffer types are mismatched");
776
777     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
778                                   "Reference Lstm: input and outputStateOut types are mismatched");
779
780     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
781                                   "Reference Lstm: input and cellStateOut types are mismatched");
782
783     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
784                                   "Reference Lstm: input and output types are mismatched");
785
786     return supported;
787 }
788
789 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
790                                          const TensorInfo& input1,
791                                          const TensorInfo& output,
792                                          Optional<std::string&> reasonIfUnsupported) const
793 {
794     bool supported = true;
795
796     std::array<DataType,3> supportedTypes = {
797         DataType::Float32,
798         DataType::QuantisedAsymm8,
799         DataType::QuantisedSymm16
800     };
801
802     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
803                                   "Reference maximum: input 0 is not a supported type.");
804
805     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
806                                   "Reference maximum: input 1 is not a supported type.");
807
808     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
809                                   "Reference maximum: output is not a supported type.");
810
811     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
812                                   "Reference maximum: input 0 and Input 1 types are mismatched");
813
814     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
815                                   "Reference maximum: input and output types are mismatched");
816
817     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
818                                   "Reference maximum: shapes are not suitable for implicit broadcast.");
819
820     return supported;
821 }
822
823 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
824                                       const TensorInfo& output,
825                                       const MeanDescriptor& descriptor,
826                                       Optional<std::string&> reasonIfUnsupported) const
827 {
828     ignore_unused(output);
829     ignore_unused(descriptor);
830     return IsSupportedForDataTypeRef(reasonIfUnsupported,
831                                      input.GetDataType(),
832                                      &TrueFunc<>,
833                                      &TrueFunc<>);
834 }
835
836 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
837                                         const TensorInfo& output,
838                                         const MergerDescriptor& descriptor,
839                                         Optional<std::string&> reasonIfUnsupported) const
840 {
841     return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
842 }
843
844 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
845                                          const TensorInfo &output,
846                                          Optional<std::string &> reasonIfUnsupported) const
847 {
848     ignore_unused(output);
849     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
850                                          input.GetDataType(),
851                                          &TrueFunc<>,
852                                          &TrueFunc<>,
853                                          &TrueFunc<>,
854                                          &FalseFuncI32<>,
855                                          &TrueFunc<>);
856 }
857
858 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
859                                          const TensorInfo& input1,
860                                          const TensorInfo& output,
861                                          Optional<std::string&> reasonIfUnsupported) const
862 {
863     bool supported = true;
864
865     std::array<DataType,3> supportedTypes = {
866         DataType::Float32,
867         DataType::QuantisedAsymm8,
868         DataType::QuantisedSymm16
869     };
870
871     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
872                                   "Reference minimum: input 0 is not a supported type.");
873
874     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
875                                   "Reference minimum: input 1 is not a supported type.");
876
877     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
878                                   "Reference minimum: output is not a supported type.");
879
880     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
881                                   "Reference minimum: input 0 and Input 1 types are mismatched");
882
883     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
884                                   "Reference minimum: input and output types are mismatched");
885
886     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
887                                   "Reference minimum: shapes are not suitable for implicit broadcast.");
888
889     return supported;
890 }
891
892 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
893                                                 const TensorInfo& input1,
894                                                 const TensorInfo& output,
895                                                 Optional<std::string&> reasonIfUnsupported) const
896 {
897     bool supported = true;
898
899     std::array<DataType,3> supportedTypes = {
900         DataType::Float32,
901         DataType::QuantisedAsymm8,
902         DataType::QuantisedSymm16
903     };
904
905     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
906                                   "Reference multiplication: input 0 is not a supported type.");
907
908     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
909                                   "Reference multiplication: input 1 is not a supported type.");
910
911     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
912                                   "Reference multiplication: output is not a supported type.");
913
914     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
915                                   "Reference multiplication: input 0 and Input 1 types are mismatched");
916
917     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
918                                   "Reference multiplication: input and output types are mismatched");
919
920     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
921                                   "Reference multiplication: shapes are not suitable for implicit broadcast.");
922
923     return supported;
924 }
925
926 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
927                                                const TensorInfo& output,
928                                                const NormalizationDescriptor& descriptor,
929                                                Optional<std::string&> reasonIfUnsupported) const
930 {
931     ignore_unused(output);
932     ignore_unused(descriptor);
933     return IsSupportedForDataTypeRef(reasonIfUnsupported,
934                                      input.GetDataType(),
935                                      &TrueFunc<>,
936                                      &FalseFuncU8<>);
937 }
938
939 bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
940                                         Optional<std::string&> reasonIfUnsupported) const
941 {
942     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
943                                          output.GetDataType(),
944                                          &TrueFunc<>,
945                                          &TrueFunc<>,
946                                          &TrueFunc<>,
947                                          &FalseFuncI32<>,
948                                          &TrueFunc<>);
949 }
950
951 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
952                                      const TensorInfo& output,
953                                      const PadDescriptor& descriptor,
954                                      Optional<std::string&> reasonIfUnsupported) const
955 {
956     ignore_unused(output);
957     ignore_unused(descriptor);
958     return IsSupportedForDataTypeRef(reasonIfUnsupported,
959                                      input.GetDataType(),
960                                      &TrueFunc<>,
961                                      &TrueFunc<>);
962 }
963
964 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
965                                          const TensorInfo& output,
966                                          const PermuteDescriptor& descriptor,
967                                          Optional<std::string&> reasonIfUnsupported) const
968 {
969     ignore_unused(output);
970     ignore_unused(descriptor);
971     return IsSupportedForDataTypeRef(reasonIfUnsupported,
972                                      input.GetDataType(),
973                                      &TrueFunc<>,
974                                      &TrueFunc<>);
975 }
976
977 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
978                                            const TensorInfo& output,
979                                            const Pooling2dDescriptor& descriptor,
980                                            Optional<std::string&> reasonIfUnsupported) const
981 {
982     ignore_unused(output);
983     ignore_unused(descriptor);
984     return IsSupportedForDataTypeRef(reasonIfUnsupported,
985                                      input.GetDataType(),
986                                      &TrueFunc<>,
987                                      &TrueFunc<>);
988 }
989
990 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
991                                           const TensorInfo& output,
992                                           Optional<std::string&> reasonIfUnsupported) const
993 {
994    bool supported = true;
995
996     // Define supported output types.
997     std::array<DataType,2> supportedInputTypes = {
998         DataType::Float32,
999     };
1000
1001     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1002                                   "Reference quantize: input type not supported.");
1003
1004     // Define supported output types.
1005     std::array<DataType,2> supportedOutputTypes = {
1006         DataType::QuantisedAsymm8,
1007         DataType::QuantisedSymm16
1008     };
1009     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1010                                   "Reference quantize: output type not supported.");
1011
1012     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1013                                   "Reference quantize: input and output shapes have different num total elements.");
1014
1015     return supported;
1016 }
1017
1018 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
1019                                          const ReshapeDescriptor& descriptor,
1020                                          Optional<std::string&> reasonIfUnsupported) const
1021 {
1022     ignore_unused(descriptor);
1023     // Define supported output types.
1024     std::array<DataType,4> supportedOutputTypes =
1025     {
1026         DataType::Float32,
1027         DataType::Float16,
1028         DataType::QuantisedAsymm8,
1029         DataType::QuantisedSymm16
1030     };
1031     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1032         "Reference reshape: input type not supported.");
1033 }
1034
1035 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
1036                                                 const TensorInfo& output,
1037                                                 Optional<std::string&> reasonIfUnsupported) const
1038 {
1039     ignore_unused(output);
1040     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1041                                      input.GetDataType(),
1042                                      &TrueFunc<>,
1043                                      &TrueFunc<>);
1044 }
1045
1046 bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1047                                        const TensorInfo& output,
1048                                        Optional<std::string&> reasonIfUnsupported) const
1049 {
1050     ignore_unused(output);
1051     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1052                                      input.GetDataType(),
1053                                      &TrueFunc<>,
1054                                      &FalseFuncU8<>);
1055 }
1056
1057 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1058                                          const TensorInfo& output,
1059                                          const SoftmaxDescriptor& descriptor,
1060                                          Optional<std::string&> reasonIfUnsupported) const
1061 {
1062     ignore_unused(output);
1063     bool supported = true;
1064     std::array<DataType,3> supportedTypes =
1065     {
1066             DataType::Float32,
1067             DataType::QuantisedAsymm8,
1068             DataType::QuantisedSymm16
1069     };
1070
1071     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1072                                   "Reference concatenation: output type not supported");
1073
1074     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1075                                   "Reference concatenation: input type not supported");
1076
1077     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1078                                   "Reference concatenation: input type not supported");
1079
1080     return supported;
1081 }
1082
1083 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1084                                                 const TensorInfo& output,
1085                                                 const SpaceToBatchNdDescriptor& descriptor,
1086                                                 Optional<std::string&> reasonIfUnsupported) const
1087 {
1088     ignore_unused(output);
1089     ignore_unused(descriptor);
1090     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1091                                      input.GetDataType(),
1092                                      &TrueFunc<>,
1093                                      &TrueFunc<>);
1094 }
1095
1096 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1097                                           const ViewsDescriptor& descriptor,
1098                                           Optional<std::string&> reasonIfUnsupported) const
1099 {
1100     ignore_unused(descriptor);
1101     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1102                                      input.GetDataType(),
1103                                      &TrueFunc<>,
1104                                      &TrueFunc<>);
1105 }
1106
1107 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1108                                           const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1109                                           const ViewsDescriptor& descriptor,
1110                                           Optional<std::string&> reasonIfUnsupported) const
1111 {
1112     ignore_unused(descriptor);
1113     ignore_unused(outputs);
1114     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1115                                      input.GetDataType(),
1116                                      &TrueFunc<>,
1117                                      &TrueFunc<>);
1118 }
1119
1120 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1121                                               const TensorInfo& output,
1122                                               const StridedSliceDescriptor& descriptor,
1123                                               Optional<std::string&> reasonIfUnsupported) const
1124 {
1125     ignore_unused(output);
1126     ignore_unused(descriptor);
1127     return IsSupportedForDataTypeRef(reasonIfUnsupported,
1128                                      input.GetDataType(),
1129                                      &TrueFunc<>,
1130                                      &TrueFunc<>);
1131 }
1132
1133 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1134                                              const TensorInfo& input1,
1135                                              const TensorInfo& output,
1136                                              Optional<std::string&> reasonIfUnsupported) const
1137 {
1138     bool supported = true;
1139
1140     std::array<DataType,3> supportedTypes = {
1141         DataType::Float32,
1142         DataType::QuantisedAsymm8,
1143         DataType::QuantisedSymm16
1144     };
1145
1146     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1147                                   "Reference subtraction: input 0 is not a supported type.");
1148
1149     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1150                                   "Reference subtraction: input 1 is not a supported type.");
1151
1152     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1153                                   "Reference subtraction: output is not a supported type.");
1154
1155     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1156                                   "Reference subtraction: input 0 and Input 1 types are mismatched");
1157
1158     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1159                                   "Reference subtraction: input and output types are mismatched");
1160
1161     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1162                                   "Reference subtraction: shapes are not suitable for implicit broadcast.");
1163
1164     return supported;
1165 }
1166
1167 } // namespace armnn