Release 18.03
[platform/upstream/armnn.git] / src / armnn / Runtime.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "LoadedNetwork.hpp"
8 #include "armnn/INetwork.hpp"
9 #include "armnn/IRuntime.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "backends/RefWorkloadFactory.hpp"
12 #include "backends/NeonWorkloadFactory.hpp"
13 #include "backends/ClWorkloadFactory.hpp"
14
15 #include <unordered_map>
16
17 namespace armnn
18 {
19
20 struct WorkloadFactories
21 {
22     std::shared_ptr<RefWorkloadFactory> m_CpuRef;
23     std::shared_ptr<NeonWorkloadFactory> m_CpuAcc;
24     std::shared_ptr<ClWorkloadFactory> m_GpuAcc;
25 };
26
27 class Runtime final : public IRuntime
28 {
29 public:
30     /// Load a complete network into the Runtime.
31     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32     /// @param [in] network Complete network to load into the Runtime.
33     /// The runtime takes ownership of the network once passed in.
34     /// @return armnn::Status
35     virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
36
37     virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
38     virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
39
40     // Evaluate network using input in inputTensors, outputs filled into outputTensors
41     virtual Status EnqueueWorkload(NetworkId networkId,
42         const InputTensors& inputTensors,
43         const OutputTensors& outputTensors) override;
44
45     /// Unload a network from the Runtime.
46     /// At the moment this only removes the network from the m_Impl->m_Network.
47     /// This might need more work in the future to be AndroidNN compliant.
48     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
49     /// @return armnn::Status
50     virtual Status UnloadNetwork(NetworkId networkId) override;
51
52     virtual const DeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
53
54     /// Creates a runtime for workload execution.
55     /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
56     /// it cannot be setup for some reason.
57     Runtime(const CreationOptions& options);
58
59     ~Runtime();
60
61 private:
62     friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // see RuntimeTests.cpp
63
64     int GenerateNetworkId();
65
66     std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
67
68     WorkloadFactories m_WorkloadFactories;
69
70     int m_NetworkIdCounter;
71
72     DeviceSpec m_DeviceSpec;
73 };
74
75 }