2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include "backends/WorkloadData.hpp"
10 #include <boost/cast.hpp>
11 #include <boost/format.hpp>
12 #include <boost/log/trivial.hpp>
19 void InputSlot::Insert(Layer& layer)
21 BOOST_ASSERT(layer.GetNumInputSlots() <= 1);
22 BOOST_ASSERT(layer.GetNumOutputSlots() == 1);
24 OutputSlot* const prevSlot = GetConnectedOutputSlot();
26 if (prevSlot != nullptr)
28 // Disconnect parent from this
29 prevSlot->Disconnect(*this);
31 // Connect inserted layer to parent
32 BOOST_ASSERT(layer.GetNumInputSlots() == 1);
33 prevSlot->Connect(layer.GetInputSlot(0));
35 // Set tensor info for inserted layer
36 const TensorInfo& tensorInfo = prevSlot->GetTensorInfo();
37 layer.GetOutputHandler().SetTensorInfo(tensorInfo);
40 // Connect inserted layer to this
41 layer.GetOutputSlot(0).Connect(*this);
44 const InputSlot* OutputSlot::GetConnection(unsigned int index) const
46 ValidateConnectionIndex(index);
47 return m_Connections[index];
50 InputSlot* OutputSlot::GetConnection(unsigned int index)
52 ValidateConnectionIndex(index);
53 return m_Connections[index];
56 void OutputSlot::SetTensorInfo(const TensorInfo& tensorInfo)
58 GetOutputHandler().SetTensorInfo(tensorInfo);
61 const TensorInfo& OutputSlot::GetTensorInfo() const
63 return GetOutputHandler().GetTensorInfo();
66 bool OutputSlot::IsTensorInfoSet() const
68 return GetOutputHandler().IsTensorInfoSet();
71 bool OutputSlot::ValidateTensorShape(const TensorShape& shape) const
73 BOOST_ASSERT_MSG(IsTensorInfoSet(), "TensorInfo must be set in order to validate the shape.");
74 return shape == m_OutputHandler.GetTensorInfo().GetShape();
77 int OutputSlot::Connect(InputSlot& destination)
79 destination.SetConnection(this);
80 m_Connections.push_back(&destination);
81 return boost::numeric_cast<int>(m_Connections.size() - 1);
84 void OutputSlot::Disconnect(InputSlot& slot)
86 slot.SetConnection(nullptr);
87 m_Connections.erase(std::remove(m_Connections.begin(), m_Connections.end(), &slot), m_Connections.end());
90 void OutputSlot::DisconnectAll()
92 while (GetNumConnections() > 0)
94 InputSlot& connection = *GetConnection(0);
95 Disconnect(connection);
99 void OutputSlot::MoveAllConnections(OutputSlot& destination)
101 while (GetNumConnections() > 0)
103 InputSlot& connection = *GetConnection(0);
104 Disconnect(connection);
105 destination.Connect(connection);
109 void OutputSlot::ValidateConnectionIndex(unsigned int index) const
111 if (boost::numeric_cast<std::size_t>(index) >= m_Connections.size())
113 throw InvalidArgumentException(
114 boost::str(boost::format("GetConnection: Invalid index %1% provided") % index));
118 Layer::Layer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name)
119 : m_OutputHandlers(numOutputSlots)
120 , m_LayerName(name ? name : "")
122 , m_ComputeDevice(Compute::Undefined)
124 m_InputSlots.reserve(numInputSlots);
125 for (unsigned int i = 0; i < numInputSlots; ++i)
127 m_InputSlots.emplace_back(*this, i);
130 m_OutputSlots.reserve(numOutputSlots);
131 for (unsigned int i = 0; i < numOutputSlots; ++i)
133 m_OutputSlots.emplace_back(*this, m_OutputHandlers[i]);
137 void Layer::CollectWorkloadInputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
139 for (auto&& inputSlot : GetInputSlots())
141 // The graph must be well-formed at this point
142 BOOST_ASSERT(inputSlot.GetConnection());
143 const OutputHandler& outputHandler = inputSlot.GetConnectedOutputSlot()->GetOutputHandler();
144 dataCollector.Push(outputHandler.GetData(), outputHandler.GetTensorInfo());
148 void Layer::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
150 for (auto&& outputHandler : m_OutputHandlers)
152 outputHandler.CollectWorkloadOutputs(dataCollector);
156 void Layer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
158 for (auto&& outputHandler : m_OutputHandlers)
160 outputHandler.CreateTensorHandles(factory);
164 DataType Layer::GetDataType() const
166 if (GetNumInputSlots() > 0) // Ignore the input layer
168 return GetInputSlot(0).GetConnection()->GetTensorInfo().GetDataType();
170 return DataType::Float32;
173 void Layer::ResetPriority() const
179 LayerPriority Layer::GetPriority() const
181 constexpr LayerPriority inputPrio = std::numeric_limits<LayerPriority>::lowest();
182 constexpr LayerPriority outputPrio = std::numeric_limits<LayerPriority>::max();
184 if (GetType() == LayerType::Input)
186 m_Priority = inputPrio;
188 else if (GetType() == LayerType::Output)
190 m_Priority = outputPrio;
192 else if (m_Priority == 0)
196 throw GraphValidationException("Graph has circular dependencies: cannot walk");
199 auto maxPrio = [](const LayerPriority prio, const InputSlot& slot) -> LayerPriority
201 const Layer& input = slot.GetConnectedOutputSlot()->GetOwningLayer();
202 return std::max(prio, input.GetPriority());
206 LayerPriority parentPrio = std::accumulate(GetInputSlots().cbegin(), GetInputSlots().cend(), 0U, maxPrio);
209 if (parentPrio >= outputPrio)
211 throw GraphValidationException("Graph has too many edges");
214 m_Priority = parentPrio + 1U;