IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / backendsCommon / OutputHandler.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "ITensorHandle.hpp"
8
9 #include <armnn/Descriptors.hpp>
10 #include <armnn/INetwork.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/Types.hpp>
13
14 #include <backendsCommon/WorkloadDataFwd.hpp>
15
16 #include <memory>
17 #include <set>
18 #include <string>
19 #include <vector>
20
21 #include <boost/assert.hpp>
22
23 namespace armnn
24 {
25
26 class ITensorHandle;
27 class IWorkloadFactory;
28 class OutputSlot;
29 class WorkloadDataCollector;
30
31 class OutputHandler
32 {
33 public:
34     /// @brief - Sets the TensorInfo used by this output handler.
35     /// @param tensorInfo - TensorInfo for the output.
36     void SetTensorInfo(const TensorInfo& tensorInfo);
37
38     /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
39     /// @param factory - Factory to be used for handler creation.
40     void CreateTensorHandles(const IWorkloadFactory& factory);
41
42     /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
43     /// @param factory - Factory to be used for handler creation.
44     /// @param dataLayout - Data Layout to be used for handler creation.
45     void CreateTensorHandles(const IWorkloadFactory& factory, DataLayout dataLayout);
46
47     /// @brief - Gets the matching TensorInfo for the output.
48     /// @return - References to the output TensorInfo.
49     const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
50
51     /// @brief - Gets the allocated tensor memory.
52     /// @return - Pointer to the tensor memory.
53     ITensorHandle* GetData() const { return m_TensorHandle.get(); }
54
55     /// Fill the outputs for a given queue descriptor.
56     void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const;
57
58     void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); }
59
60     /// @brief Returns true if SetTensorInfo() has been called at least once on this.
61     bool IsTensorInfoSet() const { return m_bTensorInfoSet; }
62 private:
63     std::unique_ptr<ITensorHandle> m_TensorHandle;
64     TensorInfo m_TensorInfo;
65     bool m_bTensorInfoSet = false;
66 };
67
68 } //namespace armnn