IVGCVSW-2865 Extend IRuntime to add a new method RegisterDebugCallback(...)
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Fri, 22 Mar 2019 14:01:46 +0000 (14:01 +0000)
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>
Fri, 22 Mar 2019 14:38:31 +0000 (14:38 +0000)
* Made changes to LoadedNetwork and IWorkload to pass on the registered
callback function

Change-Id: I6ea10f2a299d6de8bf681c8ff36d3fbed1d6d887
Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
include/armnn/IRuntime.hpp
include/armnn/Types.hpp
src/armnn/LoadedNetwork.cpp
src/armnn/LoadedNetwork.hpp
src/armnn/Runtime.cpp
src/armnn/Runtime.hpp
src/backends/backendsCommon/Workload.hpp

index b977afe..44864ce 100644 (file)
@@ -83,6 +83,11 @@ public:
     /// @return A pointer to the requested profiler, or nullptr if not found.
     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const = 0;
 
+    /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+    /// @param networkId The id of the network to register the callback.
+    /// @param func callback function to pass to the debug layer.
+    virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) = 0;
+
 protected:
     ~IRuntime() {}
 };
index 36e3c5b..693a050 100644 (file)
@@ -180,4 +180,9 @@ private:
 /// Define LayerGuid type.
 using LayerGuid = unsigned int;
 
+class ITensorHandle;
+
+/// Define the type of callback for the debug layer to call
+using DebugCallbackFunction = std::function<void(LayerGuid, unsigned int, ITensorHandle*)>;
+
 } // namespace armnn
index 4221d36..9263f1a 100644 (file)
@@ -485,4 +485,12 @@ bool LoadedNetwork::Execute()
     return success;
 }
 
+void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func)
+{
+    for (auto&& workloadPtr: m_WorkloadQueue)
+    {
+        workloadPtr.get()->RegisterDebugCallback(func);
+    }
+}
+
 }
index 9c0fe0b..75af4a4 100644 (file)
@@ -49,6 +49,8 @@ public:
 
     void FreeWorkingMemory();
 
+    void RegisterDebugCallback(const DebugCallbackFunction& func);
+
 private:
     void AllocateWorkingMemory();
 
index 09be92c..f8b2462 100644 (file)
@@ -231,4 +231,10 @@ Status Runtime::EnqueueWorkload(NetworkId networkId,
     return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
 }
 
+void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
+{
+    LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+    loadedNetwork->RegisterDebugCallback(func);
+}
+
 }
index a3f4a39..10383bc 100644 (file)
@@ -59,6 +59,11 @@ public:
     /// @return A pointer to the requested profiler, or nullptr if not found.
     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
 
+    /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+    /// @param networkId The id of the network to register the callback.
+    /// @param func callback function to pass to the debug layer.
+    virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
+
     /// Creates a runtime for workload execution.
     /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
     /// it cannot be setup for some reason.
index 7fb26f8..447ec1b 100644 (file)
@@ -21,6 +21,8 @@ public:
     virtual ~IWorkload() {}
 
     virtual void Execute() const = 0;
+
+    virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
 };
 
 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template