From 6e9482013f41725ccca0767c0c5db9b29f77d981 Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Fri, 22 Mar 2019 14:01:46 +0000 Subject: [PATCH] IVGCVSW-2865 Extend IRuntime to add a new method RegisterDebugCallback(...) * Made changes to LoadedNetwork and IWorkload to pass on the registered callback function Change-Id: I6ea10f2a299d6de8bf681c8ff36d3fbed1d6d887 Signed-off-by: Nattapat Chaimanowong --- include/armnn/IRuntime.hpp | 5 +++++ include/armnn/Types.hpp | 5 +++++ src/armnn/LoadedNetwork.cpp | 8 ++++++++ src/armnn/LoadedNetwork.hpp | 2 ++ src/armnn/Runtime.cpp | 6 ++++++ src/armnn/Runtime.hpp | 5 +++++ src/backends/backendsCommon/Workload.hpp | 2 ++ 7 files changed, 33 insertions(+) diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index b977afe..44864ce 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -83,6 +83,11 @@ public: /// @return A pointer to the requested profiler, or nullptr if not found. virtual const std::shared_ptr GetProfiler(NetworkId networkId) const = 0; + /// Registers a callback function to debug layers performing custom computations on intermediate tensors. + /// @param networkId The id of the network to register the callback. + /// @param func callback function to pass to the debug layer. + virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) = 0; + protected: ~IRuntime() {} }; diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 36e3c5b..693a050 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -180,4 +180,9 @@ private: /// Define LayerGuid type. using LayerGuid = unsigned int; +class ITensorHandle; + +/// Define the type of callback for the debug layer to call +using DebugCallbackFunction = std::function; + } // namespace armnn diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 4221d36..9263f1a 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -485,4 +485,12 @@ bool LoadedNetwork::Execute() return success; } +void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func) +{ + for (auto&& workloadPtr: m_WorkloadQueue) + { + workloadPtr.get()->RegisterDebugCallback(func); + } +} + } diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 9c0fe0b..75af4a4 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -49,6 +49,8 @@ public: void FreeWorkingMemory(); + void RegisterDebugCallback(const DebugCallbackFunction& func); + private: void AllocateWorkingMemory(); diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp index 09be92c..f8b2462 100644 --- a/src/armnn/Runtime.cpp +++ b/src/armnn/Runtime.cpp @@ -231,4 +231,10 @@ Status Runtime::EnqueueWorkload(NetworkId networkId, return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors); } +void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) +{ + LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId); + loadedNetwork->RegisterDebugCallback(func); +} + } diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp index a3f4a39..10383bc 100644 --- a/src/armnn/Runtime.hpp +++ b/src/armnn/Runtime.hpp @@ -59,6 +59,11 @@ public: /// @return A pointer to the requested profiler, or nullptr if not found. virtual const std::shared_ptr GetProfiler(NetworkId networkId) const override; + /// Registers a callback function to debug layers performing custom computations on intermediate tensors. + /// @param networkId The id of the network to register the callback. + /// @param func callback function to pass to the debug layer. + virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override; + /// Creates a runtime for workload execution. /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but /// it cannot be setup for some reason. diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 7fb26f8..447ec1b 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -21,6 +21,8 @@ public: virtual ~IWorkload() {} virtual void Execute() const = 0; + + virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {} }; // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template -- 2.7.4