10383bc97032172a6ec45413383fdc99d71b7709
[platform/upstream/armnn.git] / src / armnn / Runtime.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/BackendId.hpp>
13
14 #include <mutex>
15 #include <unordered_map>
16
17 namespace armnn
18 {
19
20 class Runtime final : public IRuntime
21 {
22 public:
23     /// Loads a complete network into the Runtime.
24     /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
25     /// @param [in] network - Complete network to load into the Runtime.
26     /// The runtime takes ownership of the network once passed in.
27     /// @return armnn::Status
28     virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
29
30     /// Load a complete network into the IRuntime.
31     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32     /// @param [in] network Complete network to load into the IRuntime.
33     /// @param [out] errorMessage Error message if there were any errors.
34     /// The runtime takes ownership of the network once passed in.
35     /// @return armnn::Status
36     virtual Status LoadNetwork(NetworkId& networkIdOut,
37                                IOptimizedNetworkPtr network,
38                                std::string & errorMessage) override;
39
40     virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
41     virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
42
43     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
44     virtual Status EnqueueWorkload(NetworkId networkId,
45         const InputTensors& inputTensors,
46         const OutputTensors& outputTensors) override;
47
48     /// Unloads a network from the Runtime.
49     /// At the moment this only removes the network from the m_Impl->m_Network.
50     /// This might need more work in the future to be AndroidNN compliant.
51     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
52     /// @return armnn::Status
53     virtual Status UnloadNetwork(NetworkId networkId) override;
54
55     virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
56
57     /// Gets the profiler corresponding to the given network id.
58     /// @param networkId The id of the network for which to get the profile.
59     /// @return A pointer to the requested profiler, or nullptr if not found.
60     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
61
62     /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
63     /// @param networkId The id of the network to register the callback.
64     /// @param func callback function to pass to the debug layer.
65     virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
66
67     /// Creates a runtime for workload execution.
68     /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
69     /// it cannot be setup for some reason.
70     Runtime(const CreationOptions& options);
71
72     ~Runtime();
73
74 private:
75     friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
76
77     int GenerateNetworkId();
78
79     LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
80
81     template<typename Func>
82     void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
83     {
84         std::lock_guard<std::mutex> lockGuard(m_Mutex);
85         auto iter = m_LoadedNetworks.find(networkId);
86         if (iter != m_LoadedNetworks.end())
87         {
88             f(iter->second.get());
89         }
90     }
91
92     mutable std::mutex m_Mutex;
93
94     std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
95     std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
96
97     int m_NetworkIdCounter;
98
99     DeviceSpec m_DeviceSpec;
100 };
101
102 }