IVGCVSW-3592 Add Support for Quantize to HAL 1.2 Driver
[platform/upstream/armnn.git] / src / backends / reference / RefLayerSupport.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefLayerSupport.hpp"
7 #include "RefBackendId.hpp"
8
9 #include <InternalTypes.hpp>
10 #include <LayerSupportCommon.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Descriptors.hpp>
13
14 #include <backendsCommon/BackendRegistry.hpp>
15 #include <backendsCommon/test/WorkloadTestUtils.hpp>
16
17 #include <boost/core/ignore_unused.hpp>
18
19 #include <vector>
20 #include <algorithm>
21 #include <array>
22
23 using namespace boost;
24
25 namespace armnn
26 {
27
28 namespace
29 {
30
31 template<typename Float32Func, typename Uint8Func, typename ... Params>
32 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
33                                DataType dataType,
34                                Float32Func floatFuncPtr,
35                                Uint8Func uint8FuncPtr,
36                                Params&&... params)
37 {
38     return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
39                                          dataType,
40                                          &FalseFunc<Params...>,
41                                          floatFuncPtr,
42                                          uint8FuncPtr,
43                                          &FalseFunc<Params...>,
44                                          &FalseFunc<Params...>,
45                                          std::forward<Params>(params)...);
46 }
47
48 } // anonymous namespace
49
50 namespace
51 {
52
53 std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected,
54                                               unsigned int actual,
55                                               std::string& layerStr,
56                                               std::string& tensorName)
57 {
58     std::string errorMsg = "Reference " + layerStr + ": Expected " + std::to_string(expected) + " dimensions but got" +
59                            " " + std::to_string(actual) + " dimensions instead, for the '" + tensorName + "' tensor.";
60
61     return errorMsg;
62 }
63
64 } // anonymous namespace
65
66 namespace
67 {
68 template<typename F>
69 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
70 {
71     bool supported = rule();
72     if (!supported && reason)
73     {
74         reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
75     }
76     return supported;
77 }
78
79 struct Rule
80 {
81     bool operator()() const
82     {
83         return m_Res;
84     }
85
86     bool m_Res = true;
87 };
88
89 template<typename T>
90 bool AllTypesAreEqualImpl(T t)
91 {
92     return true;
93 }
94
95 template<typename T, typename... Rest>
96 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
97 {
98     static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
99
100     return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
101 }
102
103 struct TypesAreEqual : public Rule
104 {
105     template<typename ... Ts>
106     TypesAreEqual(const Ts&... ts)
107     {
108         m_Res = AllTypesAreEqualImpl(ts...);
109     }
110 };
111
112 struct QuantizationParametersAreEqual : public Rule
113 {
114     QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
115     {
116         m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
117                 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
118     }
119 };
120
121 struct TypeAnyOf : public Rule
122 {
123     template<typename Container>
124     TypeAnyOf(const TensorInfo& info, const Container& c)
125     {
126         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
127         {
128             return dt == info.GetDataType();
129         });
130     }
131 };
132
133 struct TypeIs : public Rule
134 {
135     TypeIs(const TensorInfo& info, DataType dt)
136     {
137         m_Res = dt == info.GetDataType();
138     }
139 };
140
141 struct BiasAndWeightsTypesMatch : public Rule
142 {
143     BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
144     {
145         m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
146     }
147 };
148
149 struct BiasAndWeightsTypesCompatible : public Rule
150 {
151     template<typename Container>
152     BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
153     {
154         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
155         {
156             return dt ==  GetBiasTypeFromWeightsType(info.GetDataType()).value();
157         });
158     }
159 };
160
161 struct ShapesAreSameRank : public Rule
162 {
163     ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
164     {
165         m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
166     }
167 };
168
169 struct ShapesAreSameTotalSize : public Rule
170 {
171     ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
172     {
173         m_Res = info0.GetNumElements() == info1.GetNumElements();
174     }
175 };
176
177 struct ShapesAreBroadcastCompatible : public Rule
178 {
179     unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
180     {
181         unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
182         unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
183         return sizeIn;
184     }
185
186     ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
187     {
188         const TensorShape& shape0 = in0.GetShape();
189         const TensorShape& shape1 = in1.GetShape();
190         const TensorShape& outShape = out.GetShape();
191
192         for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
193         {
194             unsigned int sizeOut = outShape[i];
195             unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
196             unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
197
198             m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
199                      ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
200         }
201     }
202 };
203
204 struct TensorNumDimensionsAreCorrect : public Rule
205 {
206     TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
207     {
208         m_Res = info.GetNumDimensions() == expectedNumDimensions;
209     }
210 };
211
212 } // namespace
213
214
215 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
216                                             const TensorInfo& output,
217                                             const ActivationDescriptor& descriptor,
218                                             Optional<std::string&> reasonIfUnsupported) const
219 {
220    bool supported = true;
221
222     // Define supported types.
223     std::array<DataType,3> supportedTypes = {
224         DataType::Float32,
225         DataType::QuantisedAsymm8,
226         DataType::QuantisedSymm16
227     };
228
229     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
230                                   "Reference activation: input type not supported.");
231
232     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
233                                   "Reference activation: output type not supported.");
234
235     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
236                                   "Reference activation: input and output types mismatched.");
237
238     supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
239                                   "Reference activation: input and output shapes are of different rank.");
240
241
242     struct ActivationFunctionSupported : public Rule
243     {
244         ActivationFunctionSupported(const ActivationDescriptor& desc)
245         {
246             switch(desc.m_Function)
247             {
248                 case ActivationFunction::Abs:
249                 case ActivationFunction::BoundedReLu:
250                 case ActivationFunction::LeakyReLu:
251                 case ActivationFunction::Linear:
252                 case ActivationFunction::ReLu:
253                 case ActivationFunction::Sigmoid:
254                 case ActivationFunction::SoftReLu:
255                 case ActivationFunction::Sqrt:
256                 case ActivationFunction::Square:
257                 case ActivationFunction::TanH:
258                 {
259                     m_Res = true;
260                     break;
261                 }
262                 default:
263                 {
264                     m_Res = false;
265                     break;
266                 }
267             }
268         }
269     };
270
271     // Function is supported
272     supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
273                                   "Reference activation: function not supported.");
274
275     return supported;
276 }
277
278 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
279                                           const TensorInfo& input1,
280                                           const TensorInfo& output,
281                                           Optional<std::string&> reasonIfUnsupported) const
282 {
283     bool supported = true;
284
285     std::array<DataType,3> supportedTypes = {
286         DataType::Float32,
287         DataType::QuantisedAsymm8,
288         DataType::QuantisedSymm16
289     };
290
291     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
292                                   "Reference addition: input 0 is not a supported type.");
293
294     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
295                                   "Reference addition: input 1 is not a supported type.");
296
297     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
298                                   "Reference addition: output is not a supported type.");
299
300     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
301                                   "Reference addition: input 0 and Input 1 types are mismatched");
302
303     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
304                                   "Reference addition: input and output types are mismatched");
305
306     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
307                                   "Reference addition: shapes are not suitable for implicit broadcast.");
308
309     return supported;
310 }
311
312 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
313                                                     const TensorInfo& output,
314                                                     const TensorInfo& mean,
315                                                     const TensorInfo& variance,
316                                                     const TensorInfo& beta,
317                                                     const TensorInfo& gamma,
318                                                     const BatchNormalizationDescriptor& descriptor,
319                                                     Optional<std::string&> reasonIfUnsupported) const
320 {
321     ignore_unused(descriptor);
322
323     std::array<DataType, 3> supportedTypes =
324     {
325         DataType::Float32,
326         DataType::QuantisedAsymm8,
327         DataType::QuantisedSymm16
328     };
329
330     bool supported = true;
331
332     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
333                                   "Reference batch normalization: input is not a supported type.");
334
335     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
336                                   "Reference batch normalization: output is not a supported type.");
337
338     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
339                                   "Reference batch normalization: input and output types are mismatched");
340
341     supported &= CheckSupportRule(TypeAnyOf(mean, supportedTypes), reasonIfUnsupported,
342                                   "Reference batch normalization: mean is not a supported type.");
343
344     supported &= CheckSupportRule(TypeAnyOf(variance, supportedTypes), reasonIfUnsupported,
345                                   "Reference batch normalization: variance is not a supported type.");
346
347     supported &= CheckSupportRule(TypeAnyOf(beta, supportedTypes), reasonIfUnsupported,
348                                   "Reference batch normalization: beta is not a supported type.");
349
350     supported &= CheckSupportRule(TypeAnyOf(gamma, supportedTypes), reasonIfUnsupported,
351                                   "Reference batch normalization: gamma is not a supported type.");
352
353     return supported;
354 }
355
356 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
357                                                 const TensorInfo& output,
358                                                 const BatchToSpaceNdDescriptor& descriptor,
359                                                 Optional<std::string&> reasonIfUnsupported) const
360 {
361     ignore_unused(descriptor);
362
363     bool supported = true;
364
365     std::string batchToSpaceNdLayerStr = "batchToSpaceNd";
366     std::string inputTensorStr = "input";
367     std::string outputTensorStr = "output";
368
369     // Define supported types.
370     std::array<DataType,3> supportedTypes =
371     {
372             DataType::Float32,
373             DataType::QuantisedAsymm8,
374             DataType::QuantisedSymm16
375     };
376
377     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
378                                   "Reference BatchToSpaceNd: input type not supported.");
379
380     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
381                                   "Reference BatchToSpaceNd: output type not supported.");
382
383     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
384                                   "Reference BatchToSpaceNd: input and output types mismatched.");
385
386     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 4),
387                                   reasonIfUnsupported,
388                                   CreateIncorrectDimensionsErrorMsg(4,
389                                                                     output.GetNumDimensions(),
390                                                                     batchToSpaceNdLayerStr,
391                                                                     outputTensorStr).data());
392
393     supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(input, 4),
394                                   reasonIfUnsupported,
395                                   CreateIncorrectDimensionsErrorMsg(4,
396                                                                     input.GetNumDimensions(),
397                                                                     batchToSpaceNdLayerStr,
398                                                                     inputTensorStr).data());
399
400     return supported;
401 }
402
403 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
404                                         const TensorInfo& output,
405                                         const ConcatDescriptor& descriptor,
406                                         Optional<std::string&> reasonIfUnsupported) const
407 {
408     ignore_unused(descriptor);
409
410     bool supported = true;
411     std::array<DataType,3> supportedTypes =
412     {
413             DataType::Float32,
414             DataType::QuantisedAsymm8,
415             DataType::QuantisedSymm16
416     };
417
418     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
419                                   "Reference concatenation: output type not supported");
420     for (const TensorInfo* input : inputs)
421     {
422         BOOST_ASSERT(input != nullptr);
423         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
424             "Reference concatenation: input type not supported");
425
426         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
427             "Reference concatenation: input and output types mismatched.");
428     }
429
430     return supported;
431 }
432
433 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
434                                           Optional<std::string&> reasonIfUnsupported) const
435 {
436     std::array<DataType,4> supportedTypes =
437     {
438         DataType::Float32,
439         DataType::Signed32,
440         DataType::QuantisedAsymm8,
441         DataType::QuantisedSymm16
442     };
443
444     return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
445                                   "Reference constant: output is not a supported type.");
446 }
447
448 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
449                                                    const TensorInfo& output,
450                                                    Optional<std::string&> reasonIfUnsupported) const
451 {
452     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
453                                           input.GetDataType(),
454                                           &TrueFunc<>,
455                                           &FalseInputFuncF32<>,
456                                           &FalseFuncU8<>,
457                                           &FalseFuncI32<>,
458                                           &FalseFuncU8<>) &&
459             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
460                                           output.GetDataType(),
461                                           &FalseOutputFuncF16<>,
462                                           &TrueFunc<>,
463                                           &FalseFuncU8<>,
464                                           &FalseFuncI32<>,
465                                           &FalseFuncU8<>));
466 }
467
468 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
469                                                    const TensorInfo& output,
470                                                    Optional<std::string&> reasonIfUnsupported) const
471 {
472     return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
473                                           input.GetDataType(),
474                                           &FalseInputFuncF16<>,
475                                           &TrueFunc<>,
476                                           &FalseFuncU8<>,
477                                           &FalseFuncI32<>,
478                                           &FalseFuncU8<>) &&
479             IsSupportedForDataTypeGeneric(reasonIfUnsupported,
480                                           output.GetDataType(),
481                                           &TrueFunc<>,
482                                           &FalseOutputFuncF32<>,
483                                           &FalseFuncU8<>,
484                                           &FalseFuncI32<>,
485                                           &FalseFuncU8<>));
486 }
487
488 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
489                                                const TensorInfo& output,
490                                                const Convolution2dDescriptor& descriptor,
491                                                const TensorInfo& weights,
492                                                const Optional<TensorInfo>& biases,
493                                                Optional<std::string&> reasonIfUnsupported) const
494 {
495     bool supported = true;
496
497     // Define supported types.
498     std::array<DataType,3> supportedTypes = {
499             DataType::Float32,
500             DataType::QuantisedAsymm8,
501             DataType::QuantisedSymm16
502     };
503
504     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
505                                   "Reference convolution2d: input is not a supported type.");
506
507     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
508                                   "Reference convolution2d: output is not a supported type.");
509
510     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
511                                   "Reference convolution2d: weights is not a supported type.");
512
513     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
514                                   "Reference convolution2d: input and output types mismatched.");
515
516     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
517                                   "Reference convolution2d: input and weights types mismatched.");
518
519     if (biases.has_value())
520     {
521         std::array<DataType,3> biasesSupportedTypes = {
522                 DataType::Float32,
523                 DataType::Signed32
524         };
525         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
526                                       "Reference convolution2d: biases is not a supported type.");
527     }
528     ignore_unused(descriptor);
529
530     return supported;
531 }
532
533 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
534                                        const TensorInfo& output,
535                                        Optional<std::string&> reasonIfUnsupported) const
536 {
537     bool supported = true;
538
539     std::array<DataType,3> supportedTypes =
540     {
541         DataType::Float32,
542         DataType::QuantisedAsymm8,
543         DataType::QuantisedSymm16
544     };
545
546     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
547                                   "Reference debug: input type not supported");
548
549     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
550                                   "Reference debug: output type not supported");
551
552     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
553                                   "Reference debug: input and output types are mismatched");
554
555     return supported;
556 }
557
558 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
559                                                       const TensorInfo& output,
560                                                       const DepthwiseConvolution2dDescriptor& descriptor,
561                                                       const TensorInfo& weights,
562                                                       const Optional<TensorInfo>& biases,
563                                                       Optional<std::string&> reasonIfUnsupported) const
564 {
565     bool supported = true;
566
567     // Define supported types.
568     std::array<DataType,3> supportedTypes =
569     {
570         DataType::Float32,
571         DataType::QuantisedAsymm8,
572         DataType::QuantisedSymm16
573     };
574
575     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
576                                   "Reference DepthwiseConvolution2d: input is not a supported type.");
577
578     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
579                                   "Reference DepthwiseConvolution2d: output is not a supported type.");
580
581     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
582                                   "Reference DepthwiseConvolution2d: weights is not a supported type.");
583
584     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
585                                   "Reference DepthwiseConvolution2d: input and output types mismatched.");
586
587     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
588                                   "Reference DepthwiseConvolution2d: input and weights types mismatched.");
589
590     if (biases.has_value())
591     {
592         std::array<DataType,2> biasesSupportedTypes =
593         {
594             DataType::Float32,
595             DataType::Signed32
596         };
597         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
598                                       "Reference DepthwiseConvolution2d: biases is not a supported type.");
599     }
600     ignore_unused(descriptor);
601
602     return supported;
603
604 }
605
606 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
607                                             const TensorInfo& output,
608                                             Optional<std::string&> reasonIfUnsupported) const
609 {
610    bool supported = true;
611
612     std::array<DataType,2> supportedInputTypes = {
613         DataType::QuantisedAsymm8,
614         DataType::QuantisedSymm16
615     };
616
617     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
618                                   "Reference dequantize: input type not supported.");
619
620     std::array<DataType,2> supportedOutputTypes = {
621         DataType::Float32,
622     };
623
624     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
625                                   "Reference dequantize: output type not supported.");
626
627     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
628                                   "Reference dequantize: input and output shapes have different num total elements.");
629
630     return supported;
631 }
632
633 bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
634                                                       const armnn::TensorInfo& input1,
635                                                       const armnn::DetectionPostProcessDescriptor& descriptor,
636                                                       armnn::Optional<std::string&> reasonIfUnsupported) const
637 {
638     bool supported = true;
639
640     std::vector<DataType> supportedInputTypes =
641     {
642         DataType::Float32,
643         DataType::QuantisedAsymm8,
644         DataType::QuantisedSymm16
645     };
646
647     supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
648                                   "Reference DetectionPostProcess: input 0 is not a supported type.");
649
650     supported &= CheckSupportRule(TypeAnyOf(input1, supportedInputTypes), reasonIfUnsupported,
651                                   "Reference DetectionPostProcess: input 1 is not a supported type.");
652
653     return supported;
654 }
655
656 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
657                                                              const TensorInfo& output,
658                                                              const DepthwiseConvolution2dDescriptor& descriptor,
659                                                              const TensorInfo& weights,
660                                                              const Optional<TensorInfo>& biases,
661                                                              Optional<std::string&> reasonIfUnsupported) const
662 {
663     return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
664 }
665
666 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
667                                           const TensorInfo& input1,
668                                           const TensorInfo& output,
669                                           Optional<std::string&> reasonIfUnsupported) const
670 {
671     bool supported = true;
672
673     std::array<DataType,3> supportedTypes = {
674         DataType::Float32,
675         DataType::QuantisedAsymm8,
676         DataType::QuantisedSymm16
677     };
678
679     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
680                                   "Reference division: input 0 is not a supported type.");
681
682     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
683                                   "Reference division: input 1 is not a supported type.");
684
685     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
686                                   "Reference division: output is not a supported type.");
687
688     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
689                                   "Reference division: input 0 and Input 1 types are mismatched");
690
691     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
692                                   "Reference division: input and output types are mismatched");
693
694     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
695                                   "Reference division: shapes are not suitable for implicit broadcast.");
696
697     return supported;
698 }
699
700 bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
701                                        const TensorInfo& input1,
702                                        const TensorInfo& output,
703                                        Optional<std::string&> reasonIfUnsupported) const
704 {
705     bool supported = true;
706
707     std::array<DataType,3> supportedTypes =
708     {
709         DataType::Float32,
710         DataType::QuantisedAsymm8,
711         DataType::QuantisedSymm16
712     };
713
714     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
715                                   "Reference equal: input 0 is not a supported type.");
716
717     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
718                                   "Reference equal: input 1 is not a supported type.");
719
720     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
721                                   "Reference equal: input 0 and Input 1 types are mismatched");
722
723     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
724                                   "Reference equal: shapes are not suitable for implicit broadcast.");
725
726     return supported;
727 }
728
729 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
730                                                   const FakeQuantizationDescriptor& descriptor,
731                                                   Optional<std::string&> reasonIfUnsupported) const
732 {
733     ignore_unused(descriptor);
734     bool supported = true;
735
736     std::array<DataType,1> supportedTypes =
737     {
738         DataType::Float32
739     };
740
741     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
742                                   "Reference fake quantization: input type not supported.");
743
744     return supported;
745 }
746
747 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
748                                        const TensorInfo& output,
749                                        Optional<std::string&> reasonIfUnsupported) const
750 {
751     ignore_unused(output);
752     bool supported = true;
753
754     std::array<DataType,2> supportedTypes =
755     {
756         DataType::Float32,
757         DataType::QuantisedSymm16
758     };
759
760     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
761                                   "Reference Floor: input type not supported.");
762
763     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
764                                   "Reference Floor: output type not supported.");
765
766     return supported;
767 }
768
769 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
770                                                 const TensorInfo& output,
771                                                 const TensorInfo& weights,
772                                                 const TensorInfo& biases,
773                                                 const FullyConnectedDescriptor& descriptor,
774                                                 Optional<std::string&> reasonIfUnsupported) const
775 {
776     bool supported = true;
777
778     // Define supported types.
779     std::array<DataType,3> supportedTypes =
780     {
781             DataType::Float32,
782             DataType::QuantisedAsymm8,
783             DataType::QuantisedSymm16
784     };
785
786     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
787                                   "Reference Fully Connected: input type not supported.");
788
789     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
790                                   "Reference Fully Connected: output type not supported.");
791
792     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
793                                   "Reference Fully Connected: input and output types mismatched.");
794
795     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
796                                   "Reference Fully Connected: weights type not supported.");
797
798     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
799                                   "Reference Fully Connected: input and weight types mismatched.");
800
801     if (descriptor.m_BiasEnabled)
802     {
803         // Defined supported types for bias
804         std::array<DataType, 2>
805         supportedBiasTypes =
806         {
807             DataType::Float32,
808             DataType::Signed32
809         };
810
811         supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
812                                       "Reference Fully Connected: bias type not supported.");
813
814         supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
815                                       "Reference Fully Connected: bias and weight types mismatch.");
816
817         supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
818                                       "Reference Fully Connected: bias type inferred from weights is incompatible.");
819
820     }
821
822     return supported;
823 }
824
825 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
826                                         const armnn::TensorInfo& input1,
827                                         const armnn::TensorInfo& output,
828                                         armnn::Optional<std::string&> reasonIfUnsupported) const
829 {
830     bool supported = true;
831     std::array<DataType,3> supportedTypes =
832     {
833         DataType::Float32,
834         DataType::QuantisedAsymm8,
835         DataType::QuantisedSymm16
836     };
837
838     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
839                                   "Reference Gather: input type not supported");
840
841     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
842                                   "Reference Gather: output type not supported");
843
844     supported &= CheckSupportRule(TypeIs(input1, DataType::Signed32), reasonIfUnsupported,
845                                   "Reference Gather: indices (input1) type not supported");
846
847     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
848                                   "Reference Gather: input and output types not matching");
849
850     return supported;
851 }
852
853 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
854                                          const TensorInfo& input1,
855                                          const TensorInfo& output,
856                                          Optional<std::string&> reasonIfUnsupported) const
857 {
858     bool supported = true;
859
860     std::array<DataType,3> supportedTypes =
861     {
862         DataType::Float32,
863         DataType::QuantisedAsymm8,
864         DataType::QuantisedSymm16
865     };
866
867     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
868                                   "Reference greater: input 0 is not a supported type.");
869
870     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
871                                   "Reference greater: input 1 is not a supported type.");
872
873     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
874                                   "Reference greater: input 0 and Input 1 types are mismatched");
875
876     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
877                                   "Reference greater: shapes are not suitable for implicit broadcast.");
878
879     return supported;
880 }
881
882 bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
883                                        Optional<std::string&> reasonIfUnsupported) const
884 {
885     return true;
886 }
887
888 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
889                                                  const TensorInfo& output,
890                                                  const L2NormalizationDescriptor& descriptor,
891                                                  Optional<std::string&> reasonIfUnsupported) const
892 {
893     ignore_unused(descriptor);
894     // Define supported types
895     std::array<DataType, 3> supportedTypes =
896     {
897         DataType::Float32,
898         DataType::QuantisedAsymm8,
899         DataType::QuantisedSymm16
900     };
901
902     bool supported = true;
903
904     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
905                                   "Reference L2normalization: input type not supported.");
906
907     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
908                                   "Reference L2normalization: output type not supported.");
909
910     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
911                                   "Reference L2normalization: input and output types mismatched.");
912
913     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
914                                   "Reference L2normalization: input and output shapes have different "
915                                   "num total elements.");
916
917     return supported;
918 }
919
920 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
921                                       const TensorInfo& outputStateIn,
922                                       const TensorInfo& cellStateIn,
923                                       const TensorInfo& scratchBuffer,
924                                       const TensorInfo& outputStateOut,
925                                       const TensorInfo& cellStateOut,
926                                       const TensorInfo& output,
927                                       const LstmDescriptor& descriptor,
928                                       const LstmInputParamsInfo& paramsInfo,
929                                       Optional<std::string&> reasonIfUnsupported) const
930 {
931     ignore_unused(descriptor);
932     ignore_unused(paramsInfo);
933
934     bool supported = true;
935
936     std::array<DataType,2> supportedTypes = {
937         DataType::Float32,
938         DataType::QuantisedSymm16
939     };
940
941     // check inputs and outputs
942     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
943                                   "Reference Lstm: input is not a supported type.");
944     supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
945                                   "Reference Lstm: input and outputStateIn types are mismatched");
946     supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
947                                   "Reference Lstm: input and cellStateIn types are mismatched");
948     supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
949                                   "Reference Lstm: input and scratchBuffer types are mismatched");
950     supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
951                                   "Reference Lstm: input and outputStateOut types are mismatched");
952     supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
953                                   "Reference Lstm: input and cellStateOut types are mismatched");
954     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
955                                   "Reference Lstm: input and output types are mismatched");
956     // check layer parameters
957     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
958                                   "Reference Lstm: input and InputToForgetWeights types are mismatched");
959     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
960                                   "Reference Lstm: input and InputToCellWeights types are mismatched");
961     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
962                                   "Reference Lstm: input and InputToOutputWeights types are mismatched");
963     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
964                                   "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
965     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
966                                   "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
967     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
968                                   "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
969     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
970                                   "Reference Lstm: input and ForgetGateBias types are mismatched");
971     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
972                                   "Reference Lstm: input and CellBias types are mismatched");
973     supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
974                                   "Reference Lstm: input and OutputGateBias types are mismatched");
975     if (!descriptor.m_CifgEnabled)
976     {
977         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
978                                       "Reference Lstm: input and InputToInputWeights types are mismatched");
979         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
980                                       reasonIfUnsupported,
981                                       "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
982         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
983                                       "Reference Lstm: input and InputGateBias types are mismatched");
984         if (descriptor.m_PeepholeEnabled)
985         {
986             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
987                                           reasonIfUnsupported,
988                                           "Reference Lstm: input and CellToInputWeights types are mismatched");
989         }
990     }
991     if (descriptor.m_PeepholeEnabled)
992     {
993         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
994                                       "Reference Lstm: input and CellToForgetWeights types are mismatched");
995         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
996                                       "Reference Lstm: input and CellToOutputWeights types are mismatched");
997     }
998     if (descriptor.m_ProjectionEnabled)
999     {
1000         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
1001                                       "Reference Lstm: input and mProjectionWeights types are mismatched");
1002         if (paramsInfo.m_ProjectionBias != nullptr)
1003         {
1004             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
1005                                           "Reference Lstm: input and ProjectionBias types are mismatched");
1006         }
1007     }
1008     if (descriptor.m_LayerNormEnabled)
1009     {
1010         if (!descriptor.m_CifgEnabled)
1011         {
1012             supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
1013                                           reasonIfUnsupported,
1014                                           "Reference Lstm: input and InputLayerNormWeights types are mismatched");
1015         }
1016         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
1017                                       reasonIfUnsupported,
1018                                       "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
1019         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
1020                                       reasonIfUnsupported,
1021                                       "Reference Lstm: input and CellLayerNormWeights types are mismatched");
1022         supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
1023                                       reasonIfUnsupported,
1024                                       "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
1025     }
1026
1027     return supported;
1028 }
1029
1030 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1031                                          const TensorInfo& input1,
1032                                          const TensorInfo& output,
1033                                          Optional<std::string&> reasonIfUnsupported) const
1034 {
1035     bool supported = true;
1036
1037     std::array<DataType,3> supportedTypes = {
1038         DataType::Float32,
1039         DataType::QuantisedAsymm8,
1040         DataType::QuantisedSymm16
1041     };
1042
1043     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1044                                   "Reference maximum: input 0 is not a supported type.");
1045
1046     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1047                                   "Reference maximum: input 1 is not a supported type.");
1048
1049     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1050                                   "Reference maximum: output is not a supported type.");
1051
1052     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1053                                   "Reference maximum: input 0 and Input 1 types are mismatched");
1054
1055     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1056                                   "Reference maximum: input and output types are mismatched");
1057
1058     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1059                                   "Reference maximum: shapes are not suitable for implicit broadcast.");
1060
1061     return supported;
1062 }
1063
1064 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
1065                                       const TensorInfo& output,
1066                                       const MeanDescriptor& descriptor,
1067                                       Optional<std::string&> reasonIfUnsupported) const
1068 {
1069     bool supported = true;
1070     std::string meanLayerStr = "Mean";
1071     std::string outputTensorStr = "output";
1072
1073     std::array<DataType,3> supportedTypes =
1074     {
1075         DataType::Float32,
1076         DataType::QuantisedAsymm8,
1077         DataType::QuantisedSymm16
1078     };
1079
1080     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1081                                   "Reference Mean: input type not supported.");
1082
1083     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1084                                   "Reference Mean: input and output types are mismatched");
1085
1086     if (descriptor.m_KeepDims)
1087     {
1088         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, input.GetNumDimensions()),
1089                                       reasonIfUnsupported,
1090                                       CreateIncorrectDimensionsErrorMsg(input.GetNumDimensions(),
1091                                                                         output.GetNumDimensions(),
1092                                                                         meanLayerStr, outputTensorStr).data());
1093     }
1094     else if (descriptor.m_Axis.empty())
1095     {
1096         supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1097                                       reasonIfUnsupported,
1098                                       CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1099                                                                         meanLayerStr, outputTensorStr).data());
1100     }
1101     else
1102     {
1103         auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(descriptor.m_Axis.size());
1104
1105         if (outputDim > 0)
1106         {
1107             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, outputDim),
1108                                           reasonIfUnsupported,
1109                                           CreateIncorrectDimensionsErrorMsg(outputDim, output.GetNumDimensions(),
1110                                                                             meanLayerStr, outputTensorStr).data());
1111         }
1112         else
1113         {
1114             supported &= CheckSupportRule(TensorNumDimensionsAreCorrect(output, 1),
1115                                           reasonIfUnsupported,
1116                                           CreateIncorrectDimensionsErrorMsg(1, output.GetNumDimensions(),
1117                                                                             meanLayerStr, outputTensorStr).data());
1118         }
1119     }
1120
1121     return supported;
1122 }
1123
1124 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
1125                                         const TensorInfo& output,
1126                                         const MergerDescriptor& descriptor,
1127                                         Optional<std::string&> reasonIfUnsupported) const
1128 {
1129     return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
1130 }
1131
1132 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
1133                                          const TensorInfo &output,
1134                                          Optional<std::string &> reasonIfUnsupported) const
1135 {
1136     bool supported = true;
1137
1138     std::array<DataType,5> supportedTypes =
1139     {
1140         DataType::Float32,
1141         DataType::Float16,
1142         DataType::QuantisedAsymm8,
1143         DataType::QuantisedSymm16,
1144         DataType::Boolean
1145     };
1146
1147     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1148                                   "Reference MemCopy: input type not supported");
1149
1150     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1151                                   "Reference MemCopy: output type not supported");
1152
1153     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1154                                   "Reference MemCopy: input and output types are mismatched");
1155
1156     return supported;
1157 }
1158
1159 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1160                                          const TensorInfo& input1,
1161                                          const TensorInfo& output,
1162                                          Optional<std::string&> reasonIfUnsupported) const
1163 {
1164     bool supported = true;
1165
1166     std::array<DataType,3> supportedTypes = {
1167         DataType::Float32,
1168         DataType::QuantisedAsymm8,
1169         DataType::QuantisedSymm16
1170     };
1171
1172     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1173                                   "Reference minimum: input 0 is not a supported type.");
1174
1175     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1176                                   "Reference minimum: input 1 is not a supported type.");
1177
1178     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1179                                   "Reference minimum: output is not a supported type.");
1180
1181     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1182                                   "Reference minimum: input 0 and Input 1 types are mismatched");
1183
1184     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1185                                   "Reference minimum: input and output types are mismatched");
1186
1187     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1188                                   "Reference minimum: shapes are not suitable for implicit broadcast.");
1189
1190     return supported;
1191 }
1192
1193 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1194                                                 const TensorInfo& input1,
1195                                                 const TensorInfo& output,
1196                                                 Optional<std::string&> reasonIfUnsupported) const
1197 {
1198     bool supported = true;
1199
1200     std::array<DataType,3> supportedTypes = {
1201         DataType::Float32,
1202         DataType::QuantisedAsymm8,
1203         DataType::QuantisedSymm16
1204     };
1205
1206     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1207                                   "Reference multiplication: input 0 is not a supported type.");
1208
1209     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1210                                   "Reference multiplication: input 1 is not a supported type.");
1211
1212     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1213                                   "Reference multiplication: output is not a supported type.");
1214
1215     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1216                                   "Reference multiplication: input 0 and Input 1 types are mismatched");
1217
1218     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1219                                   "Reference multiplication: input and output types are mismatched");
1220
1221     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1222                                   "Reference multiplication: shapes are not suitable for implicit broadcast.");
1223
1224     return supported;
1225 }
1226
1227 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1228                                                const TensorInfo& output,
1229                                                const NormalizationDescriptor& descriptor,
1230                                                Optional<std::string&> reasonIfUnsupported) const
1231 {
1232     ignore_unused(descriptor);
1233
1234     // Define supported types
1235     std::array<DataType, 4> supportedTypes =
1236     {
1237         DataType::Float16,
1238         DataType::Float32,
1239         DataType::QuantisedAsymm8,
1240         DataType::QuantisedSymm16
1241     };
1242
1243     bool supported = true;
1244
1245     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1246                                   "Reference normalization: input type not supported.");
1247
1248     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1249                                   "Reference normalization: output type not supported.");
1250
1251     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1252                                   "Reference normalization: input and output shapes have different "
1253                                   "num total elements.");
1254
1255     return supported;
1256 }
1257
1258 bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
1259                                         Optional<std::string&> reasonIfUnsupported) const
1260 {
1261     return true;
1262 }
1263
1264 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
1265                                      const TensorInfo& output,
1266                                      const PadDescriptor& descriptor,
1267                                      Optional<std::string&> reasonIfUnsupported) const
1268 {
1269     ignore_unused(descriptor);
1270     bool supported = true;
1271
1272     // Define supported output and inputs types.
1273     std::array<DataType,3> supportedTypes =
1274     {
1275         DataType::Float32,
1276         DataType::QuantisedAsymm8,
1277         DataType::QuantisedSymm16
1278     };
1279
1280     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1281                                   "Reference pad: input is not a supported type.");
1282
1283     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1284                                   "Reference pad: output is not a supported type.");
1285
1286     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1287                                   "Reference pad: input and output types are mismatched.");
1288
1289     return supported;
1290 }
1291
1292 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
1293                                          const TensorInfo& output,
1294                                          const PermuteDescriptor& descriptor,
1295                                          Optional<std::string&> reasonIfUnsupported) const
1296 {
1297     ignore_unused(descriptor);
1298     bool supported = true;
1299
1300     // Define supported output and inputs types.
1301     std::array<DataType,3> supportedTypes =
1302     {
1303         DataType::Float32,
1304         DataType::QuantisedAsymm8,
1305         DataType::QuantisedSymm16
1306     };
1307
1308     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1309                                   "Reference permute: input is not a supported type.");
1310
1311     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1312                                   "Reference permute: output is not a supported type.");
1313
1314     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1315                                   "Reference permute: input and output types are mismatched.");
1316
1317     return supported;
1318 }
1319
1320 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1321                                            const TensorInfo& output,
1322                                            const Pooling2dDescriptor& descriptor,
1323                                            Optional<std::string&> reasonIfUnsupported) const
1324 {
1325     ignore_unused(descriptor);
1326     bool supported = true;
1327
1328     // Define supported output and inputs types.
1329     std::array<DataType,3> supportedTypes =
1330     {
1331         DataType::Float32,
1332         DataType::QuantisedAsymm8,
1333         DataType::QuantisedSymm16
1334     };
1335
1336     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1337                                   "Reference poolind2d: input is not a supported type.");
1338
1339     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1340                                   "Reference poolind2d: output is not a supported type.");
1341
1342     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1343                                   "Reference poolind2d: input and output types are mismatched.");
1344
1345     return supported;
1346 }
1347
1348 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1349                                           const TensorInfo& output,
1350                                           Optional<std::string&> reasonIfUnsupported) const
1351 {
1352    bool supported = true;
1353
1354     // Define supported output types.
1355     std::array<DataType,1> supportedInputTypes = {
1356         DataType::Float32,
1357     };
1358
1359     supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1360                                   "Reference quantize: input type not supported.");
1361
1362     // Define supported output types.
1363     std::array<DataType,2> supportedOutputTypes = {
1364         DataType::QuantisedAsymm8,
1365         DataType::QuantisedSymm16
1366     };
1367     supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1368                                   "Reference quantize: output type not supported.");
1369
1370     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1371                                   "Reference quantize: input and output shapes have different num total elements.");
1372
1373     return supported;
1374 }
1375
1376 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
1377                                          const ReshapeDescriptor& descriptor,
1378                                          Optional<std::string&> reasonIfUnsupported) const
1379 {
1380     ignore_unused(descriptor);
1381     // Define supported output types.
1382     std::array<DataType,4> supportedOutputTypes =
1383     {
1384         DataType::Float32,
1385         DataType::Float16,
1386         DataType::QuantisedAsymm8,
1387         DataType::QuantisedSymm16
1388     };
1389     return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1390         "Reference reshape: input type not supported.");
1391 }
1392
1393 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
1394                                                 const TensorInfo& output,
1395                                                 Optional<std::string&> reasonIfUnsupported) const
1396 {
1397     bool supported = true;
1398     std::array<DataType,3> supportedTypes =
1399     {
1400         DataType::Float32,
1401         DataType::QuantisedAsymm8,
1402         DataType::QuantisedSymm16
1403     };
1404
1405     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1406                                   "Reference ResizeBilinear: input type not supported");
1407
1408     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1409                                   "Reference ResizeBilinear: output type not supported");
1410
1411     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1412                                   "Reference ResizeBilinear: input and output types not matching");
1413
1414     return supported;
1415 }
1416
1417 bool RefLayerSupport::IsResizeSupported(const TensorInfo& input,
1418                                         const TensorInfo& output,
1419                                         const ResizeDescriptor& descriptor,
1420                                         Optional<std::string&> reasonIfUnsupported) const
1421 {
1422     bool supported = true;
1423     std::array<DataType,3> supportedTypes =
1424     {
1425         DataType::Float32,
1426         DataType::QuantisedAsymm8,
1427         DataType::QuantisedSymm16
1428     };
1429
1430     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1431                                   "Reference Resize: input type not supported");
1432
1433     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1434                                   "Reference Resize: output type not supported");
1435
1436     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1437                                   "Reference Resize: input and output types not matching");
1438
1439     return supported;
1440 }
1441
1442 bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1443                                        const TensorInfo& output,
1444                                        Optional<std::string&> reasonIfUnsupported) const
1445 {
1446     bool supported = true;
1447     std::array<DataType,3> supportedTypes =
1448     {
1449             DataType::Float32,
1450             DataType::QuantisedAsymm8,
1451             DataType::QuantisedSymm16
1452     };
1453
1454     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1455                                   "Reference rsqrt: input type not supported");
1456
1457     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1458                                   "Reference rsqrt: output type not supported");
1459
1460     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1461                                   "Reference rsqrt: input and output types not matching");
1462
1463     supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1464                                   "Reference Rsqrt: input and output shapes have different number of total elements");
1465
1466     return supported;
1467 }
1468
1469 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1470                                          const TensorInfo& output,
1471                                          const SoftmaxDescriptor& descriptor,
1472                                          Optional<std::string&> reasonIfUnsupported) const
1473 {
1474     ignore_unused(output);
1475     bool supported = true;
1476     std::array<DataType,3> supportedTypes =
1477     {
1478             DataType::Float32,
1479             DataType::QuantisedAsymm8,
1480             DataType::QuantisedSymm16
1481     };
1482
1483     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1484                                   "Reference concatenation: output type not supported");
1485
1486     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1487                                   "Reference concatenation: input type not supported");
1488
1489     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1490                                   "Reference concatenation: input type not supported");
1491
1492     return supported;
1493 }
1494
1495 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1496                                                 const TensorInfo& output,
1497                                                 const SpaceToBatchNdDescriptor& descriptor,
1498                                                 Optional<std::string&> reasonIfUnsupported) const
1499 {
1500     ignore_unused(output);
1501     bool supported = true;
1502     std::array<DataType,3> supportedTypes =
1503     {
1504             DataType::Float32,
1505             DataType::QuantisedAsymm8,
1506             DataType::QuantisedSymm16
1507     };
1508
1509     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1510                                   "Reference SpaceToBatchNd: input type not supported");
1511
1512     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1513                                   "Reference SpaceToBatchNd: output type not supported");
1514
1515     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1516                                   "Reference SpaceToBatchNd: input and output types are mismatched");
1517
1518     return supported;
1519 }
1520
1521 bool RefLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1522                                               const TensorInfo& output,
1523                                               const SpaceToDepthDescriptor& descriptor,
1524                                               Optional<std::string&> reasonIfUnsupported) const
1525 {
1526
1527     ignore_unused(descriptor);
1528     bool supported = true;
1529
1530     std::array<DataType,3> supportedTypes =
1531     {
1532         DataType::Float32,
1533         DataType::QuantisedAsymm8,
1534         DataType::QuantisedSymm16
1535     };
1536
1537     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1538         "Reference SpaceToDepth: input type not supported");
1539
1540     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1541         "Reference SpaceToDepth: output type not supported");
1542
1543     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1544         "Reference SpaceToDepth: input and output types are mismatched");
1545
1546     return supported;
1547 }
1548
1549 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1550                                           const ViewsDescriptor& descriptor,
1551                                           Optional<std::string&> reasonIfUnsupported) const
1552 {
1553     ignore_unused(descriptor);
1554     bool supported = true;
1555     std::array<DataType,3> supportedTypes =
1556     {
1557         DataType::Float32,
1558         DataType::QuantisedAsymm8,
1559         DataType::QuantisedSymm16
1560     };
1561
1562     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1563                                   "Reference splitter: input type not supported");
1564
1565     return supported;
1566 }
1567
1568 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1569                                           const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1570                                           const ViewsDescriptor& descriptor,
1571                                           Optional<std::string&> reasonIfUnsupported) const
1572 {
1573     ignore_unused(descriptor);
1574     bool supported = true;
1575     std::array<DataType,3> supportedTypes =
1576     {
1577         DataType::Float32,
1578         DataType::QuantisedAsymm8,
1579         DataType::QuantisedSymm16
1580     };
1581
1582     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1583                                   "Reference splitter: output type not supported");
1584     for (const TensorInfo output : outputs)
1585     {
1586         supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1587                                       "Reference splitter: input type not supported");
1588
1589         supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1590                                       "Reference splitter: input and output types mismatched.");
1591     }
1592
1593     return supported;
1594 }
1595
1596 bool RefLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1597                                        const TensorInfo& output,
1598                                        const StackDescriptor& descriptor,
1599                                        Optional<std::string&> reasonIfUnsupported) const
1600 {
1601     ignore_unused(descriptor);
1602
1603     bool supported = true;
1604     std::array<DataType,3> supportedTypes =
1605     {
1606         DataType::Float32,
1607         DataType::QuantisedAsymm8,
1608         DataType::QuantisedSymm16
1609     };
1610
1611     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1612                                   "Reference stack: output type not supported");
1613     for (const TensorInfo* input : inputs)
1614     {
1615         BOOST_ASSERT(input != nullptr);
1616         supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
1617             "Reference stack: input type not supported");
1618
1619         supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
1620             "Reference stack: input and output types mismatched.");
1621     }
1622
1623     return supported;
1624 }
1625
1626 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1627                                               const TensorInfo& output,
1628                                               const StridedSliceDescriptor& descriptor,
1629                                               Optional<std::string&> reasonIfUnsupported) const
1630 {
1631     ignore_unused(descriptor);
1632     bool supported = true;
1633
1634     std::array<DataType,3> supportedTypes =
1635     {
1636         DataType::Float32,
1637         DataType::QuantisedAsymm8,
1638         DataType::QuantisedSymm16
1639     };
1640
1641     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1642                                   "Reference StridedSlice: input type not supported");
1643
1644     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1645                                   "Reference StridedSlice: output type not supported");
1646
1647     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1648                                   "Reference StridedSlice: input and output types are mismatched");
1649
1650     return supported;
1651 }
1652
1653 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1654                                              const TensorInfo& input1,
1655                                              const TensorInfo& output,
1656                                              Optional<std::string&> reasonIfUnsupported) const
1657 {
1658     bool supported = true;
1659
1660     std::array<DataType,3> supportedTypes = {
1661         DataType::Float32,
1662         DataType::QuantisedAsymm8,
1663         DataType::QuantisedSymm16
1664     };
1665
1666     supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1667                                   "Reference subtraction: input 0 is not a supported type.");
1668
1669     supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1670                                   "Reference subtraction: input 1 is not a supported type.");
1671
1672     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1673                                   "Reference subtraction: output is not a supported type.");
1674
1675     supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1676                                   "Reference subtraction: input 0 and Input 1 types are mismatched");
1677
1678     supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1679                                   "Reference subtraction: input and output types are mismatched");
1680
1681     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1682                                   "Reference subtraction: shapes are not suitable for implicit broadcast.");
1683
1684     return supported;
1685 }
1686
1687 bool RefLayerSupport::IsPreluSupported(const TensorInfo& input,
1688                                        const TensorInfo& alpha,
1689                                        const TensorInfo& output,
1690                                        Optional<std::string&> reasonIfUnsupported) const
1691 {
1692     bool supported = true;
1693
1694     std::array<DataType, 3> supportedTypes
1695     {
1696         DataType::Float32,
1697         DataType::QuantisedAsymm8,
1698         DataType::QuantisedSymm16
1699     };
1700
1701     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1702                                   "PReLU: input is not a supported type.");
1703
1704     supported &= CheckSupportRule(TypeAnyOf(alpha, supportedTypes), reasonIfUnsupported,
1705                                   "PReLU: alpha is not a supported type.");
1706
1707     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1708                                   "PReLU: output is not a supported type.");
1709
1710     supported &= CheckSupportRule(TypesAreEqual(input, alpha, output), reasonIfUnsupported,
1711                                   "PReLU: input, alpha and output types are mismatched");
1712
1713     supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input, alpha, output), reasonIfUnsupported,
1714                                   "PReLU: shapes are not suitable for implicit broadcast");
1715
1716     return supported;
1717 }
1718
1719 bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1720                                                         const TensorInfo& output,
1721                                                         const TransposeConvolution2dDescriptor& descriptor,
1722                                                         const TensorInfo& weights,
1723                                                         const Optional<TensorInfo>& biases,
1724                                                         Optional<std::string&> reasonIfUnsupported) const
1725 {
1726     ignore_unused(descriptor);
1727
1728     bool supported = true;
1729
1730     std::array<DataType,3> supportedTypes =
1731     {
1732             DataType::Float32,
1733             DataType::QuantisedAsymm8,
1734             DataType::QuantisedSymm16
1735     };
1736
1737     supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1738                                   "Reference TransposeConvolution2d: input is not a supported type.");
1739
1740     supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1741                                   "Reference TransposeConvolution2d: output is not a supported type.");
1742
1743     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
1744                                   "Reference TransposeConvolution2d: weights is not a supported type.");
1745
1746     supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1747                                   "Reference TransposeConvolution2d: input and output types mismatched.");
1748
1749     supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
1750                                   "Reference TransposeConvolution2d: input and weights types mismatched.");
1751
1752     if (biases.has_value())
1753     {
1754         std::array<DataType,3> biasesSupportedTypes = {
1755                 DataType::Float32,
1756                 DataType::Signed32
1757         };
1758         supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
1759                                       "Reference TransposeConvolution2d: biases is not a supported type.");
1760     }
1761
1762     return supported;
1763 }
1764
1765 } // namespace armnn