Release 18.08
[platform/upstream/armnn.git] / src / armnn / Runtime.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 #include "armnn/INetwork.hpp"
10 #include "armnn/IRuntime.hpp"
11 #include "armnn/Tensor.hpp"
12 #include "backends/ClContextControl.hpp"
13
14 #include <mutex>
15 #include <unordered_map>
16
17 namespace armnn
18 {
19
20 class Runtime final : public IRuntime
21 {
22 public:
23     /// Loads a complete network into the Runtime.
24     /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
25     /// @param [in] network - Complete network to load into the Runtime.
26     /// The runtime takes ownership of the network once passed in.
27     /// @return armnn::Status
28     virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
29
30     /// Load a complete network into the IRuntime.
31     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32     /// @param [in] network Complete network to load into the IRuntime.
33     /// @param [out] errorMessage Error message if there were any errors.
34     /// The runtime takes ownership of the network once passed in.
35     /// @return armnn::Status
36     virtual Status LoadNetwork(NetworkId& networkIdOut,
37                                IOptimizedNetworkPtr network,
38                                std::string & errorMessage) override;
39
40     virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
41     virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
42
43     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
44     virtual Status EnqueueWorkload(NetworkId networkId,
45         const InputTensors& inputTensors,
46         const OutputTensors& outputTensors) override;
47
48     /// Unloads a network from the Runtime.
49     /// At the moment this only removes the network from the m_Impl->m_Network.
50     /// This might need more work in the future to be AndroidNN compliant.
51     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
52     /// @return armnn::Status
53     virtual Status UnloadNetwork(NetworkId networkId) override;
54
55     virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
56
57     /// Gets the profiler corresponding to the given network id.
58     /// @param networkId The id of the network for which to get the profile.
59     /// @return A pointer to the requested profiler, or nullptr if not found.
60     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
61
62     /// Creates a runtime for workload execution.
63     /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
64     /// it cannot be setup for some reason.
65     Runtime(const CreationOptions& options);
66
67     ~Runtime();
68
69 private:
70     friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
71
72     int GenerateNetworkId();
73
74     LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
75
76     mutable std::mutex m_Mutex;
77
78     std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
79
80     ClContextControl m_ClContextControl;
81
82     int m_NetworkIdCounter;
83
84     DeviceSpec m_DeviceSpec;
85 };
86
87 }