IVGCVSW-2509 Add GatherLayer implementation
authornarpra01 <narumol.prangnawarat@arm.com>
Wed, 16 Jan 2019 17:22:19 +0000 (17:22 +0000)
committerAron Virginas-Tar <aron.virginas-tar@arm.com>
Tue, 22 Jan 2019 17:08:42 +0000 (17:08 +0000)
 * implementation of ValidateTensorShapesFromInputs
 * unit tests

Change-Id: I1ed88f8ba0ea20329a259c5f36caea4b1fbeb013

src/armnn/layers/GatherLayer.cpp
src/armnn/test/OptimizerTests.cpp

index 2e5d011..d7ed4b2 100644 (file)
@@ -32,6 +32,32 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const
 
 void GatherLayer::ValidateTensorShapesFromInputs()
 {
+    VerifyLayerConnections(2, CHECK_LOCATION());
+
+    const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo();
+    const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo();
+
+    const unsigned int paramsDim = params.GetNumDimensions();
+    const unsigned int indicesDim = indices.GetNumDimensions();
+    const unsigned int outputDim = paramsDim - 1 + indicesDim;
+
+    std::vector<unsigned int> dimSizes;
+
+    for (unsigned int i = 0; i < indicesDim; ++i)
+    {
+        dimSizes.push_back(indices.GetShape()[i]);
+    }
+    for (unsigned int i = 1; i < paramsDim; ++i)
+    {
+        dimSizes.push_back(params.GetShape()[i]);
+    }
+
+    const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+
+    ConditionalThrowIfNotEqual<LayerValidationException>(
+        "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+        GetOutputSlot(0).GetTensorInfo().GetShape(),
+        inferredShape);
 }
 
 } // namespace armnn
index 80addb4..3b07986 100644 (file)
@@ -995,4 +995,59 @@ BOOST_AUTO_TEST_CASE(ResizeBilinearValidateTensorShapesFromInputsNhwc)
     BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
 }
 
+
+void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo,
+                       const armnn::TensorInfo& outputInfo)
+{
+    Layer* input0 = graph.AddLayer<InputLayer>(0, "params");
+    input0->GetOutputSlot().SetTensorInfo(paramsInfo);
+
+    Layer* input1 = graph.AddLayer<InputLayer>(1, "indices");
+    input1->GetOutputSlot().SetTensorInfo(indicesInfo);
+
+    GatherLayer* layer = graph.AddLayer<GatherLayer>("gather");
+    layer->GetOutputSlot().SetTensorInfo(outputInfo);
+
+    Layer* output = graph.AddLayer<OutputLayer>(0, "output");
+    input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
+    input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
+    layer->GetOutputSlot().Connect(output->GetInputSlot(0));
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs)
+{
+    Graph graph;
+    armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32);
+    armnn::TensorInfo indicesInfo({3}, DataType::Signed32);
+    armnn::TensorInfo outputInfo({3, 5}, DataType::Float32);
+
+    CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+    BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs1DParams)
+{
+    Graph graph;
+    armnn::TensorInfo paramsInfo({8}, DataType::Float32);
+    armnn::TensorInfo indicesInfo({5}, DataType::Signed32);
+    armnn::TensorInfo outputInfo( {5}, DataType::Float32);
+
+    CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+    BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
+BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputsMultiDimIndices)
+{
+    Graph graph;
+    armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32);
+    armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32);
+    armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32);
+
+    CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
+
+    BOOST_CHECK_NO_THROW(graph.InferTensorInfos());
+}
+
 BOOST_AUTO_TEST_SUITE_END()