2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "LoadedNetwork.hpp"
8 #include "armnn/INetwork.hpp"
9 #include "armnn/IRuntime.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "backends/ClContextControl.hpp"
14 #include <unordered_map>
19 class Runtime final : public IRuntime
22 /// Load a complete network into the Runtime.
23 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
24 /// @param [in] network Complete network to load into the Runtime.
25 /// The runtime takes ownership of the network once passed in.
26 /// @return armnn::Status
27 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
29 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
30 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
32 // Evaluate network using input in inputTensors, outputs filled into outputTensors
33 virtual Status EnqueueWorkload(NetworkId networkId,
34 const InputTensors& inputTensors,
35 const OutputTensors& outputTensors) override;
37 /// Unload a network from the Runtime.
38 /// At the moment this only removes the network from the m_Impl->m_Network.
39 /// This might need more work in the future to be AndroidNN compliant.
40 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
41 /// @return armnn::Status
42 virtual Status UnloadNetwork(NetworkId networkId) override;
44 virtual const DeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
46 /// Creates a runtime for workload execution.
47 /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
48 /// it cannot be setup for some reason.
49 Runtime(const CreationOptions& options);
54 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // see RuntimeTests.cpp
56 int GenerateNetworkId();
58 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
60 mutable std::mutex m_Mutex;
62 std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
64 ClContextControl m_ClContextControl;
66 int m_NetworkIdCounter;
68 bool m_UseCpuRefAsFallback;
70 DeviceSpec m_DeviceSpec;