IVGCVSW-4338 Implement the Activation of Counters in backends
authorFinn Williams <Finn.Williams@arm.com>
Wed, 12 Feb 2020 11:02:34 +0000 (11:02 +0000)
committerFinn Williams <Finn.Williams@arm.com>
Fri, 14 Feb 2020 10:20:00 +0000 (10:20 +0000)
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I4a2465f06e046f78242ff0a246c651638b205498

17 files changed:
include/armnn/backends/profiling/IBackendProfilingContext.hpp
src/backends/backendsCommon/test/BackendProfilingTests.cpp
src/backends/backendsCommon/test/MockBackend.hpp
src/profiling/Holder.cpp
src/profiling/Holder.hpp
src/profiling/ISendCounterPacket.hpp
src/profiling/PeriodicCounterCapture.cpp
src/profiling/PeriodicCounterCapture.hpp
src/profiling/PeriodicCounterSelectionCommandHandler.cpp
src/profiling/PeriodicCounterSelectionCommandHandler.hpp
src/profiling/ProfilingService.cpp
src/profiling/ProfilingService.hpp
src/profiling/SendCounterPacket.cpp
src/profiling/SendCounterPacket.hpp
src/profiling/test/ProfilingTests.cpp
src/profiling/test/SendCounterPacketTests.cpp
src/profiling/test/SendCounterPacketTests.hpp

index d7f062b..3f54d31 100644 (file)
@@ -19,7 +19,7 @@ public:
     virtual ~IBackendProfilingContext()
     {}
     virtual uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterID) = 0;
-    virtual void ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) = 0;
+    virtual Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) = 0;
     virtual std::vector<Timestamp> ReportCounterValues() = 0;
     virtual void EnableProfiling(bool flag) = 0;
 };
index fc21730..6e4a020 100644 (file)
 // SPDX-License-Identifier: MIT
 //
 
+#include "CounterDirectory.hpp"
+#include "CounterIdMap.hpp"
+#include "Holder.hpp"
 #include "MockBackend.hpp"
 #include "MockBackendId.hpp"
-#include "Runtime.hpp"
+#include "PeriodicCounterCapture.hpp"
+#include "PeriodicCounterSelectionCommandHandler.hpp"
+#include "ProfilingStateMachine.hpp"
+#include "ProfilingUtils.hpp"
+#include "RequestCounterDirectoryCommandHandler.hpp"
 
 #include <armnn/BackendId.hpp>
+#include <armnn/Logging.hpp>
+
+#include <boost/algorithm/string.hpp>
+#include <boost/numeric/conversion/cast.hpp>
 #include <boost/test/unit_test.hpp>
 #include <vector>
 
+#include <cstdint>
+#include <limits>
+#include <backends/BackendProfiling.hpp>
+
+using namespace armnn::profiling;
+
+class ReadCounterVals : public IReadCounterValues
+{
+    virtual bool IsCounterRegistered(uint16_t counterUid) const override
+    {
+        return (counterUid > 4 && counterUid < 11);
+    }
+    virtual uint16_t GetCounterCount() const override
+    {
+        return 1;
+    }
+    virtual uint32_t GetCounterValue(uint16_t counterUid) const override
+    {
+        return counterUid;
+    }
+};
+
+class MockBackendSendCounterPacket : public ISendCounterPacket
+{
+public:
+    using IndexValuePairsVector = std::vector<CounterValue>;
+
+    /// Create and write a StreamMetaDataPacket in the buffer
+    virtual void SendStreamMetaDataPacket() {}
+
+    /// Create and write a CounterDirectoryPacket from the parameters to the buffer.
+    virtual void SendCounterDirectoryPacket(const ICounterDirectory& counterDirectory)
+    {
+        boost::ignore_unused(counterDirectory);
+    }
+
+    /// Create and write a PeriodicCounterCapturePacket from the parameters to the buffer.
+    virtual void SendPeriodicCounterCapturePacket(uint64_t timestamp, const IndexValuePairsVector& values)
+    {
+        m_timestamps.emplace_back(Timestamp{timestamp, values});
+    }
+
+    /// Create and write a PeriodicCounterSelectionPacket from the parameters to the buffer.
+    virtual void SendPeriodicCounterSelectionPacket(uint32_t capturePeriod,
+                                                    const std::vector<uint16_t>& selectedCounterIds)
+    {
+        boost::ignore_unused(capturePeriod);
+        boost::ignore_unused(selectedCounterIds);
+    }
+
+    std::vector<Timestamp> GetTimestamps()
+    {
+        return  m_timestamps;
+    }
+
+    void ClearTimestamps()
+    {
+        m_timestamps.clear();
+    }
+
+private:
+    std::vector<Timestamp> m_timestamps;
+};
+
+Packet PacketWriter(uint32_t period, std::vector<uint16_t> countervalues)
+{
+    const uint32_t packetId = 0x40000;
+    uint32_t offset = 0;
+    uint32_t dataLength = static_cast<uint32_t>(4 + countervalues.size() * 2);
+    std::unique_ptr<unsigned char[]> uniqueData = std::make_unique<unsigned char[]>(dataLength);
+    unsigned char* data1                        = reinterpret_cast<unsigned char*>(uniqueData.get());
+
+    WriteUint32(data1, offset, period);
+    offset += 4;
+    for (auto countervalue : countervalues)
+    {
+        WriteUint16(data1, offset, countervalue);
+        offset += 2;
+    }
+
+    return {packetId, dataLength, uniqueData};
+}
+
 BOOST_AUTO_TEST_SUITE(BackendProfilingTestSuite)
 
 BOOST_AUTO_TEST_CASE(BackendProfilingCounterRegisterMockBackendTest)
@@ -38,4 +132,314 @@ BOOST_AUTO_TEST_CASE(BackendProfilingCounterRegisterMockBackendTest)
     profilingService.ResetExternalProfilingOptions(options.m_ProfilingOptions, true);
 }
 
+BOOST_AUTO_TEST_CASE(TestBackendCounters)
+{
+    Holder holder;
+    PacketVersionResolver packetVersionResolver;
+    ProfilingStateMachine stateMachine;
+    ReadCounterVals readCounterVals;
+    CounterIdMap counterIdMap;
+    MockBackendSendCounterPacket sendCounterPacket;
+
+    const armnn::BackendId cpuAccId(armnn::Compute::CpuAcc);
+    const armnn::BackendId gpuAccId(armnn::Compute::GpuAcc);
+
+    armnn::IRuntime::CreationOptions options;
+    options.m_ProfilingOptions.m_EnableProfiling = true;
+
+    armnn::profiling::ProfilingService& profilingService = armnn::profiling::ProfilingService::Instance();
+
+    std::unique_ptr<armnn::profiling::IBackendProfiling> cpuBackendProfilingPtr =
+            std::make_unique<BackendProfiling>(options, profilingService, cpuAccId);
+    std::unique_ptr<armnn::profiling::IBackendProfiling> gpuBackendProfilingPtr =
+            std::make_unique<BackendProfiling>(options, profilingService, gpuAccId);
+
+    std::shared_ptr<armnn::profiling::IBackendProfilingContext> cpuProfilingContextPtr =
+            std::make_shared<armnn::MockBackendProfilingContext>(cpuBackendProfilingPtr);
+    std::shared_ptr<armnn::profiling::IBackendProfilingContext> gpuProfilingContextPtr =
+            std::make_shared<armnn::MockBackendProfilingContext>(gpuBackendProfilingPtr);
+
+    std::unordered_map<armnn::BackendId,
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContexts;
+
+    backendProfilingContexts[cpuAccId] = cpuProfilingContextPtr;
+    backendProfilingContexts[gpuAccId] = gpuProfilingContextPtr;
+
+    uint16_t globalId = 5;
+
+    counterIdMap.RegisterMapping(globalId++, 0, cpuAccId);
+    counterIdMap.RegisterMapping(globalId++, 1, cpuAccId);
+    counterIdMap.RegisterMapping(globalId++, 2, cpuAccId);
+
+    counterIdMap.RegisterMapping(globalId++, 0, gpuAccId);
+    counterIdMap.RegisterMapping(globalId++, 1, gpuAccId);
+    counterIdMap.RegisterMapping(globalId++, 2, gpuAccId);
+
+    backendProfilingContexts[cpuAccId] = cpuProfilingContextPtr;
+    backendProfilingContexts[gpuAccId] = gpuProfilingContextPtr;
+
+    PeriodicCounterCapture periodicCounterCapture(holder, sendCounterPacket, readCounterVals,
+                                                  counterIdMap, backendProfilingContexts);
+
+    uint16_t maxArmnnCounterId = 4;
+
+    PeriodicCounterSelectionCommandHandler periodicCounterSelectionCommandHandler(0,
+                                                  4,
+                                                  packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
+                                                  backendProfilingContexts,
+                                                  counterIdMap,
+                                                  holder,
+                                                  maxArmnnCounterId,
+                                                  periodicCounterCapture,
+                                                  readCounterVals,
+                                                  sendCounterPacket,
+                                                  stateMachine);
+
+    stateMachine.TransitionToState(ProfilingState::NotConnected);
+    stateMachine.TransitionToState(ProfilingState::WaitingForAck);
+    stateMachine.TransitionToState(ProfilingState::Active);
+
+    uint32_t period = 12345u;
+
+    std::vector<uint16_t> cpuCounters{5, 6, 7};
+    std::vector<uint16_t> gpuCounters{8, 9, 10};
+
+    // Request only gpu counters
+    periodicCounterSelectionCommandHandler(PacketWriter(period, gpuCounters));
+    periodicCounterCapture.Stop();
+
+    std::set<armnn::BackendId> activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 1);
+    BOOST_CHECK(activeIds.find(gpuAccId) != activeIds.end());
+
+    std::vector<Timestamp> recievedTimestamp = sendCounterPacket.GetTimestamps();
+
+    BOOST_CHECK(recievedTimestamp[0].timestamp == period);
+    BOOST_CHECK(recievedTimestamp.size() == 1);
+    BOOST_CHECK(recievedTimestamp[0].counterValues.size() == gpuCounters.size());
+    for (unsigned long i=0; i< gpuCounters.size(); ++i)
+    {
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterId == gpuCounters[i]);
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterValue == i + 1u);
+    }
+    sendCounterPacket.ClearTimestamps();
+
+    // Request only cpu counters
+    periodicCounterSelectionCommandHandler(PacketWriter(period, cpuCounters));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 1);
+    BOOST_CHECK(activeIds.find(cpuAccId) != activeIds.end());
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+
+    BOOST_CHECK(recievedTimestamp[0].timestamp == period);
+    BOOST_CHECK(recievedTimestamp.size() == 1);
+    BOOST_CHECK(recievedTimestamp[0].counterValues.size() == cpuCounters.size());
+    for (unsigned long i=0; i< cpuCounters.size(); ++i)
+    {
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterId == cpuCounters[i]);
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterValue == i + 1u);
+    }
+    sendCounterPacket.ClearTimestamps();
+
+    // Request combination of cpu & gpu counters with new period
+    period = 12222u;
+    periodicCounterSelectionCommandHandler(PacketWriter(period, {cpuCounters[0], gpuCounters[2],
+                                                                 gpuCounters[1], cpuCounters[1], gpuCounters[0]}));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 2);
+    BOOST_CHECK(activeIds.find(cpuAccId) != activeIds.end());
+    BOOST_CHECK(activeIds.find(gpuAccId) != activeIds.end());
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+
+    BOOST_CHECK(recievedTimestamp[0].timestamp == period);
+    BOOST_CHECK(recievedTimestamp[1].timestamp == period);
+
+    BOOST_CHECK(recievedTimestamp.size() == 2);
+    BOOST_CHECK(recievedTimestamp[0].counterValues.size() == 2);
+    BOOST_CHECK(recievedTimestamp[1].counterValues.size() == gpuCounters.size());
+
+    BOOST_CHECK(recievedTimestamp[0].counterValues[0].counterId == cpuCounters[0]);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[0].counterValue == 1u);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[1].counterId == cpuCounters[1]);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[1].counterValue == 2u);
+
+    for (unsigned long i=0; i< gpuCounters.size(); ++i)
+    {
+        BOOST_CHECK(recievedTimestamp[1].counterValues[i].counterId == gpuCounters[i]);
+        BOOST_CHECK(recievedTimestamp[1].counterValues[i].counterValue == i + 1u);
+    }
+
+    sendCounterPacket.ClearTimestamps();
+
+    // Request all counters
+    std::vector<uint16_t> counterValues;
+    counterValues.insert(counterValues.begin(), cpuCounters.begin(), cpuCounters.end());
+    counterValues.insert(counterValues.begin(), gpuCounters.begin(), gpuCounters.end());
+
+    periodicCounterSelectionCommandHandler(PacketWriter(period, counterValues));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 2);
+    BOOST_CHECK(activeIds.find(cpuAccId) != activeIds.end());
+    BOOST_CHECK(activeIds.find(gpuAccId) != activeIds.end());
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+
+    BOOST_CHECK(recievedTimestamp[0].counterValues.size() == cpuCounters.size());
+    for (unsigned long i=0; i< cpuCounters.size(); ++i)
+    {
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterId == cpuCounters[i]);
+        BOOST_CHECK(recievedTimestamp[0].counterValues[i].counterValue == i + 1u);
+    }
+
+    BOOST_CHECK(recievedTimestamp[1].counterValues.size() == gpuCounters.size());
+    for (unsigned long i=0; i< gpuCounters.size(); ++i)
+    {
+        BOOST_CHECK(recievedTimestamp[1].counterValues[i].counterId == gpuCounters[i]);
+        BOOST_CHECK(recievedTimestamp[1].counterValues[i].counterValue == i + 1u);
+    }
+    sendCounterPacket.ClearTimestamps();
+
+    // Request random counters with duplicates and invalid counters
+    counterValues = {0, 0, 200, cpuCounters[2], gpuCounters[0],3 ,30, cpuCounters[0],cpuCounters[2], gpuCounters[1], 3,
+                     90, 0, 30, gpuCounters[0], gpuCounters[0]};
+
+    periodicCounterSelectionCommandHandler(PacketWriter(period, counterValues));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 2);
+    BOOST_CHECK(activeIds.find(cpuAccId) != activeIds.end());
+    BOOST_CHECK(activeIds.find(gpuAccId) != activeIds.end());
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+
+    BOOST_CHECK(recievedTimestamp.size() == 2);
+
+    BOOST_CHECK(recievedTimestamp[0].counterValues.size() == 2);
+
+    BOOST_CHECK(recievedTimestamp[0].counterValues[0].counterId == cpuCounters[0]);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[0].counterValue == 1u);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[1].counterId == cpuCounters[2]);
+    BOOST_CHECK(recievedTimestamp[0].counterValues[1].counterValue == 3u);
+
+    BOOST_CHECK(recievedTimestamp[1].counterValues.size() == 2);
+
+    BOOST_CHECK(recievedTimestamp[1].counterValues[0].counterId == gpuCounters[0]);
+    BOOST_CHECK(recievedTimestamp[1].counterValues[0].counterValue == 1u);
+    BOOST_CHECK(recievedTimestamp[1].counterValues[1].counterId == gpuCounters[1]);
+    BOOST_CHECK(recievedTimestamp[1].counterValues[1].counterValue == 2u);
+
+    sendCounterPacket.ClearTimestamps();
+
+    // Request no counters
+    periodicCounterSelectionCommandHandler(PacketWriter(period, {}));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 0);
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+    BOOST_CHECK(recievedTimestamp.size() == 0);
+
+    sendCounterPacket.ClearTimestamps();
+
+    // Request period of zero
+    periodicCounterSelectionCommandHandler(PacketWriter(0, counterValues));
+    periodicCounterCapture.Stop();
+
+    activeIds = holder.GetCaptureData().GetActiveBackends();
+    BOOST_CHECK(activeIds.size() == 0);
+
+    recievedTimestamp = sendCounterPacket.GetTimestamps();
+    BOOST_CHECK(recievedTimestamp.size() == 0);
+}
+
+BOOST_AUTO_TEST_CASE(TestBackendCounterLogging)
+{
+    std::stringstream ss;
+
+    struct StreamRedirector
+    {
+    public:
+        StreamRedirector(std::ostream &stream, std::streambuf *newStreamBuffer)
+                : m_Stream(stream), m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer))
+        {}
+
+        ~StreamRedirector()
+        { m_Stream.rdbuf(m_BackupBuffer); }
+
+    private:
+        std::ostream &m_Stream;
+        std::streambuf *m_BackupBuffer;
+    };
+
+    Holder holder;
+    PacketVersionResolver packetVersionResolver;
+    ProfilingStateMachine stateMachine;
+    ReadCounterVals readCounterVals;
+    StreamRedirector redirect(std::cout, ss.rdbuf());
+    CounterIdMap counterIdMap;
+    MockBackendSendCounterPacket sendCounterPacket;
+
+    const armnn::BackendId cpuAccId(armnn::Compute::CpuAcc);
+    const armnn::BackendId gpuAccId(armnn::Compute::GpuAcc);
+
+    armnn::IRuntime::CreationOptions options;
+    options.m_ProfilingOptions.m_EnableProfiling = true;
+
+    armnn::profiling::ProfilingService& profilingService = armnn::profiling::ProfilingService::Instance();
+
+    std::unique_ptr<armnn::profiling::IBackendProfiling> cpuBackendProfilingPtr =
+            std::make_unique<BackendProfiling>(options, profilingService, cpuAccId);
+
+    std::shared_ptr<armnn::profiling::IBackendProfilingContext> cpuProfilingContextPtr =
+            std::make_shared<armnn::MockBackendProfilingContext>(cpuBackendProfilingPtr);
+
+    std::unordered_map<armnn::BackendId,
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContexts;
+
+    uint16_t globalId = 5;
+    counterIdMap.RegisterMapping(globalId, 0, cpuAccId);
+    backendProfilingContexts[cpuAccId] = cpuProfilingContextPtr;
+
+    PeriodicCounterCapture periodicCounterCapture(holder, sendCounterPacket, readCounterVals,
+                                                  counterIdMap, backendProfilingContexts);
+
+    uint16_t maxArmnnCounterId = 4;
+
+    PeriodicCounterSelectionCommandHandler periodicCounterSelectionCommandHandler(0,
+                                                  4,
+                                                  packetVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
+                                                  backendProfilingContexts,
+                                                  counterIdMap,
+                                                  holder,
+                                                  maxArmnnCounterId,
+                                                  periodicCounterCapture,
+                                                  readCounterVals,
+                                                  sendCounterPacket,
+                                                  stateMachine);
+
+    stateMachine.TransitionToState(ProfilingState::NotConnected);
+    stateMachine.TransitionToState(ProfilingState::WaitingForAck);
+    stateMachine.TransitionToState(ProfilingState::Active);
+
+    uint32_t period = 15939u;
+
+    armnn::SetAllLoggingSinks(true, false, false);
+    SetLogFilter(armnn::LogSeverity::Warning);
+    periodicCounterSelectionCommandHandler(PacketWriter(period, {5}));
+    periodicCounterCapture.Stop();
+    SetLogFilter(armnn::LogSeverity::Fatal);
+
+    BOOST_CHECK(boost::contains(ss.str(), "ActivateCounters example test error"));
+}
+
 BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file
index 21ce7ab..641d67f 100644 (file)
@@ -13,6 +13,8 @@
 #include <armnn/backends/IBackendInternal.hpp>
 #include <armnn/backends/OptimizationViews.hpp>
 #include <backendsCommon/LayerSupportBase.hpp>
+#include <armnn/backends/profiling/IBackendProfiling.hpp>
+#include <backends/BackendProfiling.hpp>
 
 namespace armnn
 {
@@ -53,6 +55,7 @@ class MockBackendProfilingContext : public profiling::IBackendProfilingContext
 public:
     MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
         : m_BackendProfiling(backendProfiling)
+        , m_CapturePeriod(0)
     {}
 
     ~MockBackendProfilingContext() = default;
@@ -81,19 +84,42 @@ public:
             return nextMaxGlobalCounterId;
     }
 
-    void ActivateCounters(uint32_t, const std::vector<uint16_t>&)
-    {}
+    Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
+    {
+        if ( capturePeriod == 0 || counterIds.size() == 0)
+        {
+            m_ActiveCounters.clear();
+        }
+        else if (capturePeriod == 15939u)
+        {
+            return armnn::Optional<std::string>("ActivateCounters example test error");
+        }
+        m_CapturePeriod = capturePeriod;
+        m_ActiveCounters = counterIds;
+        return armnn::Optional<std::string>();
+    }
 
     std::vector<profiling::Timestamp> ReportCounterValues()
     {
-        return std::vector<profiling::Timestamp>();
+        std::vector<profiling::CounterValue> counterValues;
+
+        for(auto counterId : m_ActiveCounters)
+        {
+            counterValues.emplace_back(profiling::CounterValue{counterId, counterId+1u});
+        }
+
+        uint64_t timestamp = m_CapturePeriod;
+        return  {profiling::Timestamp{timestamp, counterValues}};
     }
 
     void EnableProfiling(bool)
     {}
 
 private:
+
     IBackendInternal::IBackendProfilingPtr& m_BackendProfiling;
+    uint32_t m_CapturePeriod;
+    std::vector<uint16_t> m_ActiveCounters;
 };
 
 class MockBackend : public IBackendInternal
index 41c2993..a366898 100644 (file)
@@ -3,6 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 
+#include <armnn/BackendId.hpp>
 #include "Holder.hpp"
 
 namespace armnn
@@ -13,12 +14,18 @@ namespace profiling
 
 CaptureData& CaptureData::operator=(const CaptureData& other)
 {
-    m_CapturePeriod = other.m_CapturePeriod;
-    m_CounterIds    = other.m_CounterIds;
+    m_CapturePeriod  = other.m_CapturePeriod;
+    m_CounterIds     = other.m_CounterIds;
+    m_ActiveBackends = other.m_ActiveBackends;
 
     return *this;
 }
 
+void CaptureData::SetActiveBackends(const std::set<armnn::BackendId>& activeBackends)
+{
+    m_ActiveBackends = activeBackends;
+}
+
 void CaptureData::SetCapturePeriod(uint32_t capturePeriod)
 {
     m_CapturePeriod = capturePeriod;
@@ -29,6 +36,11 @@ void CaptureData::SetCounterIds(const std::vector<uint16_t>& counterIds)
     m_CounterIds = counterIds;
 }
 
+const std::set<armnn::BackendId>& CaptureData::GetActiveBackends() const
+{
+    return m_ActiveBackends;
+}
+
 uint32_t CaptureData::GetCapturePeriod() const
 {
     return m_CapturePeriod;
@@ -59,12 +71,16 @@ bool CaptureData::IsCounterIdInCaptureData(uint16_t counterId)
     return false;
 }
 
-void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
+void Holder::SetCaptureData(uint32_t capturePeriod,
+                            const std::vector<uint16_t>& counterIds,
+                            const std::set<armnn::BackendId>& activeBackends)
 {
     std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
 
     m_CaptureData.SetCapturePeriod(capturePeriod);
     m_CaptureData.SetCounterIds(counterIds);
+    m_CaptureData.SetActiveBackends(activeBackends);
+
 }
 
 } // namespace profiling
index 9785b98..8a89cda 100644 (file)
@@ -7,6 +7,8 @@
 
 #include <mutex>
 #include <vector>
+#include <set>
+#include "ProfilingUtils.hpp"
 
 namespace armnn
 {
@@ -19,25 +21,31 @@ class CaptureData
 public:
     CaptureData()
         : m_CapturePeriod(0)
-        , m_CounterIds() {}
-    CaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds)
+        , m_CounterIds()
+        , m_ActiveBackends(){}
+    CaptureData(uint32_t capturePeriod, std::vector<uint16_t>& counterIds, std::set<armnn::BackendId> activeBackends)
         : m_CapturePeriod(capturePeriod)
-        , m_CounterIds(counterIds) {}
+        , m_CounterIds(counterIds)
+        , m_ActiveBackends(activeBackends){}
     CaptureData(const CaptureData& captureData)
         : m_CapturePeriod(captureData.m_CapturePeriod)
-        , m_CounterIds(captureData.m_CounterIds) {}
+        , m_CounterIds(captureData.m_CounterIds)
+        , m_ActiveBackends(captureData.m_ActiveBackends){}
 
     CaptureData& operator=(const CaptureData& other);
 
+    void SetActiveBackends(const std::set<armnn::BackendId>& activeBackends);
     void SetCapturePeriod(uint32_t capturePeriod);
     void SetCounterIds(const std::vector<uint16_t>& counterIds);
     uint32_t GetCapturePeriod() const;
     const std::vector<uint16_t>& GetCounterIds() const;
+    const std::set<armnn::BackendId>& GetActiveBackends() const;
     bool IsCounterIdInCaptureData(uint16_t counterId);
 
 private:
     uint32_t m_CapturePeriod;
     std::vector<uint16_t> m_CounterIds;
+    std::set<armnn::BackendId> m_ActiveBackends;
 };
 
 class Holder
@@ -46,7 +54,9 @@ public:
     Holder()
         : m_CaptureData() {}
     CaptureData GetCaptureData() const;
-    void SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds);
+    void SetCaptureData(uint32_t capturePeriod,
+                        const std::vector<uint16_t>& counterIds,
+                        const std::set<armnn::BackendId>& activeBackends);
 
 private:
     mutable std::mutex m_CaptureThreadMutex;
index 5c8e6b8..d87a042 100644 (file)
@@ -5,6 +5,7 @@
 
 #pragma once
 
+#include <armnn/backends/profiling/IBackendProfiling.hpp>
 #include "ICounterDirectory.hpp"
 
 namespace armnn
@@ -16,7 +17,7 @@ namespace profiling
 class ISendCounterPacket
 {
 public:
-    using IndexValuePairsVector = std::vector<std::pair<uint16_t, uint32_t>>;
+    using IndexValuePairsVector = std::vector<CounterValue>;
 
     virtual ~ISendCounterPacket() {}
 
index d60cbd7..b143295 100644 (file)
@@ -55,6 +55,24 @@ CaptureData PeriodicCounterCapture::ReadCaptureData()
     return m_CaptureDataHolder.GetCaptureData();
 }
 
+void PeriodicCounterCapture::DispatchPeriodicCounterCapturePacket(
+    const armnn::BackendId& backendId, const std::vector<Timestamp>& timestampValues)
+{
+    // Report counter values
+    for (const auto timestampInfo : timestampValues)
+    {
+        std::vector<CounterValue> backendCounterValues = timestampInfo.counterValues;
+        for_each(backendCounterValues.begin(), backendCounterValues.end(), [&](CounterValue& backendCounterValue)
+        {
+            // translate the counterId to globalCounterId
+            backendCounterValue.counterId = m_CounterIdMap.GetGlobalId(backendCounterValue.counterId, backendId);
+        });
+
+        // Send Periodic Counter Capture Packet for the Timestamp
+        m_SendCounterPacket.SendPeriodicCounterCapturePacket(timestampInfo.timestamp, backendCounterValues);
+    }
+}
+
 void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues)
 {
     do
@@ -62,50 +80,60 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
         // Check if the current capture data indicates that there's data capture
         auto currentCaptureData = ReadCaptureData();
         const std::vector<uint16_t>& counterIds = currentCaptureData.GetCounterIds();
+        const uint32_t capturePeriod = currentCaptureData.GetCapturePeriod();
 
-        if (currentCaptureData.GetCapturePeriod() == 0 || counterIds.empty())
+        if (capturePeriod == 0)
         {
-            // No data capture, wait the indicated capture period (milliseconds)
-            std::this_thread::sleep_for(std::chrono::milliseconds(5));
+            // No data capture, wait the indicated capture period (milliseconds), if it is not zero
+            std::this_thread::sleep_for(std::chrono::milliseconds(50u));
             continue;
         }
 
-        std::vector<std::pair<uint16_t, uint32_t>> values;
-        auto numCounters = counterIds.size();
-        values.reserve(numCounters);
-
-        // Create a vector of pairs of CounterIndexes and Values
-        for (uint16_t index = 0; index < numCounters; ++index)
+        if(counterIds.size() != 0)
         {
-            auto requestedId = counterIds[index];
-            uint32_t counterValue = 0;
-            try
-            {
-                counterValue = readCounterValues.GetCounterValue(requestedId);
-            }
-            catch (const Exception& e)
+            std::vector<CounterValue> counterValues;
+
+            auto numCounters = counterIds.size();
+            counterValues.reserve(numCounters);
+
+            // Create a vector of pairs of CounterIndexes and Values
+            for (uint16_t index = 0; index < numCounters; ++index)
             {
-                // Report the error and continue
-                ARMNN_LOG(warning) << "An error has occurred when getting a counter value: "
-                                           << e.what();
-                continue;
+                auto requestedId = counterIds[index];
+                uint32_t counterValue = 0;
+                try
+                {
+                    counterValue = readCounterValues.GetCounterValue(requestedId);
+                }
+                catch (const Exception& e)
+                {
+                    // Report the error and continue
+                    ARMNN_LOG(warning) << "An error has occurred when getting a counter value: "
+                                       << e.what();
+                    continue;
+                }
+
+                counterValues.emplace_back(CounterValue {requestedId, counterValue });
             }
-            values.emplace_back(std::make_pair(requestedId, counterValue));
-        }
 
-        // Take a timestamp
-        uint64_t timestamp = GetTimestamp();
+            // Send Periodic Counter Capture Packet for the Timestamp
+            m_SendCounterPacket.SendPeriodicCounterCapturePacket(GetTimestamp(), counterValues);
+        }
 
-        // Write a Periodic Counter Capture packet to the Counter Stream Buffer
-        m_SendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, values);
+        // Report counter values for each active backend
+        auto activeBackends = currentCaptureData.GetActiveBackends();
+        for_each(activeBackends.begin(), activeBackends.end(), [&](const armnn::BackendId& backendId)
+        {
+            DispatchPeriodicCounterCapturePacket(
+                backendId, m_BackendProfilingContext.at(backendId)->ReportCounterValues());
+        });
 
         // Wait the indicated capture period (microseconds)
-        std::this_thread::sleep_for(std::chrono::microseconds(currentCaptureData.GetCapturePeriod()));
-
+        std::this_thread::sleep_for(std::chrono::microseconds(capturePeriod));
     }
     while (m_KeepRunning.load());
 }
 
 } // namespace profiling
 
-} // namespace armnn
+} // namespace armnn
\ No newline at end of file
index 9229a49..ff05623 100644 (file)
 #include "Packet.hpp"
 #include "SendCounterPacket.hpp"
 #include "ICounterValues.hpp"
+#include "CounterIdMap.hpp"
 
 #include <atomic>
 #include <mutex>
 #include <thread>
+#include <armnn/backends/profiling/IBackendProfilingContext.hpp>
 
 namespace armnn
 {
@@ -24,12 +26,20 @@ namespace profiling
 class PeriodicCounterCapture final : public IPeriodicCounterCapture
 {
 public:
-    PeriodicCounterCapture(const Holder& data, ISendCounterPacket& packet, const IReadCounterValues& readCounterValue)
-        : m_CaptureDataHolder(data)
-        , m_IsRunning(false)
-        , m_KeepRunning(false)
-        , m_ReadCounterValues(readCounterValue)
-        , m_SendCounterPacket(packet)
+    PeriodicCounterCapture(const Holder& data,
+                           ISendCounterPacket& packet,
+                           const IReadCounterValues& readCounterValue,
+                           const ICounterMappings& counterIdMap,
+                           const std::unordered_map<armnn::BackendId,
+                                   std::shared_ptr<armnn::profiling::IBackendProfilingContext>>&
+                           backendProfilingContexts)
+            : m_CaptureDataHolder(data)
+            , m_IsRunning(false)
+            , m_KeepRunning(false)
+            , m_ReadCounterValues(readCounterValue)
+            , m_SendCounterPacket(packet)
+            , m_CounterIdMap(counterIdMap)
+            , m_BackendProfilingContext(backendProfilingContexts)
     {}
     ~PeriodicCounterCapture() { Stop(); }
 
@@ -40,6 +50,8 @@ public:
 private:
     CaptureData ReadCaptureData();
     void Capture(const IReadCounterValues& readCounterValues);
+    void DispatchPeriodicCounterCapturePacket(
+            const armnn::BackendId& backendId, const std::vector<Timestamp>& timestampValues);
 
     const Holder&             m_CaptureDataHolder;
     bool                      m_IsRunning;
@@ -47,6 +59,9 @@ private:
     std::thread               m_PeriodCaptureThread;
     const IReadCounterValues& m_ReadCounterValues;
     ISendCounterPacket&       m_SendCounterPacket;
+    const ICounterMappings&   m_CounterIdMap;
+    const std::unordered_map<armnn::BackendId,
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContext;
 };
 
 } // namespace profiling
index 4a051b8..d218433 100644 (file)
@@ -101,13 +101,48 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
                 // Invalid counter UID, ignore it and continue
                 continue;
             }
-
             // The counter is valid
-            validCounterIds.push_back(counterId);
+            validCounterIds.emplace_back(counterId);
+        }
+
+        std::sort(validCounterIds.begin(), validCounterIds.end());
+
+        auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
+        {
+            return counterId > m_MaxArmCounterId;
+        });
+
+        std::set<armnn::BackendId> activeBackends;
+        std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
+
+        if (m_BackendCounterMap.size() != 0)
+        {
+            std::set<uint16_t> newCounterIds;
+            std::set<uint16_t> unusedCounterIds;
+
+            // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
+            std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
+                                m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
+                                std::inserter(newCounterIds, newCounterIds.begin()));
+
+            // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
+            std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
+                                backendCounterIds.begin(), backendCounterIds.end(),
+                                std::inserter(unusedCounterIds, unusedCounterIds.begin()));
+
+            activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
+        }
+        else
+        {
+            activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
         }
 
-        // Set the capture data with only the valid counter UIDs
-        m_CaptureDataHolder.SetCaptureData(capturePeriod, validCounterIds);
+        // save the new backend counter ids for next time
+        m_PrevBackendCounterIds = backendCounterIds;
+
+
+        // Set the capture data with only the valid armnn counter UIDs
+        m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
 
         // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
         m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
@@ -131,6 +166,68 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
     }
 }
 
+std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
+                                                                      const u_int32_t capturePeriod,
+                                                                      std::set<uint16_t> newCounterIds,
+                                                                      std::set<uint16_t> unusedCounterIds)
+{
+    std::set<armnn::BackendId> changedBackends;
+    std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
+
+    for (uint16_t counterId : newCounterIds)
+    {
+        auto backendId = m_CounterIdMap.GetBackendId(counterId);
+        m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
+        changedBackends.insert(backendId.second);
+    }
+    // Add any new backends to active backends
+    activeBackends.insert(changedBackends.begin(), changedBackends.end());
+
+    for (uint16_t counterId : unusedCounterIds)
+    {
+        auto backendId = m_CounterIdMap.GetBackendId(counterId);
+        std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
+
+        backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
+
+        if(backendCounters.size() == 0)
+        {
+            // If a backend has no counters associated with it we remove it from active backends and
+            // send a capture period of zero with an empty vector, this will deactivate all the backends counters
+            activeBackends.erase(backendId.second);
+            ActivateBackedCounters(backendId.second, 0, {});
+        }
+        else
+        {
+            changedBackends.insert(backendId.second);
+        }
+    }
+
+    // If the capture period remains the same we only need to update the backends who's counters have changed
+    if(capturePeriod == m_PrevCapturePeriod)
+    {
+        for (auto backend : changedBackends)
+        {
+            ActivateBackedCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
+        }
+    }
+    // Otherwise update all the backends with the new capture period and any new/unused counters
+    else
+    {
+        for (auto backend : m_BackendCounterMap)
+        {
+            ActivateBackedCounters(backend.first, capturePeriod, backend.second);
+        }
+        if(capturePeriod == 0)
+        {
+            activeBackends = {};
+        }
+        m_PrevCapturePeriod = capturePeriod;
+    }
+
+    return activeBackends;
+}
+
 } // namespace profiling
 
 } // namespace armnn
index c974747..437d712 100644 (file)
@@ -5,6 +5,7 @@
 
 #pragma once
 
+#include "CounterIdMap.hpp"
 #include "Packet.hpp"
 #include "CommandHandlerFunctor.hpp"
 #include "Holder.hpp"
 #include "IPeriodicCounterCapture.hpp"
 #include "ICounterValues.hpp"
 
+#include "armnn/backends/profiling/IBackendProfilingContext.hpp"
+#include "armnn/Logging.hpp"
+#include "armnn/BackendRegistry.hpp"
+
+#include <set>
+
+
 namespace armnn
 {
 
 namespace profiling
 {
 
+
 class PeriodicCounterSelectionCommandHandler : public CommandHandlerFunctor
 {
 
@@ -26,29 +35,66 @@ public:
     PeriodicCounterSelectionCommandHandler(uint32_t familyId,
                                            uint32_t packetId,
                                            uint32_t version,
+                                           const std::unordered_map<BackendId,
+                                                   std::shared_ptr<armnn::profiling::IBackendProfilingContext>>&
+                                           backendProfilingContext,
+                                           const ICounterMappings& counterIdMap,
                                            Holder& captureDataHolder,
+                                           const uint16_t maxArmnnCounterId,
                                            IPeriodicCounterCapture& periodicCounterCapture,
                                            const IReadCounterValues& readCounterValue,
                                            ISendCounterPacket& sendCounterPacket,
                                            const ProfilingStateMachine& profilingStateMachine)
         : CommandHandlerFunctor(familyId, packetId, version)
+        , m_BackendProfilingContext(backendProfilingContext)
+        , m_CounterIdMap(counterIdMap)
         , m_CaptureDataHolder(captureDataHolder)
+        , m_MaxArmCounterId(maxArmnnCounterId)
         , m_PeriodicCounterCapture(periodicCounterCapture)
+        , m_PrevCapturePeriod(0)
         , m_ReadCounterValues(readCounterValue)
         , m_SendCounterPacket(sendCounterPacket)
         , m_StateMachine(profilingStateMachine)
-    {}
+
+    {
+
+    }
 
     void operator()(const Packet& packet) override;
 
 private:
+
+    std::unordered_map<armnn::BackendId, std::vector<uint16_t>> m_BackendCounterMap;
+    const std::unordered_map<BackendId,
+          std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContext;
+    const ICounterMappings& m_CounterIdMap;
     Holder& m_CaptureDataHolder;
+    const uint16_t m_MaxArmCounterId;
     IPeriodicCounterCapture& m_PeriodicCounterCapture;
+    uint32_t m_PrevCapturePeriod;
+    std::set<uint16_t> m_PrevBackendCounterIds;
     const IReadCounterValues& m_ReadCounterValues;
     ISendCounterPacket& m_SendCounterPacket;
     const ProfilingStateMachine& m_StateMachine;
 
+    void ActivateBackedCounters(const armnn::BackendId backendId,
+                                const uint32_t capturePeriod,
+                                const std::vector<uint16_t> counterIds)
+    {
+        Optional<std::string> errorMsg =
+                m_BackendProfilingContext.at(backendId)->ActivateCounters(capturePeriod, counterIds);
+
+        if(errorMsg.has_value())
+        {
+            ARMNN_LOG(warning) << "An error has occurred when activating counters of " << backendId << ": "
+                               << errorMsg.value();
+        }
+    }
     void ParseData(const Packet& packet, CaptureData& captureData);
+    std::set<armnn::BackendId> ProcessBackendCounterIds(const u_int32_t capturePeriod,
+                                                        std::set<uint16_t> newCounterIds,
+                                                        std::set<uint16_t> unusedCounterIds);
+
 };
 
 } // namespace profiling
index 27b05a6..b07465f 100644 (file)
@@ -238,9 +238,11 @@ CaptureData ProfilingService::GetCaptureData()
     return m_Holder.GetCaptureData();
 }
 
-void ProfilingService::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
+void ProfilingService::SetCaptureData(uint32_t capturePeriod,
+                                      const std::vector<uint16_t>& counterIds,
+                                      const std::set<BackendId>& activeBackends)
 {
-    m_Holder.SetCaptureData(capturePeriod, counterIds);
+    m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
 }
 
 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
index 54c6540..2584c76 100644 (file)
@@ -36,6 +36,7 @@ static const uint16_t NETWORK_UNLOADS       =   1;
 static const uint16_t REGISTERED_BACKENDS   =   2;
 static const uint16_t UNREGISTERED_BACKENDS =   3;
 static const uint16_t INFERENCES_RUN        =   4;
+static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN;
 
 class ProfilingService : public IReadWriteCounterValues, public IProfilingGuidGenerator
 {
@@ -83,7 +84,9 @@ public:
     bool IsProfilingEnabled();
 
     CaptureData GetCaptureData();
-    void SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds);
+    void SetCaptureData(uint32_t capturePeriod,
+                        const std::vector<uint16_t>& counterIds,
+                        const std::set<BackendId>& activeBackends);
 
     // Setters for the profiling service state
     void SetCounterValue(uint16_t counterUid, uint32_t value) override;
@@ -143,7 +146,7 @@ private:
     ProfilingGuidGenerator m_GuidGenerator;
     TimelinePacketWriterFactory m_TimelinePacketWriterFactory;
     std::unordered_map<BackendId,
-        std::shared_ptr<armnn::profiling::IBackendProfilingContext>> m_BackendProfilingContexts;
+    std::shared_ptr<armnn::profiling::IBackendProfilingContext>> m_BackendProfilingContexts;
     uint16_t m_MaxGlobalCounterId;
 
 protected:
@@ -166,7 +169,7 @@ protected:
         , m_SendCounterPacket(m_BufferManager)
         , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket)
         , m_SendTimelinePacket(m_BufferManager)
-        , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this)
+        , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this, m_CounterIdMap, m_BackendProfilingContexts)
         , m_ConnectionAcknowledgedCommandHandler(0,
                                                  1,
                                                  m_PacketVersionResolver.ResolvePacketVersion(0, 1).GetEncodedValue(),
@@ -184,7 +187,10 @@ protected:
         , m_PeriodicCounterSelectionCommandHandler(0,
                                                    4,
                                                    m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
+                                                   m_BackendProfilingContexts,
+                                                   m_CounterIdMap,
                                                    m_Holder,
+                                                   MAX_ARMNN_COUNTER,
                                                    m_PeriodicCounterCapture,
                                                    *this,
                                                    m_SendCounterPacket,
index 942caec..f60586e 100644 (file)
@@ -850,9 +850,9 @@ void SendCounterPacket::SendPeriodicCounterCapturePacket(uint64_t timestamp, con
     offset += uint64_t_size;
     for (const auto& pair: values)
     {
-        WriteUint16(writeBuffer, offset, pair.first);
+        WriteUint16(writeBuffer, offset, pair.counterId);
         offset += uint16_t_size;
-        WriteUint32(writeBuffer, offset, pair.second);
+        WriteUint32(writeBuffer, offset, pair.counterValue);
         offset += uint32_t_size;
     }
 
index 5a10711..1880a2a 100644 (file)
@@ -25,7 +25,7 @@ public:
     using DeviceRecord          = std::vector<uint32_t>;
     using CounterSetRecord      = std::vector<uint32_t>;
     using EventRecord           = std::vector<uint32_t>;
-    using IndexValuePairsVector = std::vector<std::pair<uint16_t, uint32_t>>;
+    using IndexValuePairsVector = std::vector<CounterValue>;
 
     SendCounterPacket(IBufferManager& buffer)
         : m_BufferManager(buffer)
index 0bad66f..af9f1b8 100644 (file)
@@ -536,7 +536,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingStateMachine)
 
 void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
 {
-    holder.SetCaptureData(capturePeriod, counterIds);
+    holder.SetCaptureData(capturePeriod, counterIds, {});
 }
 
 void CaptureDataReadThreadImpl(const Holder& holder, CaptureData& captureData)
@@ -1764,6 +1764,9 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
     const uint32_t packetId = 0x40000;
 
     uint32_t version = 1;
+    const std::unordered_map<armnn::BackendId,
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContext;
+    CounterIdMap counterIdMap;
     Holder holder;
     TestCaptureThread captureThread;
     TestReadCounterValues readCounterValues;
@@ -1790,7 +1793,8 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
 
     Packet packetA(packetId, dataLength1, uniqueData1);
 
-    PeriodicCounterSelectionCommandHandler commandHandler(familyId, packetId, version, holder, captureThread,
+    PeriodicCounterSelectionCommandHandler commandHandler(familyId, packetId, version, backendProfilingContext,
+                                                          counterIdMap, holder, 10000u, captureThread,
                                                           readCounterValues, sendCounterPacket, profilingStateMachine);
 
     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
@@ -2157,6 +2161,9 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
 
     ProfilingStateMachine profilingStateMachine;
 
+    const std::unordered_map<armnn::BackendId,
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>> backendProfilingContext;
+    CounterIdMap counterIdMap;
     Holder data;
     std::vector<uint16_t> captureIds1 = { 0, 1 };
     std::vector<uint16_t> captureIds2;
@@ -2172,11 +2179,12 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
     unsigned int valueB   = 15;
     unsigned int numSteps = 5;
 
-    PeriodicCounterCapture periodicCounterCapture(std::ref(data), std::ref(sendCounterPacket), captureReader);
+    PeriodicCounterCapture periodicCounterCapture(std::ref(data), std::ref(sendCounterPacket), captureReader,
+                                                  counterIdMap, backendProfilingContext);
 
     for (unsigned int i = 0; i < numSteps; ++i)
     {
-        data.SetCaptureData(1, captureIds1);
+        data.SetCaptureData(1, captureIds1, {});
         captureReader.SetCounterValue(0, valueA * (i + 1));
         captureReader.SetCounterValue(1, valueB * (i + 1));
 
@@ -3344,7 +3352,7 @@ BOOST_AUTO_TEST_CASE(CheckCounterStatusQuery)
     const uint32_t newCapturePeriod = 100;
 
     // Set capture period and active counters in CaptureData
-    profilingService.SetCaptureData(capturePeriod, activeGlobalCounterIds);
+    profilingService.SetCaptureData(capturePeriod, activeGlobalCounterIds, {});
 
     // Get vector of active counters for CpuRef and CpuAcc backends
     std::vector<CounterStatus> cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
@@ -3373,7 +3381,7 @@ BOOST_AUTO_TEST_CASE(CheckCounterStatusQuery)
     BOOST_CHECK_EQUAL(inactiveCpuAccCounter.m_Enabled, false);
 
     // Set new capture period and new active counters in CaptureData
-    profilingService.SetCaptureData(newCapturePeriod, newActiveGlobalCounterIds);
+    profilingService.SetCaptureData(newCapturePeriod, newActiveGlobalCounterIds, {});
 
     // Get vector of active counters for CpuRef and CpuAcc backends
     cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
index 9ec24e5..b87583c 100644 (file)
@@ -121,7 +121,7 @@ BOOST_AUTO_TEST_CASE(MockSendCounterPacketTest)
     mockBuffer.MarkRead(packetBuffer);
 
     uint64_t timestamp = 0;
-    std::vector<std::pair<uint16_t, uint32_t>> indexValuePairs;
+    std::vector<CounterValue> indexValuePairs;
 
     mockSendCounterPacket.SendPeriodicCounterCapturePacket(timestamp, indexValuePairs);
 
@@ -215,7 +215,7 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest)
 
     auto captureTimestamp = std::chrono::steady_clock::now();
     uint64_t time =  static_cast<uint64_t >(captureTimestamp.time_since_epoch().count());
-    std::vector<std::pair<uint16_t, uint32_t>> indexValuePairs;
+    std::vector<CounterValue> indexValuePairs;
 
     BOOST_CHECK_THROW(sendPacket1.SendPeriodicCounterCapturePacket(time, indexValuePairs),
                       BufferExhaustion);
@@ -242,11 +242,11 @@ BOOST_AUTO_TEST_CASE(SendPeriodicCounterCapturePacketTest)
     SendCounterPacket sendPacket3(mockBuffer3);
 
     indexValuePairs.reserve(5);
-    indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(0, 100));
-    indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(1, 200));
-    indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(2, 300));
-    indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(3, 400));
-    indexValuePairs.emplace_back(std::make_pair<uint16_t, uint32_t >(4, 500));
+    indexValuePairs.emplace_back(CounterValue{0, 100});
+    indexValuePairs.emplace_back(CounterValue{1, 200});
+    indexValuePairs.emplace_back(CounterValue{2, 300});
+    indexValuePairs.emplace_back(CounterValue{3, 400});
+    indexValuePairs.emplace_back(CounterValue{4, 500});
     sendPacket3.SendPeriodicCounterCapturePacket(time, indexValuePairs);
     auto readBuffer3 = mockBuffer3.GetReadableBuffer();
 
index c7fc7b8..4118989 100644 (file)
@@ -406,7 +406,7 @@ public:
     }
 
     void SendPeriodicCounterCapturePacket(uint64_t timestamp,
-                                          const std::vector<std::pair<uint16_t, uint32_t>>& values) override
+                                          const std::vector<CounterValue>& values) override
     {
         boost::ignore_unused(timestamp, values);