IVGCVSW-4108 Fixed invalid data type exception
authorMike Kelly <mike.kelly@arm.com>
Fri, 8 Nov 2019 12:08:35 +0000 (12:08 +0000)
committermike.kelly <mike.kelly@arm.com>
Fri, 8 Nov 2019 14:13:13 +0000 (14:13 +0000)
 * Added support for QuantizedSymm8PerAxis to ArmComputeTensorUtils.

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ib8662f216bc4b6b54e0099780f73bcf6ef05384b

src/backends/aclCommon/ArmComputeTensorUtils.cpp

index b2955b9..b0a8ba1 100644 (file)
@@ -17,6 +17,8 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
 {
     switch(dataType)
     {
+        case armnn::DataType::Boolean:
+            return arm_compute::DataType::U8;
         case armnn::DataType::Float16:
             return arm_compute::DataType::F16;
         case armnn::DataType::Float32:
@@ -25,10 +27,10 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
             return arm_compute::DataType::QASYMM8;
         case armnn::DataType::QuantisedSymm16:
             return arm_compute::DataType::QSYMM16;
+        case armnn::DataType::QuantizedSymm8PerAxis:
+            return arm_compute::DataType::QSYMM8_PER_CHANNEL;
         case armnn::DataType::Signed32:
             return arm_compute::DataType::S32;
-        case armnn::DataType::Boolean:
-            return arm_compute::DataType::U8;
         default:
             BOOST_ASSERT_MSG(false, "Unknown data type");
             return arm_compute::DataType::UNKNOWN;
@@ -212,14 +214,16 @@ arm_compute::PixelValue GetPixelValue(arm_compute::ITensor& input, float pixelVa
 {
     switch (input.info()->data_type())
     {
-        case arm_compute::DataType::QASYMM8:
-            return arm_compute::PixelValue(static_cast<uint8_t>(pixelValue));
-        case arm_compute::DataType::QSYMM16:
-            return arm_compute::PixelValue(static_cast<int16_t>(pixelValue));
         case arm_compute::DataType::F16:
             return arm_compute::PixelValue(static_cast<Half>(pixelValue));
         case arm_compute::DataType::F32:
             return arm_compute::PixelValue(pixelValue);
+        case arm_compute::DataType::QASYMM8:
+            return arm_compute::PixelValue(static_cast<uint8_t>(pixelValue));
+        case arm_compute::DataType::QSYMM16:
+            return arm_compute::PixelValue(static_cast<int16_t>(pixelValue));
+        case arm_compute::DataType::QSYMM8_PER_CHANNEL:
+            return arm_compute::PixelValue(static_cast<int8_t>(pixelValue));
         default:
             throw InvalidArgumentException("Unsupported DataType: [" +
                                            std::to_string(static_cast<int>(input.info()->data_type())) + "]");