IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / armnn / Runtime.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/BackendId.hpp>
13 #include <backendsCommon/IBackendContext.hpp>
14
15 #include <mutex>
16 #include <unordered_map>
17
18 namespace armnn
19 {
20
21 class Runtime final : public IRuntime
22 {
23 public:
24     /// Loads a complete network into the Runtime.
25     /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
26     /// @param [in] network - Complete network to load into the Runtime.
27     /// The runtime takes ownership of the network once passed in.
28     /// @return armnn::Status
29     virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
30
31     /// Load a complete network into the IRuntime.
32     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
33     /// @param [in] network Complete network to load into the IRuntime.
34     /// @param [out] errorMessage Error message if there were any errors.
35     /// The runtime takes ownership of the network once passed in.
36     /// @return armnn::Status
37     virtual Status LoadNetwork(NetworkId& networkIdOut,
38                                IOptimizedNetworkPtr network,
39                                std::string & errorMessage) override;
40
41     virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
42     virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
43
44     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
45     virtual Status EnqueueWorkload(NetworkId networkId,
46         const InputTensors& inputTensors,
47         const OutputTensors& outputTensors) override;
48
49     /// Unloads a network from the Runtime.
50     /// At the moment this only removes the network from the m_Impl->m_Network.
51     /// This might need more work in the future to be AndroidNN compliant.
52     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
53     /// @return armnn::Status
54     virtual Status UnloadNetwork(NetworkId networkId) override;
55
56     virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
57
58     /// Gets the profiler corresponding to the given network id.
59     /// @param networkId The id of the network for which to get the profile.
60     /// @return A pointer to the requested profiler, or nullptr if not found.
61     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
62
63     /// Creates a runtime for workload execution.
64     /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
65     /// it cannot be setup for some reason.
66     Runtime(const CreationOptions& options);
67
68     ~Runtime();
69
70 private:
71     friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
72
73     int GenerateNetworkId();
74
75     LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
76
77     template<typename Func>
78     void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
79     {
80         std::lock_guard<std::mutex> lockGuard(m_Mutex);
81         auto iter = m_LoadedNetworks.find(networkId);
82         if (iter != m_LoadedNetworks.end())
83         {
84             f(iter->second.get());
85         }
86     }
87
88     mutable std::mutex m_Mutex;
89     std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
90     CreationOptions m_Options;
91     int m_NetworkIdCounter;
92     DeviceSpec m_DeviceSpec;
93
94     using BackendContextMap = std::unordered_map<BackendId, IBackendContextUniquePtr>;
95     BackendContextMap m_BackendContexts;
96 };
97
98 }