2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/BackendId.hpp>
15 #include <unordered_map>
20 class Runtime final : public IRuntime
23 /// Loads a complete network into the Runtime.
24 /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
25 /// @param [in] network - Complete network to load into the Runtime.
26 /// The runtime takes ownership of the network once passed in.
27 /// @return armnn::Status
28 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
30 /// Load a complete network into the IRuntime.
31 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32 /// @param [in] network Complete network to load into the IRuntime.
33 /// @param [out] errorMessage Error message if there were any errors.
34 /// The runtime takes ownership of the network once passed in.
35 /// @return armnn::Status
36 virtual Status LoadNetwork(NetworkId& networkIdOut,
37 IOptimizedNetworkPtr network,
38 std::string & errorMessage) override;
40 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
41 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
43 // Evaluates network using input in inputTensors, outputs filled into outputTensors.
44 virtual Status EnqueueWorkload(NetworkId networkId,
45 const InputTensors& inputTensors,
46 const OutputTensors& outputTensors) override;
48 /// Unloads a network from the Runtime.
49 /// At the moment this only removes the network from the m_Impl->m_Network.
50 /// This might need more work in the future to be AndroidNN compliant.
51 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
52 /// @return armnn::Status
53 virtual Status UnloadNetwork(NetworkId networkId) override;
55 virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
57 /// Gets the profiler corresponding to the given network id.
58 /// @param networkId The id of the network for which to get the profile.
59 /// @return A pointer to the requested profiler, or nullptr if not found.
60 virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
62 /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
63 /// @param networkId The id of the network to register the callback.
64 /// @param func callback function to pass to the debug layer.
65 virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
67 /// Creates a runtime for workload execution.
68 /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
69 /// it cannot be setup for some reason.
70 Runtime(const CreationOptions& options);
75 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
77 int GenerateNetworkId();
79 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
81 template<typename Func>
82 void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
84 std::lock_guard<std::mutex> lockGuard(m_Mutex);
85 auto iter = m_LoadedNetworks.find(networkId);
86 if (iter != m_LoadedNetworks.end())
88 f(iter->second.get());
92 mutable std::mutex m_Mutex;
94 std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
95 std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
97 int m_NetworkIdCounter;
99 DeviceSpec m_DeviceSpec;