remove BOM from files
[platform/upstream/armnn.git] / src / armnn / LoadedNetwork.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9
10 #include "Network.hpp"
11 #include "LayerFwd.hpp"
12 #include "Profiling.hpp"
13
14 #include <armnn/backends/IBackendInternal.hpp>
15 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
16 #include <backendsCommon/Workload.hpp>
17 #include <backendsCommon/WorkloadFactory.hpp>
18 #include <ProfilingService.hpp>
19 #include <TimelineUtilityMethods.hpp>
20
21 #include <mutex>
22 #include <unordered_map>
23
24 namespace cl
25 {
26     class Context;
27     class CommandQueue;
28     class Device;
29 }
30
31 namespace armnn
32 {
33
34 class LoadedNetwork
35 {
36 public:
37     using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
38     ~LoadedNetwork(){ FreeWorkingMemory(); }
39
40     TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
41     TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
42
43     Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
44
45     static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
46                                                             std::string & errorMessage,
47                                                             const INetworkProperties& networkProperties,
48                                                             profiling::ProfilingService& profilingService);
49
50     // NOTE we return by reference as the purpose of this method is only to provide
51     // access to the private m_Profiler and in theory we should not need to increment
52     // the shared_ptr's reference counter
53     const std::shared_ptr<Profiler>& GetProfiler() const { return m_Profiler; }
54
55     void FreeWorkingMemory();
56
57     void RegisterDebugCallback(const DebugCallbackFunction& func);
58
59     void SendNetworkStructure();
60
61 private:
62     void AllocateWorkingMemory();
63
64     LoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
65                   const INetworkProperties& networkProperties,
66                   profiling::ProfilingService& profilingService);
67
68     void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
69
70     void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
71
72     bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
73                  profiling::ProfilingGuid inferenceGuid);
74
75
76     const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
77
78     using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
79
80     using WorkloadFactoryWithMemoryManager =
81         std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>;
82
83     using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>;
84
85     BackendPtrMap       m_Backends;
86     WorkloadFactoryMap  m_WorkloadFactories;
87
88     std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
89     WorkloadQueue m_InputQueue;
90     WorkloadQueue m_WorkloadQueue;
91     WorkloadQueue m_OutputQueue;
92     std::shared_ptr<Profiler> m_Profiler;
93
94     mutable std::mutex m_WorkingMemMutex;
95
96     bool m_IsWorkingMemAllocated=false;
97     bool m_IsImportEnabled=false;
98     bool m_IsExportEnabled=false;
99
100     TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
101
102     profiling::ProfilingService&  m_ProfilingService;
103 };
104
105 }