IVGCVSW-3909 Fix Transpose perm vector not parsed by Tflite parser
authorKevin May <kevin.may@arm.com>
Fri, 27 Sep 2019 16:21:06 +0000 (17:21 +0100)
committerKevin May <kevin.may@arm.com>
Fri, 27 Sep 2019 16:21:06 +0000 (17:21 +0100)
    * Add permute vector to descriptor if present
    * Refactor test to check with and without permute vector

Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: Ic8d882bb0f982fd00bb2854c18ea316b1b2cde2b

src/armnnTfLiteParser/TfLiteParser.cpp
src/armnnTfLiteParser/test/Transpose.cpp

index 939640a..da81c0a 100644 (file)
@@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
 
     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
-    CHECK_VALID_SIZE(inputs.size(), 2);
+    CHECK_VALID_SIZE(inputs.size(), 1, 2);
 
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
@@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
 
     PermuteDescriptor desc;
 
+    if(inputs.size() == 2)
+    {
+        armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]);
+        BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+        std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements());
+        ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+
+        PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
+
+        desc =  PermuteDescriptor(permutationVector);
+    }
+
     layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
 
     BOOST_ASSERT(layer != nullptr);
index 4430438..2e3190b 100644 (file)
@@ -12,6 +12,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
 struct TransposeFixture : public ParserFlatbuffersFixture
 {
     explicit TransposeFixture(const std::string & inputShape,
+                              const std::string & permuteData,
                               const std::string & outputShape)
     {
         m_JsonString = R"(
@@ -29,8 +30,8 @@ struct TransposeFixture : public ParserFlatbuffersFixture
                         {
                           "shape": )" + inputShape + R"(,
                           "type": "FLOAT32",
-                          "buffer": 3,
-                          "name": "Placeholder",
+                          "buffer": 0,
+                          "name": "inputTensor",
                           "quantization": {
                             "min": [
                               0.0
@@ -46,28 +47,33 @@ struct TransposeFixture : public ParserFlatbuffersFixture
                         {
                           "shape": )" + outputShape + R"(,
                           "type": "FLOAT32",
-                          "buffer": 2,
-                          "name": "transpose",
-                          "quantization": {
-                            "details_type": 0,
-                            "quantized_dimension": 0
-                          },
-                          "is_variable": false
-                        },
-                        {
-                          "shape": [
-                            3
-                          ],
-                          "type": "INT32",
                           "buffer": 1,
-                          "name": "transpose/perm",
+                          "name": "outputTensor",
                           "quantization": {
                             "details_type": 0,
                             "quantized_dimension": 0
                           },
                           "is_variable": false
-                        }
-                      ],
+                        })";
+        if (!permuteData.empty())
+        {
+            m_JsonString += R"(,
+                              {
+                                "shape": [
+                                  3
+                                ],
+                                "type": "INT32",
+                                "buffer": 2,
+                                "name": "permuteTensor",
+                                "quantization": {
+                                  "details_type": 0,
+                                  "quantized_dimension": 0
+                                },
+                                "is_variable": false
+                              })";
+        }
+
+        m_JsonString += R"(],
                       "inputs": [
                         0
                       ],
@@ -78,9 +84,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
                         {
                           "opcode_index": 0,
                           "inputs": [
-                            0,
-                            2
-                          ],
+                            0)";
+        if (!permuteData.empty())
+        {
+            m_JsonString += R"(,2)";
+        }
+        m_JsonString += R"(],
                           "outputs": [
                             1
                           ],
@@ -95,9 +104,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
                   "description": "TOCO Converted.",
                   "buffers": [
                     { },
-                    { },
-                    { },
-                    { }
+                    { })";
+        if (!permuteData.empty())
+        {
+            m_JsonString += R"(,{"data": )" + permuteData + R"( })";
+        }
+        m_JsonString += R"(
                   ]
                 }
         )";
@@ -105,20 +117,39 @@ struct TransposeFixture : public ParserFlatbuffersFixture
     }
 };
 
-struct SimpleTransposeFixture : TransposeFixture
+struct TransposeFixtureWithPermuteData : TransposeFixture
 {
-    SimpleTransposeFixture() : TransposeFixture("[ 2, 2, 3 ]",
-                                                "[ 2, 3, 2 ]") {}
+    TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+                                                         "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
+                                                         "[ 2, 3, 2 ]") {}
 };
 
-BOOST_FIXTURE_TEST_CASE(SimpleTranspose, SimpleTransposeFixture)
+BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
 {
     RunTest<3, armnn::DataType::Float32>(
       0,
-      {{"Placeholder", {  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+      {{"inputTensor", {  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+      {{"outputTensor", {  1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
+
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+                == armnn::TensorShape({2,3,2})));
+}
+
+struct TransposeFixtureWithoutPermuteData : TransposeFixture
+{
+    TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+                                                            "",
+                                                            "[ 2, 3, 2 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
+{
+    RunTest<3, armnn::DataType::Float32>(
+        0,
+        {{"inputTensor", {  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+        {{"outputTensor", {  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
 
-      {{"transpose", {  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
-    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "transpose").second.GetShape()
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
                 == armnn::TensorShape({2,3,2})));
 }