2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "LayerFwd.hpp"
9 #include "backends/OutputHandler.hpp"
10 #include "backends/WorkloadDataCollector.hpp"
11 #include "backends/WorkloadInfo.hpp"
12 #include "InternalTypes.hpp"
14 #include <armnn/Types.hpp>
15 #include <armnn/Tensor.hpp>
16 #include <armnn/INetwork.hpp>
23 #include <boost/numeric/conversion/cast.hpp>
24 #include <boost/core/ignore_unused.hpp>
25 #include <boost/cast.hpp>
31 class IWorkloadFactory;
35 class InputSlot final : public IInputSlot
38 explicit InputSlot(Layer& owner, unsigned int slotIndex)
39 : m_OwningLayer(owner)
40 , m_Connection(nullptr)
41 , m_SlotIndex(slotIndex)
46 Layer& GetOwningLayer() const { return m_OwningLayer; }
47 unsigned int GetSlotIndex() const { return m_SlotIndex; }
49 const OutputSlot* GetConnectedOutputSlot() const { return m_Connection; }
50 OutputSlot* GetConnectedOutputSlot() { return m_Connection; }
52 /// Links the slot to an output slot or breaks an existing link if passing nullptr
53 void SetConnection(OutputSlot* source)
55 if (m_Connection != nullptr && source != nullptr)
57 throw InvalidArgumentException("Tried to connect an output slot to an input slot, "
58 "but the latter already has a connection");
60 m_Connection = source;
63 // Insert single-output existing layer at this point in the graph.
64 void Insert(Layer& layer);
68 const IOutputSlot* GetConnection() const override;
69 IOutputSlot* GetConnection() override;
73 OutputSlot* m_Connection;
74 const unsigned int m_SlotIndex;
77 class OutputSlot final : public IOutputSlot
80 explicit OutputSlot(Layer& owner, OutputHandler& outputHandler)
81 : m_OwningLayer(owner)
82 , m_OutputHandler(outputHandler)
90 Layer& GetOwningLayer() const { return m_OwningLayer; }
92 const OutputHandler& GetOutputHandler() const { return m_OutputHandler; }
93 OutputHandler& GetOutputHandler() { return m_OutputHandler; }
95 int Connect(InputSlot& destination);
96 void Disconnect(InputSlot& slot);
98 const std::vector<InputSlot*>& GetConnections() const { return m_Connections; }
100 bool ValidateTensorShape(const TensorShape& shape) const;
102 // Disconnect all conections
103 void DisconnectAll();
105 /// Move all connections to another OutputSlot
106 void MoveAllConnections(OutputSlot& destination);
110 unsigned int GetNumConnections() const override { return boost::numeric_cast<unsigned int>(m_Connections.size()); }
111 const InputSlot* GetConnection(unsigned int index) const override;
112 InputSlot* GetConnection(unsigned int index) override;
114 void SetTensorInfo(const TensorInfo& tensorInfo) override;
115 const TensorInfo& GetTensorInfo() const override;
116 bool IsTensorInfoSet() const override;
118 int Connect(IInputSlot& destination) override
120 return Connect(*boost::polymorphic_downcast<InputSlot*>(&destination));
123 void Disconnect(IInputSlot& slot) override
125 return Disconnect(*boost::polymorphic_downcast<InputSlot*>(&slot));
129 void ValidateConnectionIndex(unsigned int index) const;
131 Layer& m_OwningLayer;
132 OutputHandler& m_OutputHandler;
133 std::vector<InputSlot*> m_Connections;
136 // InputSlot inlines that need OutputSlot declaration
138 inline InputSlot::~InputSlot()
140 if (m_Connection != nullptr)
142 m_Connection->Disconnect(*this);
146 inline const IOutputSlot* InputSlot::GetConnection() const { return GetConnectedOutputSlot(); }
147 inline IOutputSlot* InputSlot::GetConnection() { return GetConnectedOutputSlot(); }
151 using LayerPriority = unsigned int;
153 class Layer : public IConnectableLayer
156 /// @param name Optional name for the layer (may be nullptr)
157 Layer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name);
159 const std::string& GetNameStr() const
164 const OutputHandler& GetOutputHandler(unsigned int i = 0) const
166 return m_OutputHandlers[i];
169 OutputHandler& GetOutputHandler(unsigned int i = 0)
171 return const_cast<OutputHandler&>(const_cast<const Layer*>(this)->GetOutputHandler(i));
174 const std::vector<InputSlot>& GetInputSlots() const { return m_InputSlots; }
175 const std::vector<OutputSlot>& GetOutputSlots() const { return m_OutputSlots; }
177 // Allow non-const access to input slots, but don't expose vector (vector size is fixed at layer construction).
178 std::vector<InputSlot>::iterator BeginInputSlots() { return m_InputSlots.begin(); }
179 std::vector<InputSlot>::iterator EndInputSlots() { return m_InputSlots.end(); }
181 // Allow non-const access to output slots, but don't expose vector (vector size is fixed at layer construction).
182 std::vector<OutputSlot>::iterator BeginOutputSlots() { return m_OutputSlots.begin(); }
183 std::vector<OutputSlot>::iterator EndOutputSlots() { return m_OutputSlots.end(); }
185 // Check whether the outputs of this layer don't have any connection
186 bool IsOutputUnconnected()
188 unsigned int numConnections = 0;
190 for (auto&& output : GetOutputSlots())
192 numConnections += output.GetNumConnections();
195 return (GetNumOutputSlots() > 0) && (numConnections == 0);
199 void ResetPriority() const;
200 LayerPriority GetPriority() const;
202 LayerType GetType() const { return m_Type; }
204 DataType GetDataType() const;
206 Compute GetComputeDevice() const { return m_ComputeDevice; }
207 void SetComputeDevice(Compute device) { m_ComputeDevice = device; }
211 virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const = 0;
213 virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory);
215 /// Creates a dynamically-allocated copy of this layer
216 /// @param graph The Graph into which this Layer is being cloned
217 virtual Layer* Clone(Graph& graph) const = 0;
219 virtual void ValidateTensorShapesFromInputs() = 0;
223 const char* GetName() const override { return m_LayerName.c_str(); }
225 unsigned int GetNumInputSlots() const override { return static_cast<unsigned int>(m_InputSlots.size()); }
226 unsigned int GetNumOutputSlots() const override { return static_cast<unsigned int>(m_OutputSlots.size()); }
228 const InputSlot& GetInputSlot(unsigned int index) const override { return m_InputSlots.at(index); }
229 InputSlot& GetInputSlot(unsigned int index) override { return m_InputSlots.at(index); }
230 const OutputSlot& GetOutputSlot(unsigned int index = 0) const override { return m_OutputSlots.at(index); }
231 OutputSlot& GetOutputSlot(unsigned int index = 0) override { return m_OutputSlots.at(index); }
234 // Graph needs access to the virtual destructor
236 virtual ~Layer() = default;
238 template <typename QueueDescriptor>
239 void CollectQueueDescriptorInputs(QueueDescriptor& descriptor, WorkloadInfo& info, const Graph& graph) const
241 WorkloadDataCollector dataCollector(descriptor.m_Inputs, info.m_InputTensorInfos);
242 CollectWorkloadInputs(dataCollector, graph);
245 template <typename QueueDescriptor>
246 void CollectQueueDescriptorOutputs(QueueDescriptor& descriptor, WorkloadInfo& info, const Graph& graph) const
248 WorkloadDataCollector dataCollector(descriptor.m_Outputs, info.m_OutputTensorInfos);
249 CollectWorkloadOutputs(dataCollector, graph);
252 /// Helper function to reduce duplication in *Layer::CreateWorkload
253 template <typename QueueDescriptor>
254 WorkloadInfo PrepInfoAndDesc(QueueDescriptor& descriptor, const Graph& graph) const
257 CollectQueueDescriptorInputs(descriptor, info, graph);
258 CollectQueueDescriptorOutputs(descriptor, info, graph);
262 template <typename LayerType, typename ... Params>
263 LayerType* CloneBase(Graph& graph, Params&& ... params) const;
266 void CollectWorkloadInputs(WorkloadDataCollector& dataCollector, const Graph& graph) const;
267 void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const Graph& graph) const;
270 std::vector<OutputHandler> m_OutputHandlers;
273 const std::string m_LayerName;
275 std::vector<InputSlot> m_InputSlots;
276 std::vector<OutputSlot> m_OutputSlots;
278 const LayerType m_Type;
279 Compute m_ComputeDevice;
282 mutable LayerPriority m_Priority = 0;
283 mutable bool m_Visiting = false;
286 // A layer user-provided data can be bound to (e.g. inputs, outputs)
287 class BindableLayer : public Layer
290 BindableLayer(unsigned int numInputSlots,
291 unsigned int numOutputSlots,
295 : Layer(numInputSlots, numOutputSlots, type, name)
300 LayerBindingId GetBindingId() const { return m_Id; };
303 ~BindableLayer() = default;