2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "LoadedNetwork.hpp"
8 #include "armnn/INetwork.hpp"
9 #include "armnn/IRuntime.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "backends/RefWorkloadFactory.hpp"
12 #include "backends/NeonWorkloadFactory.hpp"
13 #include "backends/ClWorkloadFactory.hpp"
15 #include <unordered_map>
20 struct WorkloadFactories
22 std::shared_ptr<RefWorkloadFactory> m_CpuRef;
23 std::shared_ptr<NeonWorkloadFactory> m_CpuAcc;
24 std::shared_ptr<ClWorkloadFactory> m_GpuAcc;
27 class Runtime final : public IRuntime
30 /// Load a complete network into the Runtime.
31 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32 /// @param [in] network Complete network to load into the Runtime.
33 /// The runtime takes ownership of the network once passed in.
34 /// @return armnn::Status
35 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
37 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
38 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
40 // Evaluate network using input in inputTensors, outputs filled into outputTensors
41 virtual Status EnqueueWorkload(NetworkId networkId,
42 const InputTensors& inputTensors,
43 const OutputTensors& outputTensors) override;
45 /// Unload a network from the Runtime.
46 /// At the moment this only removes the network from the m_Impl->m_Network.
47 /// This might need more work in the future to be AndroidNN compliant.
48 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
49 /// @return armnn::Status
50 virtual Status UnloadNetwork(NetworkId networkId) override;
52 virtual const DeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
54 /// Creates a runtime for workload execution.
55 /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
56 /// it cannot be setup for some reason.
57 Runtime(const CreationOptions& options);
62 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // see RuntimeTests.cpp
64 int GenerateNetworkId();
66 std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
68 WorkloadFactories m_WorkloadFactories;
70 int m_NetworkIdCounter;
72 DeviceSpec m_DeviceSpec;