/// @return A pointer to the requested profiler, or nullptr if not found.
virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const = 0;
+ /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+ /// @param networkId The id of the network to register the callback.
+ /// @param func callback function to pass to the debug layer.
+ virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) = 0;
+
protected:
~IRuntime() {}
};
/// Define LayerGuid type.
using LayerGuid = unsigned int;
+class ITensorHandle;
+
+/// Define the type of callback for the debug layer to call
+using DebugCallbackFunction = std::function<void(LayerGuid, unsigned int, ITensorHandle*)>;
+
} // namespace armnn
return success;
}
+void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func)
+{
+ for (auto&& workloadPtr: m_WorkloadQueue)
+ {
+ workloadPtr.get()->RegisterDebugCallback(func);
+ }
+}
+
}
void FreeWorkingMemory();
+ void RegisterDebugCallback(const DebugCallbackFunction& func);
+
private:
void AllocateWorkingMemory();
return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
+void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
+{
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+ loadedNetwork->RegisterDebugCallback(func);
+}
+
}
/// @return A pointer to the requested profiler, or nullptr if not found.
virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
+ /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+ /// @param networkId The id of the network to register the callback.
+ /// @param func callback function to pass to the debug layer.
+ virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
+
/// Creates a runtime for workload execution.
/// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
/// it cannot be setup for some reason.
virtual ~IWorkload() {}
virtual void Execute() const = 0;
+
+ virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
};
// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template