2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include "backends/WorkloadData.hpp"
10 #include <boost/cast.hpp>
11 #include <boost/format.hpp>
12 #include <boost/log/trivial.hpp>
19 void InputSlot::Insert(Layer& layer)
21 BOOST_ASSERT(layer.GetNumOutputSlots() == 1);
23 OutputSlot* const prevSlot = GetConnectedOutputSlot();
25 if (prevSlot != nullptr)
27 // Disconnect parent from this
28 prevSlot->Disconnect(*this);
30 // Connect inserted layer to parent
31 BOOST_ASSERT(layer.GetNumInputSlots() == 1);
32 prevSlot->Connect(layer.GetInputSlot(0));
34 // Set tensor info for inserted layer
35 const TensorInfo& tensorInfo = prevSlot->GetTensorInfo();
36 layer.GetOutputHandler().SetTensorInfo(tensorInfo);
39 // Connect inserted layer to this
40 layer.GetOutputSlot(0).Connect(*this);
43 const InputSlot* OutputSlot::GetConnection(unsigned int index) const
45 ValidateConnectionIndex(index);
46 return m_Connections[index];
49 InputSlot* OutputSlot::GetConnection(unsigned int index)
51 ValidateConnectionIndex(index);
52 return m_Connections[index];
55 void OutputSlot::SetTensorInfo(const TensorInfo& tensorInfo)
57 GetOutputHandler().SetTensorInfo(tensorInfo);
60 const TensorInfo& OutputSlot::GetTensorInfo() const
62 return GetOutputHandler().GetTensorInfo();
65 bool OutputSlot::IsTensorInfoSet() const
67 return GetOutputHandler().IsTensorInfoSet();
70 bool OutputSlot::ValidateTensorShape(const TensorShape& shape) const
72 BOOST_ASSERT_MSG(IsTensorInfoSet(), "TensorInfo must be set in order to validate the shape.");
73 return shape == m_OutputHandler.GetTensorInfo().GetShape();
76 int OutputSlot::Connect(InputSlot& destination)
78 destination.SetConnection(this);
79 m_Connections.push_back(&destination);
80 return boost::numeric_cast<int>(m_Connections.size() - 1);
83 void OutputSlot::Disconnect(InputSlot& slot)
85 slot.SetConnection(nullptr);
86 m_Connections.erase(std::remove(m_Connections.begin(), m_Connections.end(), &slot), m_Connections.end());
89 void OutputSlot::DisconnectAll()
91 while (GetNumConnections() > 0)
93 InputSlot& connection = *GetConnection(0);
94 Disconnect(connection);
98 void OutputSlot::MoveAllConnections(OutputSlot& destination)
100 while (GetNumConnections() > 0)
102 InputSlot& connection = *GetConnection(0);
103 Disconnect(connection);
104 destination.Connect(connection);
108 void OutputSlot::ValidateConnectionIndex(unsigned int index) const
110 if (boost::numeric_cast<std::size_t>(index) >= m_Connections.size())
112 throw InvalidArgumentException(
113 boost::str(boost::format("GetConnection: Invalid index %1% provided") % index));
118 LayerGuid GenerateLayerGuid()
120 //Note: Not thread safe.
121 static LayerGuid newGuid=0;
126 Layer::Layer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name)
127 : m_OutputHandlers(numOutputSlots)
128 , m_LayerName(name ? name : "")
130 , m_ComputeDevice(Compute::Undefined)
131 , m_Guid(GenerateLayerGuid())
133 m_InputSlots.reserve(numInputSlots);
134 for (unsigned int i = 0; i < numInputSlots; ++i)
136 m_InputSlots.emplace_back(*this, i);
139 m_OutputSlots.reserve(numOutputSlots);
140 for (unsigned int i = 0; i < numOutputSlots; ++i)
142 m_OutputSlots.emplace_back(*this, m_OutputHandlers[i]);
146 void Layer::CollectWorkloadInputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
148 for (auto&& inputSlot : GetInputSlots())
150 // The graph must be well-formed at this point
151 BOOST_ASSERT(inputSlot.GetConnection());
152 const OutputHandler& outputHandler = inputSlot.GetConnectedOutputSlot()->GetOutputHandler();
153 dataCollector.Push(outputHandler.GetData(), outputHandler.GetTensorInfo());
157 void Layer::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
159 for (auto&& outputHandler : m_OutputHandlers)
161 outputHandler.CollectWorkloadOutputs(dataCollector);
165 void Layer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
167 for (auto&& outputHandler : m_OutputHandlers)
169 outputHandler.CreateTensorHandles(factory);
173 DataType Layer::GetDataType() const
175 if (GetNumInputSlots() > 0) // Ignore the input layer
177 return GetInputSlot(0).GetConnection()->GetTensorInfo().GetDataType();
179 return DataType::Float32;
182 void Layer::ResetPriority() const
188 LayerPriority Layer::GetPriority() const
190 constexpr LayerPriority inputPrio = std::numeric_limits<LayerPriority>::lowest();
191 constexpr LayerPriority outputPrio = std::numeric_limits<LayerPriority>::max();
193 if (GetType() == LayerType::Input)
195 m_Priority = inputPrio;
197 else if (GetType() == LayerType::Output)
199 m_Priority = outputPrio;
201 else if (m_Priority == 0)
205 throw GraphValidationException("Graph has circular dependencies: cannot walk");
208 auto maxPrio = [](const LayerPriority prio, const InputSlot& slot) -> LayerPriority
210 const Layer& input = slot.GetConnectedOutputSlot()->GetOwningLayer();
211 return std::max(prio, input.GetPriority());
215 LayerPriority parentPrio = std::accumulate(GetInputSlots().cbegin(), GetInputSlots().cend(), 0U, maxPrio);
218 if (parentPrio >= outputPrio)
220 throw GraphValidationException("Graph has too many edges");
223 m_Priority = parentPrio + 1U;