2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "RefLayerSupport.hpp"
7 #include "RefBackendId.hpp"
9 #include <InternalTypes.hpp>
10 #include <LayerSupportCommon.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Descriptors.hpp>
14 #include <backendsCommon/BackendRegistry.hpp>
15 #include <backendsCommon/test/WorkloadTestUtils.hpp>
17 #include <boost/core/ignore_unused.hpp>
23 using namespace boost;
31 template<typename Float32Func, typename Uint8Func, typename ... Params>
32 bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
34 Float32Func floatFuncPtr,
35 Uint8Func uint8FuncPtr,
38 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
40 &FalseFunc<Params...>,
43 &FalseFunc<Params...>,
44 &FalseFunc<Params...>,
45 std::forward<Params>(params)...);
48 } // anonymous namespace
54 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
56 bool supported = rule();
57 if (!supported && reason)
59 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
66 bool operator()() const
75 bool AllTypesAreEqualImpl(T t)
80 template<typename T, typename... Rest>
81 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
83 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
85 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
88 struct TypesAreEqual : public Rule
90 template<typename ... Ts>
91 TypesAreEqual(const Ts&... ts)
93 m_Res = AllTypesAreEqualImpl(ts...);
97 struct QuantizationParametersAreEqual : public Rule
99 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
101 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
102 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
106 struct TypeAnyOf : public Rule
108 template<typename Container>
109 TypeAnyOf(const TensorInfo& info, const Container& c)
111 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
113 return dt == info.GetDataType();
118 struct BiasAndWeightsTypesMatch : public Rule
120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
126 struct BiasAndWeightsTypesCompatible : public Rule
128 template<typename Container>
129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
138 struct ShapesAreSameRank : public Rule
140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
146 struct ShapesAreSameTotalSize : public Rule
148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
154 struct ShapesAreBroadcastCompatible : public Rule
156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
183 bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
184 const TensorInfo& output,
185 const ActivationDescriptor& descriptor,
186 Optional<std::string&> reasonIfUnsupported) const
188 bool supported = true;
190 // Define supported types.
191 std::array<DataType,3> supportedTypes = {
193 DataType::QuantisedAsymm8,
194 DataType::QuantisedSymm16
197 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
198 "Reference activation: input type not supported.");
200 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
201 "Reference activation: output type not supported.");
203 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
204 "Reference activation: input and output types mismatched.");
206 supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
207 "Reference activation: input and output shapes are of different rank.");
210 struct ActivationFunctionSupported : public Rule
212 ActivationFunctionSupported(const ActivationDescriptor& desc)
214 switch(desc.m_Function)
216 case ActivationFunction::Abs:
217 case ActivationFunction::BoundedReLu:
218 case ActivationFunction::LeakyReLu:
219 case ActivationFunction::Linear:
220 case ActivationFunction::ReLu:
221 case ActivationFunction::Sigmoid:
222 case ActivationFunction::SoftReLu:
223 case ActivationFunction::Sqrt:
224 case ActivationFunction::Square:
225 case ActivationFunction::TanH:
239 // Function is supported
240 supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
241 "Reference activation: function not supported.");
246 bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
247 const TensorInfo& input1,
248 const TensorInfo& output,
249 Optional<std::string&> reasonIfUnsupported) const
251 bool supported = true;
253 std::array<DataType,3> supportedTypes = {
255 DataType::QuantisedAsymm8,
256 DataType::QuantisedSymm16
259 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
260 "Reference addition: input 0 is not a supported type.");
262 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
263 "Reference addition: input 1 is not a supported type.");
265 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
266 "Reference addition: output is not a supported type.");
268 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
269 "Reference addition: input 0 and Input 1 types are mismatched");
271 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
272 "Reference addition: input and output types are mismatched");
274 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
275 "Reference addition: shapes are not suitable for implicit broadcast.");
280 bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
281 const TensorInfo& output,
282 const TensorInfo& mean,
283 const TensorInfo& var,
284 const TensorInfo& beta,
285 const TensorInfo& gamma,
286 const BatchNormalizationDescriptor& descriptor,
287 Optional<std::string&> reasonIfUnsupported) const
289 ignore_unused(output);
293 ignore_unused(gamma);
294 ignore_unused(descriptor);
295 return IsSupportedForDataTypeRef(reasonIfUnsupported,
301 bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
302 const TensorInfo& output,
303 const BatchToSpaceNdDescriptor& descriptor,
304 Optional<std::string&> reasonIfUnsupported) const
306 ignore_unused(descriptor);
307 return (IsSupportedForDataTypeRef(reasonIfUnsupported,
311 IsSupportedForDataTypeRef(reasonIfUnsupported,
312 output.GetDataType(),
317 bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
318 const TensorInfo& output,
319 const ConcatDescriptor& descriptor,
320 Optional<std::string&> reasonIfUnsupported) const
322 ignore_unused(descriptor);
324 bool supported = true;
325 std::array<DataType,3> supportedTypes =
328 DataType::QuantisedAsymm8,
329 DataType::QuantisedSymm16
332 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
333 "Reference concatenation: output type not supported");
334 for (const TensorInfo* input : inputs)
336 supported &= CheckSupportRule(TypeAnyOf(*input, supportedTypes), reasonIfUnsupported,
337 "Reference concatenation: input type not supported");
339 supported &= CheckSupportRule(TypesAreEqual(*input, output), reasonIfUnsupported,
340 "Reference concatenation: input and output types mismatched.");
346 bool RefLayerSupport::IsConstantSupported(const TensorInfo& output,
347 Optional<std::string&> reasonIfUnsupported) const
349 std::array<DataType,4> supportedTypes =
353 DataType::QuantisedAsymm8,
354 DataType::QuantisedSymm16
357 return CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
358 "Reference constant: output is not a supported type.");
361 bool RefLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
362 const TensorInfo& output,
363 Optional<std::string&> reasonIfUnsupported) const
365 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
368 &FalseInputFuncF32<>,
372 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
373 output.GetDataType(),
374 &FalseOutputFuncF16<>,
381 bool RefLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
382 const TensorInfo& output,
383 Optional<std::string&> reasonIfUnsupported) const
385 return (IsSupportedForDataTypeGeneric(reasonIfUnsupported,
387 &FalseInputFuncF16<>,
392 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
393 output.GetDataType(),
395 &FalseOutputFuncF32<>,
401 bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
402 const TensorInfo& output,
403 const Convolution2dDescriptor& descriptor,
404 const TensorInfo& weights,
405 const Optional<TensorInfo>& biases,
406 Optional<std::string&> reasonIfUnsupported) const
408 bool supported = true;
410 // Define supported types.
411 std::array<DataType,3> supportedTypes = {
413 DataType::QuantisedAsymm8,
414 DataType::QuantisedSymm16
417 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
418 "Reference addition: input is not a supported type.");
420 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
421 "Reference addition: output is not a supported type.");
423 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
424 "Reference addition: weights is not a supported type.");
426 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
427 "Reference activation: input and output types mismatched.");
429 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
430 "Reference activation: input and weights types mismatched.");
432 if (biases.has_value())
434 std::array<DataType,3> biasesSupportedTypes = {
438 supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported,
439 "Reference addition: biases is not a supported type.");
441 ignore_unused(descriptor);
446 bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
447 const TensorInfo& output,
448 Optional<std::string&> reasonIfUnsupported) const
450 ignore_unused(output);
451 return IsSupportedForDataTypeRef(reasonIfUnsupported,
457 bool RefLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
458 const TensorInfo& output,
459 const DepthwiseConvolution2dDescriptor& descriptor,
460 const TensorInfo& weights,
461 const Optional<TensorInfo>& biases,
462 Optional<std::string&> reasonIfUnsupported) const
464 ignore_unused(output);
465 ignore_unused(descriptor);
466 ignore_unused(weights);
467 ignore_unused(biases);
468 return IsSupportedForDataTypeRef(reasonIfUnsupported,
474 bool RefLayerSupport::IsDequantizeSupported(const TensorInfo& input,
475 const TensorInfo& output,
476 Optional<std::string&> reasonIfUnsupported) const
478 bool supported = true;
480 std::array<DataType,2> supportedInputTypes = {
481 DataType::QuantisedAsymm8,
482 DataType::QuantisedSymm16
485 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
486 "Reference dequantize: input type not supported.");
488 std::array<DataType,2> supportedOutputTypes = {
492 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
493 "Reference dequantize: output type not supported.");
495 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
496 "Reference dequantize: input and output shapes have different num total elements.");
501 bool RefLayerSupport::IsDetectionPostProcessSupported(const armnn::TensorInfo& input0,
502 const armnn::TensorInfo& input1,
503 const armnn::DetectionPostProcessDescriptor& descriptor,
504 armnn::Optional<std::string&> reasonIfUnsupported) const
506 ignore_unused(input1);
507 return IsSupportedForDataTypeRef(reasonIfUnsupported,
508 input0.GetDataType(),
513 bool RefLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
514 const TensorInfo& output,
515 const DepthwiseConvolution2dDescriptor& descriptor,
516 const TensorInfo& weights,
517 const Optional<TensorInfo>& biases,
518 Optional<std::string&> reasonIfUnsupported) const
520 if (descriptor.m_DilationY == 1 && descriptor.m_DilationY == 1)
522 return IsDepthwiseConvolutionSupported(input, output, descriptor, weights, biases, reasonIfUnsupported);
526 if (reasonIfUnsupported)
528 reasonIfUnsupported.value() = "Reference Depthwise Convolution: Dilation parameters must be 1";
535 bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
536 const TensorInfo& input1,
537 const TensorInfo& output,
538 Optional<std::string&> reasonIfUnsupported) const
540 bool supported = true;
542 std::array<DataType,3> supportedTypes = {
544 DataType::QuantisedAsymm8,
545 DataType::QuantisedSymm16
548 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
549 "Reference division: input 0 is not a supported type.");
551 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
552 "Reference division: input 1 is not a supported type.");
554 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
555 "Reference division: output is not a supported type.");
557 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
558 "Reference division: input 0 and Input 1 types are mismatched");
560 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
561 "Reference division: input and output types are mismatched");
563 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
564 "Reference division: shapes are not suitable for implicit broadcast.");
569 bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
570 const TensorInfo& input1,
571 const TensorInfo& output,
572 Optional<std::string&> reasonIfUnsupported) const
574 ignore_unused(input0);
575 ignore_unused(input1);
576 ignore_unused(output);
577 ignore_unused(reasonIfUnsupported);
578 return IsSupportedForDataTypeRef(reasonIfUnsupported,
579 input0.GetDataType(),
584 bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
585 const FakeQuantizationDescriptor& descriptor,
586 Optional<std::string&> reasonIfUnsupported) const
588 ignore_unused(descriptor);
589 return IsSupportedForDataTypeRef(reasonIfUnsupported,
595 bool RefLayerSupport::IsFloorSupported(const TensorInfo& input,
596 const TensorInfo& output,
597 Optional<std::string&> reasonIfUnsupported) const
599 ignore_unused(output);
600 return IsSupportedForDataTypeRef(reasonIfUnsupported,
606 bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
607 const TensorInfo& output,
608 const TensorInfo& weights,
609 const TensorInfo& biases,
610 const FullyConnectedDescriptor& descriptor,
611 Optional<std::string&> reasonIfUnsupported) const
613 bool supported = true;
615 // Define supported types.
616 std::array<DataType,3> supportedTypes =
619 DataType::QuantisedAsymm8,
620 DataType::QuantisedSymm16
623 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
624 "Reference Fully Connected: input type not supported.");
626 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
627 "Reference Fully Connected: output type not supported.");
629 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
630 "Reference Fully Connected: input and output types mismatched.");
632 supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
633 "Reference Fully Connected: weights type not supported.");
635 supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
636 "Reference Fully Connected: input and weight types mismatched.");
638 if (descriptor.m_BiasEnabled)
640 // Defined supported types for bias
641 std::array<DataType, 2>
648 supported &= CheckSupportRule(TypeAnyOf(biases, supportedBiasTypes), reasonIfUnsupported,
649 "Reference Fully Connected: bias type not supported.");
651 supported &= CheckSupportRule(BiasAndWeightsTypesMatch(biases, weights), reasonIfUnsupported,
652 "Reference Fully Connected: bias and weight types mismatch.");
654 supported &= CheckSupportRule(BiasAndWeightsTypesCompatible(weights, supportedBiasTypes), reasonIfUnsupported,
655 "Reference Fully Connected: bias type inferred from weights is incompatible.");
662 bool RefLayerSupport::IsGatherSupported(const armnn::TensorInfo& input0,
663 const armnn::TensorInfo& input1,
664 const armnn::TensorInfo& output,
665 armnn::Optional<std::string&> reasonIfUnsupported) const
667 ignore_unused(input1);
668 ignore_unused(output);
669 return IsSupportedForDataTypeRef(reasonIfUnsupported,
670 input0.GetDataType(),
675 bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
676 const TensorInfo& input1,
677 const TensorInfo& output,
678 Optional<std::string&> reasonIfUnsupported) const
680 ignore_unused(input0);
681 ignore_unused(input1);
682 ignore_unused(output);
683 ignore_unused(reasonIfUnsupported);
684 return IsSupportedForDataTypeRef(reasonIfUnsupported,
685 input0.GetDataType(),
690 bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
691 Optional<std::string&> reasonIfUnsupported) const
693 return IsSupportedForDataTypeRef(reasonIfUnsupported,
699 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
700 const TensorInfo& output,
701 const L2NormalizationDescriptor& descriptor,
702 Optional<std::string&> reasonIfUnsupported) const
704 ignore_unused(output);
705 ignore_unused(descriptor);
706 return IsSupportedForDataTypeRef(reasonIfUnsupported,
712 bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
713 const TensorInfo& outputStateIn,
714 const TensorInfo& cellStateIn,
715 const TensorInfo& scratchBuffer,
716 const TensorInfo& outputStateOut,
717 const TensorInfo& cellStateOut,
718 const TensorInfo& output,
719 const LstmDescriptor& descriptor,
720 const TensorInfo& inputToForgetWeights,
721 const TensorInfo& inputToCellWeights,
722 const TensorInfo& inputToOutputWeights,
723 const TensorInfo& recurrentToForgetWeights,
724 const TensorInfo& recurrentToCellWeights,
725 const TensorInfo& recurrentToOutputWeights,
726 const TensorInfo& forgetGateBias,
727 const TensorInfo& cellBias,
728 const TensorInfo& outputGateBias,
729 const TensorInfo* inputToInputWeights,
730 const TensorInfo* recurrentToInputWeights,
731 const TensorInfo* cellToInputWeights,
732 const TensorInfo* inputGateBias,
733 const TensorInfo* projectionWeights,
734 const TensorInfo* projectionBias,
735 const TensorInfo* cellToForgetWeights,
736 const TensorInfo* cellToOutputWeights,
737 Optional<std::string&> reasonIfUnsupported) const
739 ignore_unused(descriptor);
740 ignore_unused(inputToForgetWeights);
741 ignore_unused(inputToCellWeights);
742 ignore_unused(inputToOutputWeights);
743 ignore_unused(recurrentToForgetWeights);
744 ignore_unused(recurrentToCellWeights);
745 ignore_unused(recurrentToOutputWeights);
746 ignore_unused(forgetGateBias);
747 ignore_unused(cellBias);
748 ignore_unused(outputGateBias);
749 ignore_unused(inputToInputWeights);
750 ignore_unused(recurrentToInputWeights);
751 ignore_unused(cellToInputWeights);
752 ignore_unused(inputGateBias);
753 ignore_unused(projectionWeights);
754 ignore_unused(projectionBias);
755 ignore_unused(cellToForgetWeights);
756 ignore_unused(cellToOutputWeights);
758 bool supported = true;
760 std::array<DataType,2> supportedTypes = {
762 DataType::QuantisedSymm16
765 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
766 "Reference Lstm: input is not a supported type.");
768 supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
769 "Reference Lstm: input and outputStateIn types are mismatched");
771 supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
772 "Reference Lstm: input and cellStateIn types are mismatched");
774 supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
775 "Reference Lstm: input and scratchBuffer types are mismatched");
777 supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
778 "Reference Lstm: input and outputStateOut types are mismatched");
780 supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
781 "Reference Lstm: input and cellStateOut types are mismatched");
783 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
784 "Reference Lstm: input and output types are mismatched");
789 bool RefLayerSupport::IsMaximumSupported(const TensorInfo& input0,
790 const TensorInfo& input1,
791 const TensorInfo& output,
792 Optional<std::string&> reasonIfUnsupported) const
794 bool supported = true;
796 std::array<DataType,3> supportedTypes = {
798 DataType::QuantisedAsymm8,
799 DataType::QuantisedSymm16
802 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
803 "Reference maximum: input 0 is not a supported type.");
805 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
806 "Reference maximum: input 1 is not a supported type.");
808 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
809 "Reference maximum: output is not a supported type.");
811 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
812 "Reference maximum: input 0 and Input 1 types are mismatched");
814 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
815 "Reference maximum: input and output types are mismatched");
817 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
818 "Reference maximum: shapes are not suitable for implicit broadcast.");
823 bool RefLayerSupport::IsMeanSupported(const TensorInfo& input,
824 const TensorInfo& output,
825 const MeanDescriptor& descriptor,
826 Optional<std::string&> reasonIfUnsupported) const
828 ignore_unused(output);
829 ignore_unused(descriptor);
830 return IsSupportedForDataTypeRef(reasonIfUnsupported,
836 bool RefLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
837 const TensorInfo& output,
838 const MergerDescriptor& descriptor,
839 Optional<std::string&> reasonIfUnsupported) const
841 return IsConcatSupported(inputs, output, descriptor, reasonIfUnsupported);
844 bool RefLayerSupport::IsMemCopySupported(const TensorInfo &input,
845 const TensorInfo &output,
846 Optional<std::string &> reasonIfUnsupported) const
848 ignore_unused(output);
849 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
858 bool RefLayerSupport::IsMinimumSupported(const TensorInfo& input0,
859 const TensorInfo& input1,
860 const TensorInfo& output,
861 Optional<std::string&> reasonIfUnsupported) const
863 bool supported = true;
865 std::array<DataType,3> supportedTypes = {
867 DataType::QuantisedAsymm8,
868 DataType::QuantisedSymm16
871 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
872 "Reference minimum: input 0 is not a supported type.");
874 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
875 "Reference minimum: input 1 is not a supported type.");
877 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
878 "Reference minimum: output is not a supported type.");
880 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
881 "Reference minimum: input 0 and Input 1 types are mismatched");
883 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
884 "Reference minimum: input and output types are mismatched");
886 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
887 "Reference minimum: shapes are not suitable for implicit broadcast.");
892 bool RefLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
893 const TensorInfo& input1,
894 const TensorInfo& output,
895 Optional<std::string&> reasonIfUnsupported) const
897 bool supported = true;
899 std::array<DataType,3> supportedTypes = {
901 DataType::QuantisedAsymm8,
902 DataType::QuantisedSymm16
905 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
906 "Reference multiplication: input 0 is not a supported type.");
908 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
909 "Reference multiplication: input 1 is not a supported type.");
911 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
912 "Reference multiplication: output is not a supported type.");
914 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
915 "Reference multiplication: input 0 and Input 1 types are mismatched");
917 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
918 "Reference multiplication: input and output types are mismatched");
920 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
921 "Reference multiplication: shapes are not suitable for implicit broadcast.");
926 bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
927 const TensorInfo& output,
928 const NormalizationDescriptor& descriptor,
929 Optional<std::string&> reasonIfUnsupported) const
931 ignore_unused(output);
932 ignore_unused(descriptor);
933 return IsSupportedForDataTypeRef(reasonIfUnsupported,
939 bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
940 Optional<std::string&> reasonIfUnsupported) const
942 return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
943 output.GetDataType(),
951 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
952 const TensorInfo& output,
953 const PadDescriptor& descriptor,
954 Optional<std::string&> reasonIfUnsupported) const
956 ignore_unused(output);
957 ignore_unused(descriptor);
958 return IsSupportedForDataTypeRef(reasonIfUnsupported,
964 bool RefLayerSupport::IsPermuteSupported(const TensorInfo& input,
965 const TensorInfo& output,
966 const PermuteDescriptor& descriptor,
967 Optional<std::string&> reasonIfUnsupported) const
969 ignore_unused(output);
970 ignore_unused(descriptor);
971 return IsSupportedForDataTypeRef(reasonIfUnsupported,
977 bool RefLayerSupport::IsPooling2dSupported(const TensorInfo& input,
978 const TensorInfo& output,
979 const Pooling2dDescriptor& descriptor,
980 Optional<std::string&> reasonIfUnsupported) const
982 ignore_unused(output);
983 ignore_unused(descriptor);
984 return IsSupportedForDataTypeRef(reasonIfUnsupported,
990 bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
991 const TensorInfo& output,
992 Optional<std::string&> reasonIfUnsupported) const
994 bool supported = true;
996 // Define supported output types.
997 std::array<DataType,2> supportedInputTypes = {
1001 supported &= CheckSupportRule(TypeAnyOf(input, supportedInputTypes), reasonIfUnsupported,
1002 "Reference quantize: input type not supported.");
1004 // Define supported output types.
1005 std::array<DataType,2> supportedOutputTypes = {
1006 DataType::QuantisedAsymm8,
1007 DataType::QuantisedSymm16
1009 supported &= CheckSupportRule(TypeAnyOf(output, supportedOutputTypes), reasonIfUnsupported,
1010 "Reference quantize: output type not supported.");
1012 supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
1013 "Reference quantize: input and output shapes have different num total elements.");
1018 bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
1019 const ReshapeDescriptor& descriptor,
1020 Optional<std::string&> reasonIfUnsupported) const
1022 ignore_unused(descriptor);
1023 // Define supported output types.
1024 std::array<DataType,4> supportedOutputTypes =
1028 DataType::QuantisedAsymm8,
1029 DataType::QuantisedSymm16
1031 return CheckSupportRule(TypeAnyOf(input, supportedOutputTypes), reasonIfUnsupported,
1032 "Reference reshape: input type not supported.");
1035 bool RefLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
1036 const TensorInfo& output,
1037 Optional<std::string&> reasonIfUnsupported) const
1039 ignore_unused(output);
1040 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1041 input.GetDataType(),
1046 bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
1047 const TensorInfo& output,
1048 Optional<std::string&> reasonIfUnsupported) const
1050 ignore_unused(output);
1051 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1052 input.GetDataType(),
1057 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1058 const TensorInfo& output,
1059 const SoftmaxDescriptor& descriptor,
1060 Optional<std::string&> reasonIfUnsupported) const
1062 ignore_unused(output);
1063 bool supported = true;
1064 std::array<DataType,3> supportedTypes =
1067 DataType::QuantisedAsymm8,
1068 DataType::QuantisedSymm16
1071 supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
1072 "Reference concatenation: output type not supported");
1074 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1075 "Reference concatenation: input type not supported");
1077 supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
1078 "Reference concatenation: input type not supported");
1083 bool RefLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1084 const TensorInfo& output,
1085 const SpaceToBatchNdDescriptor& descriptor,
1086 Optional<std::string&> reasonIfUnsupported) const
1088 ignore_unused(output);
1089 ignore_unused(descriptor);
1090 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1091 input.GetDataType(),
1096 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1097 const ViewsDescriptor& descriptor,
1098 Optional<std::string&> reasonIfUnsupported) const
1100 ignore_unused(descriptor);
1101 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1102 input.GetDataType(),
1107 bool RefLayerSupport::IsSplitterSupported(const TensorInfo& input,
1108 const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1109 const ViewsDescriptor& descriptor,
1110 Optional<std::string&> reasonIfUnsupported) const
1112 ignore_unused(descriptor);
1113 ignore_unused(outputs);
1114 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1115 input.GetDataType(),
1120 bool RefLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1121 const TensorInfo& output,
1122 const StridedSliceDescriptor& descriptor,
1123 Optional<std::string&> reasonIfUnsupported) const
1125 ignore_unused(output);
1126 ignore_unused(descriptor);
1127 return IsSupportedForDataTypeRef(reasonIfUnsupported,
1128 input.GetDataType(),
1133 bool RefLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1134 const TensorInfo& input1,
1135 const TensorInfo& output,
1136 Optional<std::string&> reasonIfUnsupported) const
1138 bool supported = true;
1140 std::array<DataType,3> supportedTypes = {
1142 DataType::QuantisedAsymm8,
1143 DataType::QuantisedSymm16
1146 supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
1147 "Reference subtraction: input 0 is not a supported type.");
1149 supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
1150 "Reference subtraction: input 1 is not a supported type.");
1152 supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
1153 "Reference subtraction: output is not a supported type.");
1155 supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
1156 "Reference subtraction: input 0 and Input 1 types are mismatched");
1158 supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
1159 "Reference subtraction: input and output types are mismatched");
1161 supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
1162 "Reference subtraction: shapes are not suitable for implicit broadcast.");
1167 } // namespace armnn