MLCE-160 Error loading quantized model containing BatchNorm Layer
authorMike Kelly <mike.kelly@arm.com>
Wed, 4 Mar 2020 18:01:13 +0000 (18:01 +0000)
committermike.kelly <mike.kelly@arm.com>
Mon, 9 Mar 2020 15:22:58 +0000 (15:22 +0000)
 * Relaxed restrictions in BatchNormalizationQueueDescriptor::Validate

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I3101971c2101e90144bbbf7b63367cb0ef09573f

src/backends/backendsCommon/WorkloadData.cpp
src/backends/backendsCommon/test/WorkloadDataValidation.cpp

index 9b7a242..dbd1158 100644 (file)
@@ -1120,7 +1120,6 @@ void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInf
     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
 
     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
-    ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
 
     ValidatePointer(m_Mean,     descriptorName, "mean");
index 5c60e9e..d48a2bb 100644 (file)
@@ -19,7 +19,31 @@ using namespace armnn;
 
 BOOST_AUTO_TEST_SUITE(WorkloadInfoValidation)
 
+BOOST_AUTO_TEST_CASE(BatchNormalizationQueueDescriptor_Validate_DifferentQuantizationData)
+{
+    TensorShape inputShape { 1, 3, 2, 2 };
+    TensorShape outputShape { 1, 3, 2, 2 };
+
+    TensorInfo inputTensorInfo(inputShape, armnn::DataType::QAsymmU8, .1f, 125);
+    TensorInfo outputTensorInfo(outputShape, armnn::DataType::QAsymmU8, .2f, 120);
+
+    BatchNormalizationQueueDescriptor invalidData;
+    WorkloadInfo                      invalidInfo;
 
+    unsigned int sameShape[] = { 10 };
+    TensorInfo sameInfo = armnn::TensorInfo(1, sameShape, armnn::DataType::QAsymmU8);
+    ScopedCpuTensorHandle sameTensor(sameInfo);
+
+    AddInputToWorkload(invalidData, invalidInfo, inputTensorInfo, nullptr);
+    AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
+
+    invalidData.m_Mean = &sameTensor;
+    invalidData.m_Variance = &sameTensor;
+    invalidData.m_Beta= &sameTensor;
+    invalidData.m_Gamma = &sameTensor;
+
+    BOOST_CHECK_NO_THROW(RefBatchNormalizationWorkload(invalidData, invalidInfo));
+}
 
 BOOST_AUTO_TEST_CASE(QueueDescriptor_Validate_WrongNumOfInputsOutputs)
 {