IVGCVSW-4375 Add parser support for Transpose
authorMike Kelly <mike.kelly@arm.com>
Mon, 2 Mar 2020 11:41:31 +0000 (11:41 +0000)
committermike.kelly <mike.kelly@arm.com>
Tue, 3 Mar 2020 10:40:38 +0000 (10:40 +0000)
 * Changed TfParser::ParseTranspose to use Transpose instead of Permute
 * Changed TfLiteParser::ParseTranspose to use Transpose instead of Permute

!armnn:2787

Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: If48f2fb88d97d31d66b6b1e631b41637d8e4c8f0

src/armnnTfLiteParser/TfLiteParser.cpp
src/armnnTfParser/TfParser.cpp

index f5c01f2..56b59a1 100644 (file)
@@ -1011,7 +1011,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
     armnn::IConnectableLayer* layer = nullptr;
     auto layerName = boost::str(boost::format("Transpose:%1%:%2%") % subgraphIndex % operatorIndex);
 
-    PermuteDescriptor desc;
+    TransposeDescriptor desc;
 
     if (inputs.size() == 2)
     {
@@ -1020,23 +1020,12 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
         auto numPermVecElements = permuteTensorInfo.GetNumElements();
         std::vector<unsigned int> permuteShape(numPermVecElements);
         ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+        PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
 
-        // permuteShape assumes Tf/Np permute vectors, we must translate to armnn expected form
-        // to do so we find the perm vector which would invert what a tf perm vector would do (ex 3,0,1,2 -> 1,2,3,0)
-        std::vector<unsigned int> armnnPermuteShape(numPermVecElements);
-        std::vector<unsigned int>::iterator it;
-        for (unsigned int i = 0u; i < numPermVecElements; ++i)
-        {
-            it = std::find(permuteShape.begin(), permuteShape.end(), i);
-            armnnPermuteShape[i] = static_cast<unsigned int>(std::distance(permuteShape.begin(), it));
-        }
-
-        PermutationVector permutationVector(armnnPermuteShape.data(), permuteTensorInfo.GetNumElements());
-
-        desc = PermuteDescriptor(permutationVector);
+        desc = TransposeDescriptor(permutationVector);
     }
 
-    layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
+    layer = m_Network->AddTransposeLayer(desc, layerName.c_str());
 
     BOOST_ASSERT(layer != nullptr);
 
index 124c5fd..1383331 100755 (executable)
@@ -10,6 +10,7 @@
 
 #include <armnnUtils/Permute.hpp>
 #include <armnnUtils/DataLayoutIndexed.hpp>
+#include <armnnUtils/Transpose.hpp>
 
 #include <GraphTopologicalSort.hpp>
 #include <ParserHelper.hpp>
@@ -2084,26 +2085,19 @@ ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef
     std::vector<int32_t> permuteVectorData;
     permuteVectorInput->GetConstTensor(permuteVectorData);
 
-    std::vector<unsigned int>      armnnPermuteVectorData(permuteVectorData.size());
-    std::vector<int32_t>::iterator it;
-
-    for (unsigned int i = 0u; i < permuteVectorData.size(); ++i)
-    {
-        it                        = std::find(permuteVectorData.begin(), permuteVectorData.end(), i);
-        armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it));
-    }
+    std::vector<unsigned int>      armnnPermuteVectorData(permuteVectorData.begin(), permuteVectorData.end());
 
     const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements());
-    const auto desc              = PermuteDescriptor(permutationVector);
+    const auto desc              = TransposeDescriptor(permutationVector);
 
-    auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str());
+    auto* layer = m_Network->AddTransposeLayer(desc, nodeDef.name().c_str());
     BOOST_ASSERT(layer);
 
     input0Slot->Connect(layer->GetInputSlot(0));
 
     const auto&       input0Info = input0Slot->GetTensorInfo();
     armnn::TensorInfo outputInfo {input0Info};
-    outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings));
+    outputInfo.SetShape(armnnUtils::TransposeTensorShape(input0Info.GetShape(), desc.m_DimMappings));
     layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
 
     return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);