Add support for all data type for input and output layers
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>
Tue, 4 Jun 2019 10:22:00 +0000 (11:22 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Tue, 4 Jun 2019 15:07:40 +0000 (15:07 +0000)
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I688f4db5f5950877ad88f637cf71c05270fd5338

src/backends/reference/RefLayerSupport.cpp
src/backends/reference/RefWorkloadFactory.cpp
src/backends/reference/test/RefEndToEndTests.cpp

index 1d0b230..edd552b 100644 (file)
@@ -701,10 +701,7 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
 bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
                                        Optional<std::string&> reasonIfUnsupported) const
 {
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &TrueFunc<>);
+    return true;
 }
 
 bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
@@ -950,13 +947,7 @@ bool RefLayerSupport::IsNormalizationSupported(const TensorInfo& input,
 bool RefLayerSupport::IsOutputSupported(const TensorInfo& output,
                                         Optional<std::string&> reasonIfUnsupported) const
 {
-    return IsSupportedForDataTypeGeneric(reasonIfUnsupported,
-                                         output.GetDataType(),
-                                         &TrueFunc<>,
-                                         &TrueFunc<>,
-                                         &TrueFunc<>,
-                                         &FalseFuncI32<>,
-                                         &TrueFunc<>);
+    return true;
 }
 
 bool RefLayerSupport::IsPadSupported(const TensorInfo& input,
index 500c14b..1610655 100644 (file)
@@ -87,7 +87,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescr
         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
     }
 
-    return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
+    return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
@@ -106,8 +106,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDes
         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
     }
 
-    return MakeWorkloadHelper<CopyMemGenericWorkload, CopyMemGenericWorkload,
-                              CopyMemGenericWorkload, NullWorkload, CopyMemGenericWorkload>(descriptor, info);
+    return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
 }
 
 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
index 8e75eba..885773d 100644 (file)
@@ -467,6 +467,16 @@ BOOST_AUTO_TEST_CASE(DequantizeEndToEndOffsetTest)
     DequantizeEndToEndOffset<armnn::DataType::QuantisedAsymm8>(defaultBackends);
 }
 
+BOOST_AUTO_TEST_CASE(DequantizeEndToEndSimpleInt16Test)
+{
+    DequantizeEndToEndSimple<armnn::DataType::QuantisedSymm16>(defaultBackends);
+}
+
+BOOST_AUTO_TEST_CASE(DequantizeEndToEndOffsetInt16Test)
+{
+    DequantizeEndToEndOffset<armnn::DataType::QuantisedSymm16>(defaultBackends);
+}
+
 BOOST_AUTO_TEST_CASE(RefDetectionPostProcessRegularNmsTest)
 {
     std::vector<float> boxEncodings({