MLCE-159 Add QAsymmS8 to ArmnnQuantizer
authorFrancis Murtagh <francis.murtagh@arm.com>
Tue, 10 Mar 2020 13:51:45 +0000 (13:51 +0000)
committerJim Flynn <jim.flynn@arm.com>
Tue, 10 Mar 2020 14:02:00 +0000 (14:02 +0000)
 * Allow per layer quantization from Fp32 to Int8 (QAsymmS8) like TfLite

Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: I5bbf770aa29d81af3568c15b47d2b2c18e55bb28

src/armnnDeserializer/Deserializer.cpp
src/armnnQuantizer/ArmNNQuantizerMain.cpp
src/armnnQuantizer/CommandLineProcessor.cpp
src/armnnSerializer/ArmnnSchema.fbs
src/armnnSerializer/SerializerUtils.cpp
src/backends/backendsCommon/WorkloadData.cpp
src/backends/backendsCommon/test/WorkloadTestUtils.hpp
src/backends/reference/RefLayerSupport.cpp

index 1f7c360..bc6fbf0 100644 (file)
@@ -505,6 +505,9 @@ armnn::TensorInfo ToTensorInfo(Deserializer::TensorRawPtr tensorPtr)
 
     switch (tensorPtr->dataType())
     {
+        case DataType_QAsymmS8:
+            type = armnn::DataType::QAsymmS8;
+            break;
         case DataType_QuantisedAsymm8:
         case DataType_QAsymmU8:
             type = armnn::DataType::QAsymmU8;
index 30167e7..219363e 100644 (file)
@@ -36,9 +36,19 @@ int main(int argc, char* argv[])
     inputFileStream.close();
 
     armnn::QuantizerOptions quantizerOptions;
-    quantizerOptions.m_ActivationFormat = cmdline.GetQuantizationScheme() == "QSymm16"
-                                          ? armnn::DataType::QSymmS16
-                                          : armnn::DataType::QAsymmU8;
+
+    if (cmdline.GetQuantizationScheme() == "QAsymmS8")
+    {
+        quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8;
+    }
+    else if (cmdline.GetQuantizationScheme() == "QSymmS16")
+    {
+        quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16;
+    }
+    else
+    {
+        quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8;
+    }
 
     quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
 
index d2163c0..0cccb66 100644 (file)
@@ -67,8 +67,10 @@ bool ValidateQuantizationScheme(const std::string& scheme)
         return false;
     }
 
-    std::vector<std::string> supportedSchemes = {
-        "QAsymm8",
+    std::vector<std::string> supportedSchemes =
+    {
+        "QAsymmS8",
+        "QAsymmU8",
         "QSymm16"
     };
 
@@ -93,8 +95,10 @@ bool CommandLineProcessor::ProcessCommandLine(int argc, char* argv[])
                 ("help,h", "Display help messages")
                 ("infile,f", po::value<std::string>(&m_InputFileName)->required(),
                              "Input file containing float 32 ArmNN Input Graph")
-                ("scheme,s", po::value<std::string>(&m_QuantizationScheme)->default_value("QAsymm8"),
-                              "Quantization scheme, \"QAsymm8\" or \"QSymm16\", default value QAsymm8")
+                ("scheme,s", po::value<std::string>(&m_QuantizationScheme)->default_value("QAsymmU8"),
+                              "Quantization scheme,"
+                              " \"QAsymmU8\" or \"QAsymmS8\" or \"QSymm16\","
+                              " default value QAsymmU8")
                 ("csvfile,c", po::value<std::string>(&m_CsvFileName)->default_value(""),
                              "CSV file containing paths for RAW input tensors")
                 ("preserve-data-type,p", po::bool_switch(&m_PreserveDataType)->default_value(false),
index d7565a5..ca3db5d 100644 (file)
@@ -37,7 +37,8 @@ enum DataType : byte {
     Boolean = 4,
     QuantisedSymm16 = 5, // deprecated
     QAsymmU8 = 6,
-    QSymmS16 = 7
+    QSymmS16 = 7,
+    QAsymmS8 = 8
 }
 
 enum DataLayout : byte {
index 02a5ed3..c184771 100644 (file)
@@ -58,6 +58,8 @@ armnnSerializer::DataType GetFlatBufferDataType(armnn::DataType dataType)
             return armnnSerializer::DataType::DataType_Signed32;
         case armnn::DataType::QSymmS16:
             return armnnSerializer::DataType::DataType_QSymmS16;
+        case armnn::DataType::QAsymmS8:
+            return armnnSerializer::DataType::DataType_QAsymmS8;
         case armnn::DataType::QAsymmU8:
             return armnnSerializer::DataType::DataType_QAsymmU8;
         case armnn::DataType::Boolean:
index dbd1158..bb0c21f 100644 (file)
@@ -994,6 +994,7 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
     {
         DataType::Float32,
         DataType::Float16,
+        DataType::QAsymmS8,
         DataType::QAsymmU8,
         DataType::QSymmS16
     };
@@ -1183,8 +1184,8 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
     std::vector<DataType> supportedTypes =
     {
         DataType::Float32,
-        DataType::QAsymmU8,
         DataType::QAsymmS8,
+        DataType::QAsymmU8,
         DataType::QSymmS16,
         DataType::QSymmS8,
         DataType::Float16
index 0b0f265..5168333 100644 (file)
@@ -98,6 +98,8 @@ inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Option
         case armnn::DataType::Float16:
         case armnn::DataType::Float32:
             return weightsType;
+        case armnn::DataType::QAsymmS8:
+            return armnn::DataType::Signed32;
         case armnn::DataType::QAsymmU8:
             return armnn::DataType::Signed32;
         case armnn::DataType::QSymmS16:
index bd2e728..cb94955 100644 (file)
@@ -815,11 +815,12 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
     bool supported = true;
 
     // Define supported types.
-    std::array<DataType,4> supportedTypes =
+    std::array<DataType,5> supportedTypes =
     {
             DataType::Float32,
             DataType::Float16,
             DataType::QAsymmU8,
+            DataType::QAsymmS8,
             DataType::QSymmS16
     };
 
@@ -835,8 +836,29 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
     supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
                                   "Reference Fully Connected: weights type not supported.");
 
-    supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
-                                  "Reference Fully Connected: input and weight types mismatched.");
+    ARMNN_NO_DEPRECATE_WARN_BEGIN
+    std::array<DataType, 3> supportedWeightTypes =
+    {
+            DataType::QAsymmU8,
+            DataType::QSymmS8,
+            DataType::QuantizedSymm8PerAxis // deprecated
+    };
+    ARMNN_NO_DEPRECATE_WARN_END
+
+    if (IsQuantized8BitType(input.GetDataType()))
+    {
+
+        supported &= CheckSupportRule(TypeAnyOf(weights, supportedWeightTypes), reasonIfUnsupported,
+                                      "Reference Fully Connected: weights type not supported for quantized input.");
+    }
+    else
+    {
+        supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
+                                      "Reference Fully Connected: weights is not a supported type.");
+
+        supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported,
+                                      "Reference Fully Connected: input and weights types mismatched.");
+    }
 
     if (descriptor.m_BiasEnabled)
     {