2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "ITensorHandle.hpp"
9 #include <armnn/Descriptors.hpp>
10 #include <armnn/INetwork.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/Types.hpp>
14 #include <backendsCommon/WorkloadDataFwd.hpp>
21 #include <boost/assert.hpp>
27 class IWorkloadFactory;
29 class WorkloadDataCollector;
34 /// @brief - Sets the TensorInfo used by this output handler.
35 /// @param tensorInfo - TensorInfo for the output.
36 void SetTensorInfo(const TensorInfo& tensorInfo);
38 /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
39 /// @param factory - Factory to be used for handler creation.
40 void CreateTensorHandles(const IWorkloadFactory& factory);
42 /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
43 /// @param factory - Factory to be used for handler creation.
44 /// @param dataLayout - Data Layout to be used for handler creation.
45 void CreateTensorHandles(const IWorkloadFactory& factory, DataLayout dataLayout);
47 /// @brief - Gets the matching TensorInfo for the output.
48 /// @return - References to the output TensorInfo.
49 const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
51 /// @brief - Gets the allocated tensor memory.
52 /// @return - Pointer to the tensor memory.
53 ITensorHandle* GetData() const { return m_TensorHandle.get(); }
55 /// Fill the outputs for a given queue descriptor.
56 void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const;
58 void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); }
60 /// @brief Returns true if SetTensorInfo() has been called at least once on this.
61 bool IsTensorInfoSet() const { return m_bTensorInfoSet; }
63 std::unique_ptr<ITensorHandle> m_TensorHandle;
64 TensorInfo m_TensorInfo;
65 bool m_bTensorInfoSet = false;