MLCE-103 Remove hardcoded output shape in ModelAccuracyTool
authorSiCong Li <sicong.li@arm.com>
Fri, 21 Jun 2019 15:02:40 +0000 (16:02 +0100)
committerJames Conroy <james.conroy@arm.com>
Tue, 9 Jul 2019 14:14:41 +0000 (14:14 +0000)
We can obtain the output tensor shape from the model provided by
the user.

Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: I5074734315174c1b5dc8eea1eff18a4a1c566f2a

tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp

index bb0d824..85241e8 100644 (file)
@@ -194,6 +194,9 @@ int main(int argc, char* argv[])
                 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
             const unsigned int inputTensorHeight =
                 inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
+            // Get output tensor info
+            const unsigned int outputNumElements = model.GetOutputSize();
+
             const unsigned int batchSize = 1;
             // Get normalisation parameters
             SupportedFrontend modelFrontend;
@@ -232,7 +235,7 @@ int main(int argc, char* argv[])
                             normParams,
                             batchSize,
                             inputTensorDataLayout));
-                        outputDataContainers = {vector<int>(1001)};
+                        outputDataContainers = { vector<int>(outputNumElements) };
                         break;
                     case armnn::DataType::QuantisedAsymm8:
                         inputDataContainers.push_back(
@@ -241,7 +244,7 @@ int main(int argc, char* argv[])
                             normParams,
                             batchSize,
                             inputTensorDataLayout));
-                        outputDataContainers = {vector<uint8_t>(1001)};
+                        outputDataContainers = { vector<uint8_t>(outputNumElements) };
                         break;
                     case armnn::DataType::Float32:
                     default:
@@ -251,7 +254,7 @@ int main(int argc, char* argv[])
                             normParams,
                             batchSize,
                             inputTensorDataLayout));
-                        outputDataContainers = {vector<float>(1001)};
+                        outputDataContainers = { vector<float>(outputNumElements) };
                         break;
                 }