IVGCVSW-4190 Add SplitV to Tflite Parser
authorRyan OShea <Ryan.OShea2@arm.com>
Tue, 26 May 2020 10:41:04 +0000 (11:41 +0100)
committerJim Flynn <jim.flynn@arm.com>
Tue, 2 Jun 2020 16:34:30 +0000 (16:34 +0000)
 * Refactored SplitV
 * Added unit tests
 * Updated Documentation

Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com>
Change-Id: If1dfa5a8780ddf3fe8788ed7bf7fa5fa8dfd14ec

CMakeLists.txt
docs/01_parsers.dox
src/armnnTfLiteParser/TfLiteParser.cpp
src/armnnTfLiteParser/test/SplitV.cpp [new file with mode: 0644]

index 8f8060a..a855cba 100644 (file)
@@ -793,6 +793,7 @@ if(BUILD_UNIT_TESTS)
              src/armnnTfLiteParser/test/SpaceToBatchND.cpp
              src/armnnTfLiteParser/test/Slice.cpp
              src/armnnTfLiteParser/test/Split.cpp
+             src/armnnTfLiteParser/test/SplitV.cpp
              src/armnnTfLiteParser/test/Squeeze.cpp
              src/armnnTfLiteParser/test/StridedSlice.cpp
              src/armnnTfLiteParser/test/Sub.cpp
index e6b4a28..1c52c4a 100644 (file)
@@ -179,6 +179,7 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators
 - SOFTMAX
 - SPACE_TO_BATCH
 - SPLIT
+- SPLIT_V
 - SQUEEZE
 - STRIDED_SLICE
 - SUB
index 53b49f4..c695caa 100644 (file)
@@ -2683,7 +2683,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
 
     const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
-
+    const auto * options = operatorPtr->builtin_options.AsSplitVOptions();
 
     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(inputs.size(), 3);
@@ -2717,66 +2717,67 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
     ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
     const unsigned int splitDim = ComputeWrappedIndex(axisData[0], inputTensorInfo.GetNumDimensions());
 
-
     // Set split sizes
-    const auto * options = operatorPtr->builtin_options.AsSplitOptions();
     CHECK_VALID_SIZE(splitsInfo.GetNumDimensions(), 1);
-    unsigned int numSplits = 0;
     std::vector<int> splitsData(0);
-    if (options)
+    unsigned int numSplits{0};
+
+    if(options)
     {
         numSplits = CHECKED_NON_NEGATIVE(options->num_splits);
-        splitsData.resize(numSplits);
-
-        if (inputTensorInfo.GetShape()[splitDim] % numSplits != 0)
-        {
-            throw ParseException("Number of splits must evenly divide the split axis");
-        }
-        unsigned int splitSize = inputTensorInfo.GetShape()[splitDim] / numSplits;
-        for (auto& split : splitsData)
-        {
-            split = numeric_cast<int>(splitSize);
-        }
     }
     else
     {
-        numSplits = splitsInfo.GetShape()[0];
-        splitsData.resize(numSplits);
+        numSplits = splitsInfo.GetNumElements();
+    }
+
+    if (numSplits <=0)
+    {
+        throw ParseException("SplitV has invalid number of splits");
+    }
 
-        BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer);
-        ::memcpy(splitsData.data(), splitsBufferPtr->data.data(), splitsInfo.GetNumBytes());
+    splitsData.resize(numSplits);
+    BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer);
+    unsigned int idx{0};
 
-        int numInferred = 0;
-        int specifiedSizes = 0;
-        unsigned int inferIdx = 0;
-        unsigned int idx = 0;
-        for (auto split : splitsData)
+    for(auto& split: splitsData)
+    {
+        split = splitsBufferPtr->data[idx];
+        idx++;
+    }
+
+    idx = 0;
+    int numInferred{0};
+    unsigned int inferIdx{0};
+    int splitSum{0};
+    for (auto split : splitsData)
+    {
+        if (split < 0)
         {
-            if (split < 0)
-            {
-                numInferred++;
-                inferIdx = idx;
-            }
-            else
-            {
-                specifiedSizes += split;
-            }
-            idx++;
+            numInferred++;
+            inferIdx = idx;
         }
-
-        if (numInferred > 0)
+        else
         {
-            if (numInferred > 1)
-            {
-                throw ParseException("Cannot infer split size for more than one split");
-            }
-            splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - specifiedSizes;
+            splitSum += split;
         }
+        idx++;
     }
-
-    if (numSplits <=0)
+    // Check for inferred Axis
+    if (numInferred == 0)
     {
-        throw ParseException("SplitV has invalid number of splits");
+        if (splitSum != numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
+        {
+            throw ParseException("SplitV split_sizes does not sum to the dimension of value along split_dim.");
+        }
+    }
+    else if (numInferred == 1)
+    {
+        splitsData[inferIdx] = numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
+    }
+    else
+    {
+        throw ParseException("Cannot infer split size for more than one split");
     }
 
     //Ouput size validation
@@ -2805,7 +2806,7 @@ void TfLiteParser::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
         accumSplit += splitSize;
     }
 
-    auto layerName = boost::str(boost::format("Split:%1%:%2%") % subgraphIndex % operatorIndex);
+    auto layerName = boost::str(boost::format("SplitV:%1%:%2%") % subgraphIndex % operatorIndex);
     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
 
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
diff --git a/src/armnnTfLiteParser/test/SplitV.cpp b/src/armnnTfLiteParser/test/SplitV.cpp
new file mode 100644 (file)
index 0000000..59afeec
--- /dev/null
@@ -0,0 +1,209 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct SplitVFixture : public ParserFlatbuffersFixture
+{
+    explicit SplitVFixture(const std::string& inputShape,
+                           const std::string& splitValues,
+                           const std::string& sizeSplitsShape,
+                           const std::string& axisShape,
+                           const std::string& numSplits,
+                           const std::string& outputShape1,
+                           const std::string& outputShape2,
+                           const std::string& axisData,
+                           const std::string& dataType)
+    {
+        m_JsonString = R"(
+            {
+                "version": 3,
+                "operator_codes": [ { "builtin_code": "SPLIT_V" } ],
+                "subgraphs": [ {
+                    "tensors": [
+                        {
+                            "shape": )" + inputShape + R"(,
+                            "type": )" + dataType + R"(,
+                            "buffer": 0,
+                            "name": "inputTensor",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + sizeSplitsShape + R"(,
+                            "type": "INT32",
+                            "buffer": 1,
+                            "name": "sizeSplits",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + axisShape + R"(,
+                            "type": "INT32",
+                            "buffer": 2,
+                            "name": "axis",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + outputShape1 + R"( ,
+                            "type":)" + dataType + R"(,
+                            "buffer": 3,
+                            "name": "outputTensor1",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        },
+                        {
+                            "shape": )" + outputShape2 + R"( ,
+                            "type":)" + dataType + R"(,
+                            "buffer": 4,
+                            "name": "outputTensor2",
+                            "quantization": {
+                                "min": [ 0.0 ],
+                                "max": [ 255.0 ],
+                                "scale": [ 1.0 ],
+                                "zero_point": [ 0 ],
+                            }
+                        }
+                    ],
+                    "inputs": [ 0, 1, 2 ],
+                    "outputs": [ 3, 4 ],
+                    "operators": [
+                        {
+                            "opcode_index": 0,
+                            "inputs": [ 0, 1, 2 ],
+                            "outputs": [ 3, 4 ],
+                            "builtin_options_type": "SplitVOptions",
+                            "builtin_options": {
+                                "num_splits": )" + numSplits + R"(
+                            },
+                            "custom_options_format": "FLEXBUFFERS"
+                        }
+                    ],
+                } ],
+                "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}]
+            }
+        )";
+
+        Setup();
+    }
+};
+
+/*
+ *  Tested inferred splitSizes with splitValues [-1, 1] locally.
+ */
+
+struct SimpleSplitVAxisOneFixture : SplitVFixture
+{
+    SimpleSplitVAxisOneFixture()
+        : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 3 ]", "[ 2 ]","[ ]", "2",
+                         "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitVTwo, SimpleSplitVAxisOneFixture)
+{
+    RunTest<4, armnn::DataType::Float32>(
+        0,
+        { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+                              9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
+                              17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
+                              25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
+        { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
+          {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
+                              17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
+                              25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
+}
+
+struct SimpleSplitVAxisTwoFixture : SplitVFixture
+{
+    SimpleSplitVAxisTwoFixture()
+        : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 1 ]", "[ 2 ]","[ ]", "2",
+                         "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAxisTwoSplitVTwo, SimpleSplitVAxisTwoFixture)
+{
+    RunTest<4, armnn::DataType::Float32>(
+        0,
+        { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+                              9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
+                              17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
+                              25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
+        { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+                              9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f,
+                              21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } },
+          {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
+}
+
+struct SimpleSplitVAxisThreeFixture : SplitVFixture
+{
+    SimpleSplitVAxisThreeFixture()
+        : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 3 ]", "[ 2 ]","[ ]", "2",
+                         "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAxisThreeSplitVTwo, SimpleSplitVAxisThreeFixture)
+{
+    RunTest<4, armnn::DataType::Float32>(
+        0,
+        { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+                              9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
+                              17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
+                              25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
+        { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } },
+          {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f,
+                              13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f,
+                              23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
+}
+
+struct SimpleSplitVAxisFourFixture : SplitVFixture
+{
+    SimpleSplitVAxisFourFixture()
+        : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 1 ]", "[ 2 ]","[ ]", "2",
+                         "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAxisFourSplitVTwo, SimpleSplitVAxisFourFixture)
+{
+    RunTest<4, armnn::DataType::Float32>(
+        0,
+        { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
+                              9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
+                              17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
+                              25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
+        { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f,
+                              11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f,
+                              22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} },
+          {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } );
+}
+
+BOOST_AUTO_TEST_SUITE_END()