Release 18.08
[platform/upstream/armnn.git] / src / armnn / layers / SplitterLayer.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "SplitterLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/TypesUtils.hpp>
10 #include <backends/WorkloadData.hpp>
11 #include <backends/WorkloadFactory.hpp>
12
13 namespace armnn
14 {
15
16 SplitterLayer::SplitterLayer(const ViewsDescriptor& param, const char* name)
17     : LayerWithParameters(1, param.GetNumViews(), LayerType::Splitter, param, name)
18 {
19 }
20
21 std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const
22 {
23     SplitterQueueDescriptor descriptor;
24
25     // Copies the window origins to the descriptor.
26     for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
27     {
28         descriptor.m_ViewOrigins.emplace_back(
29             std::vector<unsigned int>(m_Param.GetViewOrigin(i), m_Param.GetViewOrigin(i) + m_Param.GetNumDimensions()));
30     }
31
32     return factory.CreateSplitter(descriptor, PrepInfoAndDesc(descriptor, graph));
33 }
34
35 void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
36 {
37     //If sub tensors are supported than all the "splitter" need to do is to
38     //set the outputs to be appropriate sub tensors of the input.
39     if (factory.SupportsSubTensors())
40     {
41         const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
42
43         ITensorHandle* inputData = outputHandler.GetData();
44         //Creates the outputs as subtensors of the input.
45         for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
46         {
47             m_OutputHandlers[i].SetData(factory.CreateSubTensorHandle(*inputData,
48                                                                       m_OutputHandlers[i].GetTensorInfo().GetShape(),
49                                                                       m_Param.GetViewOrigin(i)));
50         }
51     }
52     else
53     {
54         for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
55         {
56             m_OutputHandlers[i].CreateTensorHandles(factory);
57         }
58     }
59 }
60
61 SplitterLayer* SplitterLayer::Clone(Graph& graph) const
62 {
63     return CloneBase<SplitterLayer>(graph, m_Param, GetName());
64 }
65
66 std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
67 {
68     BOOST_ASSERT(inputShapes.size() ==  m_Param.GetNumViews());
69     std::vector<TensorShape> outShapes;
70     //Output shapes must match View shapes.
71     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
72     {
73         const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
74         outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
75     }
76     return outShapes;
77 }
78
79 void SplitterLayer::ValidateTensorShapesFromInputs()
80 {
81     std::vector<TensorShape> views;
82     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
83     {
84         const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
85         views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
86     }
87
88     auto inferredShapes = InferOutputShapes(views);
89
90     BOOST_ASSERT(inferredShapes.size() == m_Param.GetNumViews());
91
92     for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
93     {
94         ConditionalThrowIfNotEqual<LayerValidationException>(
95             "SplitterLayer: View sizes must match output tensor shapes.",
96             GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
97             inferredShapes[viewIdx]);
98     }
99 }
100
101 } // namespace armnn