IVGCVSW-5197 Add support for 2nd input to ExpandDims of TfParser
authorJan Eilers <jan.eilers@arm.com>
Tue, 8 Sep 2020 07:57:40 +0000 (08:57 +0100)
committerKeithARM <keith.davis@arm.com>
Thu, 10 Sep 2020 09:23:30 +0000 (09:23 +0000)
 * ParseExpandDims did not support to pass the axis parameter as
   a second input tensor
 * Added related unit tests

Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I8217950f0b42beaf5b9eaebdcad04267e4443ba3

src/armnnTfParser/TfParser.cpp
src/armnnTfParser/test/ExpandDims.cpp

index 38202fc..0d7c371 100755 (executable)
@@ -24,7 +24,7 @@
 
 #include <boost/format.hpp>
 #include <boost/numeric/conversion/cast.hpp>
-#include <armnn/utility/PolymorphicDowncast.hpp>
+#include <fmt/core.h>
 #include <numeric>
 
 using namespace armnnUtils;
@@ -1464,7 +1464,9 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
     return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
 }
 
-TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef,
+                                   TensorInfo inputTensorInfo,
+                                   std::int32_t expandDim)
 {
     ARMNN_ASSERT(nodeDef.op() == "ExpandDims");
 
@@ -1478,8 +1480,6 @@ TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInf
                         % CHECK_LOCATION().AsString()));
     }
 
-    std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
-
     std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions());
     std::vector<uint32_t> outputDims;
 
@@ -1542,13 +1542,78 @@ TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInf
 ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
 {
     IgnoreUnused(graphDef);
-    std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
 
+    // Number of inputs can either
+    // be 1 - that indicates that the axis parameter is passed as an attribute of the operation
+    // or 2 - which means that the axis parameter is passed as a second input
+    std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+    const std::size_t numInputs = nodes.size();
+    std::vector<OutputOfParsedTfOperation> inputs;
+    std::int32_t expandDim; // axis or dim parameter. Describes which dimension to expand.
+    if (numInputs == 1)
+    {
+        inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+        expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
+    }
+    else
+    {
+        inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+        // make sure data type is int32
+        IOutputSlot& prevLayerOutputSlot = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+        TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+        if (inputTensorInfo.GetDataType()!=armnn::DataType::Signed32)
+        {
+            throw ParseException(
+                    fmt::format(
+                            "The axis parameter of ExpandDims operation given as second input is not of type int32. "
+                            "Input {0} Node {1} {2}",
+                            inputs[1].m_IndexedValue->GetNode().name(),
+                            nodeDef.name(),
+                            CHECK_LOCATION().AsString()));
+        }
+
+        // ensure the second input is a constant value
+        if (!HasParsedConstTensor<int32_t>(inputs[1].m_IndexedValue->GetNode().name()))
+        {
+            throw ParseException(
+                    fmt::format(
+                            "ArmNN only supports ExpandDims layers with constant axis/dim parameter. "
+                            "Input {0} Node {1} {2}",
+                            inputs[1].m_IndexedValue->GetNode().name(),
+                            nodeDef.name(),
+                            CHECK_LOCATION().AsString()));
+        }
+
+        // make sure the second input is scalar or contains only a single value
+        // (we don't support expand dims for multiple axis but we don't care what shape the
+        //  given tensor has as long as there is only a single value in it
+        //  e.g. a tensor like this [[[1]]] is completely fine)
+        if (inputTensorInfo.GetNumElements() != 1)
+        {
+            throw ParseException(
+                    fmt::format(
+                            "The axis parameter of ExpandDims operation given as second input is not "
+                            "allowed to hold more than one value. "
+                            "Input {0} Node {1} {2}",
+                            inputs[1].m_IndexedValue->GetNode().name(),
+                            nodeDef.name(),
+                            CHECK_LOCATION().AsString()));
+        }
+
+        ParsedConstTfOperation<int32_t>* expandDimsNode =
+                PolymorphicDowncast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue);
+
+        memcpy(&expandDim, expandDimsNode->GetStorage(), sizeof(expandDim));
+    }
+
+    // First input is the vector that should be expanded by another dimension
     IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
     TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
 
     TensorInfo outputInfo;
-    outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo);
+    outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo, expandDim);
 
     ReshapeDescriptor reshapeDesc;
     reshapeDesc.m_TargetShape = outputInfo.GetShape();
index 57d472d..ad95641 100644 (file)
@@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
                 armnn::TensorShape({2, 1, 3, 5})));
 }
 
+struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+    ExpandDimsAsInputFixture(const std::string& expandDim,
+                             const bool wrongDataType = false,
+                             const std::string& numElements = "1")
+    {
+        std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
+        std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
+
+        m_Prototext = R"(
+        node {
+            name: "a"
+            op: "Placeholder"
+            attr {
+                key: "dtype"
+                value {
+                    type: DT_FLOAT
+                }
+            }
+            attr {
+                key: "shape"
+                value {
+                    shape {
+                        dim {
+                            size: 1
+                        }
+                        dim {
+                            size: 4
+                        }
+                    }
+                }
+            }
+        }
+        node {
+            name: "b"
+            op: "Const"
+            attr {
+                key: "dtype"
+                value {
+                    type:  )" + dataType + R"(
+                }
+            }
+            attr {
+                key: "value"
+                value {
+                    tensor {
+                        dtype: )" + dataType + R"(
+                        tensor_shape {
+                            dim {
+                                size: )" + numElements + R"(
+                            }
+                        }
+                        )" + val + R"(
+                    }
+                }
+            }
+        }
+        node {
+            name: "ExpandDims"
+            op: "ExpandDims"
+            input: "a"
+            input: "b"
+            attr {
+                key: "T"
+                value {
+                    type: DT_FLOAT
+                }
+            }
+            attr {
+                key: "Tdim"
+                value {
+                    type: DT_INT32
+                }
+            }
+        }
+        versions {
+            producer: 134
+        })";
+    }
+};
+
+struct ExpandDimAsInput : ExpandDimsAsInputFixture
+{
+    ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
+    {
+        Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
+    }
+};
+
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
+{
+    // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+    BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+                armnn::TensorShape({1, 1, 4})));
+}
+
+struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
+{
+    ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
+{
+    // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+    // Axis parameter is of wrong data type (float instead of int32)
+    BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
+{
+    ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
+{
+    // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+    // Axis parameter is of wrong shape
+    BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+    ExpandDimsAsNotConstInputFixture()
+    {
+        m_Prototext = R"(
+            node {
+                name: "a"
+                op: "Placeholder"
+                attr {
+                    key: "dtype"
+                    value {
+                        type: DT_FLOAT
+                    }
+                }
+                attr {
+                    key: "shape"
+                    value {
+                        shape {
+                            dim {
+                                size: 1
+                            }
+                            dim {
+                            size: 4
+                            }
+                        }
+                    }
+                }
+            }
+            node {
+                name: "b"
+                op: "Placeholder"
+                attr {
+                    key: "dtype"
+                        value {
+                            type: DT_INT32
+                        }
+                }
+                attr {
+                    key: "shape"
+                    value {
+                        shape {
+                            dim {
+                                size: 1
+                            }
+                        }
+                    }
+                }
+            }
+            node {
+                name: "ExpandDims"
+                op: "ExpandDims"
+                input: "a"
+                input: "b"
+                attr {
+                    key: "T"
+                        value {
+                            type: DT_FLOAT
+                        }
+                    }
+                    attr {
+                        key: "Tdim"
+                        value {
+                            type: DT_INT32
+                        }
+                    }
+                }
+            versions {
+                producer: 134
+            })";
+    }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
+{
+    // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
+    // But is not a constant tensor --> not supported
+    BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
+                        armnn::ParseException);
+}
+
 BOOST_AUTO_TEST_SUITE_END()