From: narpra01 Date: Wed, 16 Jan 2019 17:22:19 +0000 (+0000) Subject: IVGCVSW-2509 Add GatherLayer implementation X-Git-Tag: submit/tizen/20200316.035456~945 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=33f8e3b6c71070fd867809ca6934069a950081dc;p=platform%2Fupstream%2Farmnn.git IVGCVSW-2509 Add GatherLayer implementation * implementation of ValidateTensorShapesFromInputs * unit tests Change-Id: I1ed88f8ba0ea20329a259c5f36caea4b1fbeb013 --- diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 2e5d011..d7ed4b2 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -32,6 +32,32 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const void GatherLayer::ValidateTensorShapesFromInputs() { + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + + const unsigned int paramsDim = params.GetNumDimensions(); + const unsigned int indicesDim = indices.GetNumDimensions(); + const unsigned int outputDim = paramsDim - 1 + indicesDim; + + std::vector dimSizes; + + for (unsigned int i = 0; i < indicesDim; ++i) + { + dimSizes.push_back(indices.GetShape()[i]); + } + for (unsigned int i = 1; i < paramsDim; ++i) + { + dimSizes.push_back(params.GetShape()[i]); + } + + const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + + ConditionalThrowIfNotEqual( + "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShape); } } // namespace armnn diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp index 80addb4..3b07986 100644 --- a/src/armnn/test/OptimizerTests.cpp +++ b/src/armnn/test/OptimizerTests.cpp @@ -995,4 +995,59 @@ BOOST_AUTO_TEST_CASE(ResizeBilinearValidateTensorShapesFromInputsNhwc) BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); } + +void CreateGatherGraph(Graph& graph, const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo, + const armnn::TensorInfo& outputInfo) +{ + Layer* input0 = graph.AddLayer(0, "params"); + input0->GetOutputSlot().SetTensorInfo(paramsInfo); + + Layer* input1 = graph.AddLayer(1, "indices"); + input1->GetOutputSlot().SetTensorInfo(indicesInfo); + + GatherLayer* layer = graph.AddLayer("gather"); + layer->GetOutputSlot().SetTensorInfo(outputInfo); + + Layer* output = graph.AddLayer(0, "output"); + input0->GetOutputSlot().Connect(layer->GetInputSlot(0)); + input1->GetOutputSlot().Connect(layer->GetInputSlot(1)); + layer->GetOutputSlot().Connect(output->GetInputSlot(0)); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs) +{ + Graph graph; + armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32); + armnn::TensorInfo indicesInfo({3}, DataType::Signed32); + armnn::TensorInfo outputInfo({3, 5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputs1DParams) +{ + Graph graph; + armnn::TensorInfo paramsInfo({8}, DataType::Float32); + armnn::TensorInfo indicesInfo({5}, DataType::Signed32); + armnn::TensorInfo outputInfo( {5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + +BOOST_AUTO_TEST_CASE(GatherValidateTensorShapesFromInputsMultiDimIndices) +{ + Graph graph; + armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32); + armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32); + armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32); + + CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo); + + BOOST_CHECK_NO_THROW(graph.InferTensorInfos()); +} + BOOST_AUTO_TEST_SUITE_END()