IVGCVSW-3964 Implement the Periodic Counter Selection command handler
authorMatteo Martincigh <matteo.martincigh@arm.com>
Thu, 10 Oct 2019 13:08:21 +0000 (14:08 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Fri, 11 Oct 2019 15:33:29 +0000 (16:33 +0100)
 * Improved the PeriodicCounterPacket class to handle errors properly
 * Improved the PeriodicCounterSelectionCommandHandler to handle
   invalid counter UIDs in the selection packet
 * Added the Periodic Counter Selection command handler to the
   ProfilingService class
 * Code refactoring and added comments
 * Added WaitForPacketSent method to the SendCounterPacket class
   to allow waiting for the packets to be sent (useful in the
   unit tests)
 * Added unit tests and updated the old ones accordingly
 * Fixed threading issues with a number of unit tests

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: I271b7b0bfa801d88fe1725b934d24e30cd839ed7

15 files changed:
src/profiling/ConnectionAcknowledgedCommandHandler.cpp
src/profiling/Holder.cpp
src/profiling/Holder.hpp
src/profiling/ICounterValues.hpp
src/profiling/PeriodicCounterCapture.cpp
src/profiling/PeriodicCounterSelectionCommandHandler.cpp
src/profiling/PeriodicCounterSelectionCommandHandler.hpp
src/profiling/ProfilingService.cpp
src/profiling/ProfilingService.hpp
src/profiling/RequestCounterDirectoryCommandHandler.cpp
src/profiling/SendCounterPacket.cpp
src/profiling/SendCounterPacket.hpp
src/profiling/test/ProfilingTests.cpp
src/profiling/test/ProfilingTests.hpp
src/profiling/test/SendCounterPacketTests.hpp

index 9d2d1a2bd2c8d2ae4bfa0af62427456181185adf..deffd1414b62b50f64c1a3c699bd854c67f39493 100644 (file)
@@ -22,7 +22,7 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
     {
     case ProfilingState::Uninitialised:
     case ProfilingState::NotConnected:
-        throw RuntimeException(boost::str(boost::format("Connection Acknowledged Handler invoked while in an "
+        throw RuntimeException(boost::str(boost::format("Connection Acknowledged Command Handler invoked while in an "
                                                         "wrong state: %1%")
                                           % GetProfilingStateName(currentState)));
     case ProfilingState::WaitingForAck:
index 5916017eb69aaccde93644f1e28524406d8ea483..750be7ec74b3dda613dccbddca109346107b2e52 100644 (file)
@@ -11,10 +11,10 @@ namespace armnn
 namespace profiling
 {
 
-CaptureData& CaptureData::operator= (const CaptureData& captureData)
+CaptureData& CaptureData::operator=(const CaptureData& other)
 {
-    m_CapturePeriod = captureData.m_CapturePeriod;
-    m_CounterIds    = captureData.m_CounterIds;
+    m_CapturePeriod = other.m_CapturePeriod;
+    m_CounterIds    = other.m_CounterIds;
 
     return *this;
 }
@@ -29,12 +29,12 @@ void CaptureData::SetCounterIds(const std::vector<uint16_t>& counterIds)
     m_CounterIds = counterIds;
 }
 
-std::uint32_t CaptureData::GetCapturePeriod() const
+uint32_t CaptureData::GetCapturePeriod() const
 {
     return m_CapturePeriod;
 }
 
-std::vector<uint16_t> CaptureData::GetCounterIds() const
+const std::vector<uint16_t>& CaptureData::GetCounterIds() const
 {
     return m_CounterIds;
 }
@@ -42,12 +42,14 @@ std::vector<uint16_t> CaptureData::GetCounterIds() const
 CaptureData Holder::GetCaptureData() const
 {
     std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+
     return m_CaptureData;
 }
 
 void Holder::SetCaptureData(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
 {
     std::lock_guard<std::mutex> lockGuard(m_CaptureThreadMutex);
+
     m_CaptureData.SetCapturePeriod(capturePeriod);
     m_CaptureData.SetCounterIds(counterIds);
 }
index 72ca0914a921e46d56e43823d3f8fc2890a2b158..3143105ab400b753abbfbacd1804d17329b2de40 100644 (file)
@@ -27,12 +27,12 @@ public:
         : m_CapturePeriod(captureData.m_CapturePeriod)
         , m_CounterIds(captureData.m_CounterIds) {}
 
-    CaptureData& operator= (const CaptureData& captureData);
+    CaptureData& operator=(const CaptureData& other);
 
     void SetCapturePeriod(uint32_t capturePeriod);
     void SetCounterIds(const std::vector<uint16_t>& counterIds);
     uint32_t GetCapturePeriod() const;
-    std::vector<uint16_t> GetCounterIds() const;
+    const std::vector<uint16_t>& GetCounterIds() const;
 
 private:
     uint32_t m_CapturePeriod;
index 5e32ca2b3778b0d683ef221210d4a5a5a73eac23..18e34b67479152c7399b31ee18b1e6895ae468e8 100644 (file)
@@ -18,6 +18,7 @@ class IReadCounterValues
 public:
     virtual ~IReadCounterValues() {}
 
+    virtual bool IsCounterRegistered(uint16_t counterUid) const = 0;
     virtual uint16_t GetCounterCount() const = 0;
     virtual uint32_t GetCounterValue(uint16_t counterUid) const = 0;
 };
index 9002bfc0654d4afff0111eb1473954d9b6926ff0..0ccb516ae22156f621928f598f890fb838ada466 100644 (file)
@@ -5,6 +5,8 @@
 
 #include "PeriodicCounterCapture.hpp"
 
+#include <boost/log/trivial.hpp>
+
 namespace armnn
 {
 
@@ -34,10 +36,13 @@ void PeriodicCounterCapture::Start()
 
 void PeriodicCounterCapture::Stop()
 {
+    // Signal the capture thread to stop
     m_KeepRunning.store(false);
 
+    // Check that the capture thread is running
     if (m_PeriodCaptureThread.joinable())
     {
+        // Wait for the capture thread to complete operations
         m_PeriodCaptureThread.join();
     }
 }
@@ -51,10 +56,12 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
 {
     while (m_KeepRunning.load())
     {
+        // Check if the current capture data indicates that there's data capture
         auto currentCaptureData = ReadCaptureData();
-        std::vector<uint16_t> counterIds = currentCaptureData.GetCounterIds();
+        const std::vector<uint16_t>& counterIds = currentCaptureData.GetCounterIds();
         if (currentCaptureData.GetCapturePeriod() == 0 || counterIds.empty())
         {
+            // No data capture, terminate the thread
             m_KeepRunning.store(false);
             break;
         }
@@ -63,12 +70,22 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
         auto numCounters = counterIds.size();
         values.reserve(numCounters);
 
-        // Create vector of pairs of CounterIndexes and Values
-        uint32_t counterValue = 0;
+        // Create a vector of pairs of CounterIndexes and Values
         for (uint16_t index = 0; index < numCounters; ++index)
         {
             auto requestedId = counterIds[index];
-            counterValue = readCounterValues.GetCounterValue(requestedId);
+            uint32_t counterValue = 0;
+            try
+            {
+                counterValue = readCounterValues.GetCounterValue(requestedId);
+            }
+            catch (const Exception& e)
+            {
+                // Report the error and continue
+                BOOST_LOG_TRIVIAL(warning) << "An error has occurred when getting a counter value: "
+                                           << e.what() << std::endl;
+                continue;
+            }
             values.emplace_back(std::make_pair(requestedId, counterValue));
         }
 
@@ -81,9 +98,15 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
         // Take a timestamp
         auto timestamp = clock::now();
 
+        // Write a Periodic Counter Capture packet to the Counter Stream Buffer
         m_SendCounterPacket.SendPeriodicCounterCapturePacket(
                     static_cast<uint64_t>(timestamp.time_since_epoch().count()), values);
-        std::this_thread::sleep_for(std::chrono::milliseconds(currentCaptureData.GetCapturePeriod()));
+
+        // Notify the Send Thread that new data is available in the Counter Stream Buffer
+        m_SendCounterPacket.SetReadyToRead();
+
+        // Wait the indicated capture period (microseconds)
+        std::this_thread::sleep_for(std::chrono::microseconds(currentCaptureData.GetCapturePeriod()));
     }
 
     m_IsRunning.store(false);
index 9be37fcfd2554a6cd89eef01bf7b9f8af3b2a1ba..db09856dae2e2a715f070639dec1ab2de9cbd963 100644 (file)
@@ -7,6 +7,9 @@
 #include "ProfilingUtils.hpp"
 
 #include <boost/numeric/conversion/cast.hpp>
+#include <boost/format.hpp>
+
+#include <vector>
 
 namespace armnn
 {
@@ -14,57 +17,109 @@ namespace armnn
 namespace profiling
 {
 
-using namespace std;
-using boost::numeric_cast;
-
 void PeriodicCounterSelectionCommandHandler::ParseData(const Packet& packet, CaptureData& captureData)
 {
     std::vector<uint16_t> counterIds;
-    uint32_t sizeOfUint32 = numeric_cast<uint32_t>(sizeof(uint32_t));
-    uint32_t sizeOfUint16 = numeric_cast<uint32_t>(sizeof(uint16_t));
+    uint32_t sizeOfUint32 = boost::numeric_cast<uint32_t>(sizeof(uint32_t));
+    uint32_t sizeOfUint16 = boost::numeric_cast<uint32_t>(sizeof(uint16_t));
     uint32_t offset = 0;
 
-    if (packet.GetLength() > 0)
+    if (packet.GetLength() < 4)
     {
-        if (packet.GetLength() >= 4)
-        {
-            captureData.SetCapturePeriod(ReadUint32(reinterpret_cast<const unsigned char*>(packet.GetData()), offset));
+        // Insufficient packet size
+        return;
+    }
 
-            unsigned int counters = (packet.GetLength() - 4) / 2;
+    // Parse the capture period
+    uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
 
-            if (counters > 0)
-            {
-                counterIds.reserve(counters);
-                offset += sizeOfUint32;
-                for(unsigned int pos = 0; pos < counters; ++pos)
-                {
-                    counterIds.emplace_back(ReadUint16(reinterpret_cast<const unsigned char*>(packet.GetData()),
-                                            offset));
-                    offset += sizeOfUint16;
-                }
-            }
+    // Set the capture period
+    captureData.SetCapturePeriod(capturePeriod);
 
-            captureData.SetCounterIds(counterIds);
+    // Parse the counter ids
+    unsigned int counters = (packet.GetLength() - 4) / 2;
+    if (counters > 0)
+    {
+        counterIds.reserve(counters);
+        offset += sizeOfUint32;
+        for (unsigned int i = 0; i < counters; ++i)
+        {
+            // Parse the counter id
+            uint16_t counterId = ReadUint16(packet.GetData(), offset);
+            counterIds.emplace_back(counterId);
+            offset += sizeOfUint16;
         }
     }
+
+    // Set the counter ids
+    captureData.SetCounterIds(counterIds);
 }
 
 void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
 {
-    CaptureData captureData;
+    ProfilingState currentState = m_StateMachine.GetCurrentState();
+    switch (currentState)
+    {
+    case ProfilingState::Uninitialised:
+    case ProfilingState::NotConnected:
+    case ProfilingState::WaitingForAck:
+        throw RuntimeException(boost::str(boost::format("Periodic Counter Selection Command Handler invoked while in "
+                                                        "an wrong state: %1%")
+                                          % GetProfilingStateName(currentState)));
+    case ProfilingState::Active:
+    {
+        // Process the packet
+        if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
+        {
+            throw armnn::InvalidArgumentException(boost::str(boost::format("Expected Packet family = 0, id = 4 but "
+                                                                           "received family = %1%, id = %2%")
+                                                  % packet.GetPacketFamily()
+                                                  % packet.GetPacketId()));
+        }
+
+        // Parse the packet to get the capture period and counter UIDs
+        CaptureData captureData;
+        ParseData(packet, captureData);
 
-    ParseData(packet, captureData);
+        // Get the capture data
+        const uint32_t capturePeriod = captureData.GetCapturePeriod();
+        const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
 
-    vector<uint16_t> counterIds = captureData.GetCounterIds();
+        // Check whether the selected counter UIDs are valid
+        std::vector<uint16_t> validCounterIds;
+        for (uint16_t counterId : counterIds)
+        {
+            // Check whether the counter is registered
+            if (!m_ReadCounterValues.IsCounterRegistered(counterId))
+            {
+                // Invalid counter UID, ignore it and continue
+                continue;
+            }
 
-    m_CaptureDataHolder.SetCaptureData(captureData.GetCapturePeriod(), counterIds);
+            // The counter is valid
+            validCounterIds.push_back(counterId);
+        }
 
-    m_CaptureThread.Start();
+        // Set the capture data with only the valid counter UIDs
+        m_CaptureDataHolder.SetCaptureData(capturePeriod, validCounterIds);
 
-    // Write packet to Counter Stream Buffer
-    m_SendCounterPacket.SendPeriodicCounterSelectionPacket(captureData.GetCapturePeriod(), captureData.GetCounterIds());
+        // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
+        m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
+
+        // Notify the Send Thread that new data is available in the Counter Stream Buffer
+        m_SendCounterPacket.SetReadyToRead();
+
+        // Start the Period Counter Capture thread (if not running already)
+        m_PeriodicCounterCapture.Start();
+
+        break;
+    }
+    default:
+        throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1%")
+                                          % static_cast<int>(currentState)));
+    }
 }
 
 } // namespace profiling
 
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
index e247e7773ffd30f443f21bd4625ebcfc70673366..1da08e3c7a1ea8611f13e9ba962500b05d346b1e 100644 (file)
 #include "Holder.hpp"
 #include "SendCounterPacket.hpp"
 #include "IPeriodicCounterCapture.hpp"
-
-#include <vector>
-#include <thread>
-#include <atomic>
+#include "ICounterValues.hpp"
 
 namespace armnn
 {
@@ -25,22 +22,30 @@ class PeriodicCounterSelectionCommandHandler : public CommandHandlerFunctor
 {
 
 public:
-    PeriodicCounterSelectionCommandHandler(uint32_t packetId, uint32_t version, Holder& captureDataHolder,
-                                           IPeriodicCounterCapture& captureThread,
-                                           ISendCounterPacket& sendCounterPacket)
-    : CommandHandlerFunctor(packetId, version),
-    m_CaptureDataHolder(captureDataHolder),
-    m_CaptureThread(captureThread),
-    m_SendCounterPacket(sendCounterPacket)
+    PeriodicCounterSelectionCommandHandler(uint32_t packetId,
+                                           uint32_t version,
+                                           Holder& captureDataHolder,
+                                           IPeriodicCounterCapture& periodicCounterCapture,
+                                           const IReadCounterValues& readCounterValue,
+                                           ISendCounterPacket& sendCounterPacket,
+                                           const ProfilingStateMachine& profilingStateMachine)
+        : CommandHandlerFunctor(packetId, version)
+        , m_CaptureDataHolder(captureDataHolder)
+        , m_PeriodicCounterCapture(periodicCounterCapture)
+        , m_ReadCounterValues(readCounterValue)
+        , m_SendCounterPacket(sendCounterPacket)
+        , m_StateMachine(profilingStateMachine)
     {}
 
     void operator()(const Packet& packet) override;
 
-
 private:
     Holder& m_CaptureDataHolder;
-    IPeriodicCounterCapture& m_CaptureThread;
+    IPeriodicCounterCapture& m_PeriodicCounterCapture;
+    const IReadCounterValues& m_ReadCounterValues;
     ISendCounterPacket& m_SendCounterPacket;
+    const ProfilingStateMachine& m_StateMachine;
+
     void ParseData(const Packet& packet, CaptureData& captureData);
 };
 
index 693f8337db5cdfdfb606cff4dae7278714f06cd0..79184416cd2407a6a5929f32f216c0bec088582d 100644 (file)
@@ -53,6 +53,9 @@ void ProfilingService::Update()
         // Stop the send thread (if running)
         m_SendCounterPacket.Stop(false);
 
+        // Stop the periodic counter capture thread (if running)
+        m_PeriodicCounterCapture.Stop();
+
         // Reset any existing profiling connection
         m_ProfilingConnection.reset();
 
@@ -90,6 +93,9 @@ void ProfilingService::Update()
         break;
     case ProfilingState::Active:
 
+        // The period counter capture thread is started by the Periodic Counter Selection command handler upon
+        // request by an external profiling service
+
         break;
     default:
         throw RuntimeException(boost::str(boost::format("Unknown profiling service state: %1")
@@ -112,9 +118,14 @@ uint16_t ProfilingService::GetCounterCount() const
     return m_CounterDirectory.GetCounterCount();
 }
 
+bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
+{
+    return counterUid < m_CounterIndex.size();
+}
+
 uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     return counterValuePtr->load(std::memory_order::memory_order_relaxed);
@@ -122,7 +133,7 @@ uint32_t ProfilingService::GetCounterValue(uint16_t counterUid) const
 
 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
@@ -130,7 +141,7 @@ void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
 
 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
@@ -138,7 +149,7 @@ uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
 
 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
@@ -146,7 +157,7 @@ uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t va
 
 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
@@ -154,7 +165,7 @@ uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
 
 uint32_t ProfilingService::DecrementCounterValue(uint16_t counterUid)
 {
-    BOOST_ASSERT(counterUid < m_CounterIndex.size());
+    CheckCounterUid(counterUid);
     std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
     BOOST_ASSERT(counterValuePtr);
     return counterValuePtr->operator--(std::memory_order::memory_order_relaxed);
@@ -239,6 +250,7 @@ void ProfilingService::Reset()
     // First stop the threads (Command Handler first)...
     m_CommandHandler.Stop();
     m_SendCounterPacket.Stop(false);
+    m_PeriodicCounterCapture.Stop();
 
     // ...then destroy the profiling connection...
     m_ProfilingConnection.reset();
@@ -252,6 +264,14 @@ void ProfilingService::Reset()
     m_StateMachine.Reset();
 }
 
+inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
+{
+    if (!IsCounterRegistered(counterUid))
+    {
+        throw InvalidArgumentException(boost::str(boost::format("Counter UID %1% is not registered") % counterUid));
+    }
+}
+
 } // namespace profiling
 
 } // namespace armnn
index 0e66924267fbe917fe092f948d0d8322d3943a4c..dd70af4b39351cfa104cb11b6392ae61b6048700 100644 (file)
 #include "CommandHandler.hpp"
 #include "BufferManager.hpp"
 #include "SendCounterPacket.hpp"
+#include "PeriodicCounterCapture.hpp"
 #include "ConnectionAcknowledgedCommandHandler.hpp"
 #include "RequestCounterDirectoryCommandHandler.hpp"
+#include "PeriodicCounterSelectionCommandHandler.hpp"
 
 namespace armnn
 {
@@ -46,6 +48,7 @@ public:
     // Getters for the profiling service state
     const ICounterDirectory& GetCounterDirectory() const;
     ProfilingState GetCurrentState() const;
+    bool IsCounterRegistered(uint16_t counterUid) const override;
     uint16_t GetCounterCount() const override;
     uint32_t GetCounterValue(uint16_t counterUid) const override;
 
@@ -68,6 +71,9 @@ private:
     void InitializeCounterValue(uint16_t counterUid);
     void Reset();
 
+    // Helper function
+    void CheckCounterUid(uint16_t counterUid) const;
+
     // Profiling service components
     ExternalProfilingOptions m_Options;
     CounterDirectory m_CounterDirectory;
@@ -81,8 +87,11 @@ private:
     CommandHandler m_CommandHandler;
     BufferManager m_BufferManager;
     SendCounterPacket m_SendCounterPacket;
+    Holder m_Holder;
+    PeriodicCounterCapture m_PeriodicCounterCapture;
     ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler;
     RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler;
+    PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler;
 
 protected:
     // Default constructor/destructor kept protected for testing
@@ -102,6 +111,7 @@ protected:
                            m_PacketVersionResolver)
         , m_BufferManager()
         , m_SendCounterPacket(m_StateMachine, m_BufferManager)
+        , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this)
         , m_ConnectionAcknowledgedCommandHandler(1,
                                                  m_PacketVersionResolver.ResolvePacketVersion(1).GetEncodedValue(),
                                                  m_StateMachine)
@@ -110,12 +120,22 @@ protected:
                                                   m_CounterDirectory,
                                                   m_SendCounterPacket,
                                                   m_StateMachine)
+        , m_PeriodicCounterSelectionCommandHandler(4,
+                                                   m_PacketVersionResolver.ResolvePacketVersion(4).GetEncodedValue(),
+                                                   m_Holder,
+                                                   m_PeriodicCounterCapture,
+                                                   *this,
+                                                   m_SendCounterPacket,
+                                                   m_StateMachine)
     {
         // Register the "Connection Acknowledged" command handler
         m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
 
         // Register the "Request Counter Directory" command handler
         m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
+
+        // Register the "Periodic Counter Selection" command handler
+        m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler);
     }
     ~ProfilingService() = default;
 
@@ -138,6 +158,10 @@ protected:
     {
         instance.m_StateMachine.TransitionToState(newState);
     }
+    void WaitForPacketSent(ProfilingService& instance)
+    {
+        return instance.m_SendCounterPacket.WaitForPacketSent();
+    }
 };
 
 } // namespace profiling
index e85acb4215f4bfa63c76c8f7afdf36b32acc9506..b8ac9d94261585c61563c32c2c0f2aca7e850eb6 100644 (file)
@@ -21,7 +21,7 @@ void RequestCounterDirectoryCommandHandler::operator()(const Packet& packet)
     case ProfilingState::Uninitialised:
     case ProfilingState::NotConnected:
     case ProfilingState::WaitingForAck:
-        throw RuntimeException(boost::str(boost::format("Request Counter Directory Handler invoked while in an "
+        throw RuntimeException(boost::str(boost::format("Request Counter Directory Comand Handler invoked while in an "
                                                         "wrong state: %1%")
                                           % GetProfilingStateName(currentState)));
     case ProfilingState::Active:
index e48da3ed7c273cdef3bd4ae057370268a5c688c0..41adf3724414645f0a75a866c599e28844bc1086 100644 (file)
@@ -1035,17 +1035,21 @@ void SendCounterPacket::Send(IProfilingConnection& profilingConnection)
     }
 
     // Ensure that all readable data got written to the profiling connection before the thread is stopped
-    FlushBuffer(profilingConnection);
+    // (do not notify any watcher in this case, as this is just to wrap up things before shutting down the send thread)
+    FlushBuffer(profilingConnection, false);
 
     // Mark the send thread as not running
     m_IsRunning.store(false);
 }
 
-void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
+void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers)
 {
     // Get the first available readable buffer
     std::unique_ptr<IPacketBuffer> packetBuffer = m_BufferManager.GetReadableBuffer();
 
+    // Initialize the flag that indicates whether at least a packet has been sent
+    bool packetsSent = false;
+
     while (packetBuffer != nullptr)
     {
         // Get the data to send from the buffer
@@ -1066,6 +1070,9 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
         {
             // Write a packet to the profiling connection. Silently ignore any write error and continue
             profilingConnection.WritePacket(readBuffer, boost::numeric_cast<uint32_t>(readBufferSize));
+
+            // Set the flag that indicates whether at least a packet has been sent
+            packetsSent = true;
         }
 
         // Mark the packet buffer as read
@@ -1074,6 +1081,13 @@ void SendCounterPacket::FlushBuffer(IProfilingConnection& profilingConnection)
         // Get the next available readable buffer
         packetBuffer = m_BufferManager.GetReadableBuffer();
     }
+
+    // Check whether at least a packet has been sent
+    if (packetsSent && notifyWatchers)
+    {
+        // Notify to any watcher that something has been sent
+        m_PacketSentWaitCondition.notify_one();
+    }
 }
 
 } // namespace profiling
index 9361efbc745325c187722bef1118f5f138d024d4..e1a42aa496f710c0797658e79905ace6a246e355 100644 (file)
@@ -65,6 +65,14 @@ public:
     void Stop(bool rethrowSendThreadExceptions = true);
     bool IsRunning() { return m_IsRunning.load(); }
 
+    void WaitForPacketSent()
+    {
+        std::unique_lock<std::mutex> lock(m_PacketSentWaitMutex);
+
+        // Blocks until notified that at least a packet has been sent
+        m_PacketSentWaitCondition.wait(lock);
+    }
+
 private:
     void Send(IProfilingConnection& profilingConnection);
 
@@ -93,7 +101,7 @@ private:
         throw ExceptionType(errorMessage);
     }
 
-    void FlushBuffer(IProfilingConnection& profilingConnection);
+    void FlushBuffer(IProfilingConnection& profilingConnection, bool notifyWatchers = true);
 
     ProfilingStateMachine& m_StateMachine;
     IBufferManager& m_BufferManager;
@@ -104,6 +112,8 @@ private:
     std::atomic<bool> m_IsRunning;
     std::atomic<bool> m_KeepRunning;
     std::exception_ptr m_SendThreadException;
+    std::mutex m_PacketSentWaitMutex;
+    std::condition_variable m_PacketSentWaitCondition;
 
 protected:
     // Helper methods, protected for testing
index 27bacf71454e66d7e42e04b78f6b413244c777de..554b7e1936f7369a2a8442bc37aa305d76132511 100644 (file)
@@ -35,6 +35,7 @@
 #include <limits>
 #include <map>
 #include <random>
+#include <iostream>
 
 using namespace armnn::profiling;
 
@@ -1691,11 +1692,19 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
         void Stop() override {}
     };
 
+    class TestReadCounterValues : public IReadCounterValues
+    {
+        bool IsCounterRegistered(uint16_t counterUid) const override { return true; }
+        uint16_t GetCounterCount() const override { return 0; }
+        uint32_t GetCounterValue(uint16_t counterUid) const override { return 0; }
+    };
+
     const uint32_t packetId = 0x40000;
 
     uint32_t version = 1;
     Holder holder;
     TestCaptureThread captureThread;
+    TestReadCounterValues readCounterValues;
     MockBufferManager mockBuffer(512);
     SendCounterPacket sendCounterPacket(profilingStateMachine, mockBuffer);
 
@@ -1718,16 +1727,29 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
 
     Packet packetA(packetId, dataLength1, uniqueData1);
 
-    PeriodicCounterSelectionCommandHandler commandHandler(packetId, version, holder, captureThread,
-                                                          sendCounterPacket);
-    commandHandler(packetA);
+    PeriodicCounterSelectionCommandHandler commandHandler(packetId,
+                                                          version,
+                                                          holder,
+                                                          captureThread,
+                                                          readCounterValues,
+                                                          sendCounterPacket,
+                                                          profilingStateMachine);
 
-    std::vector<uint16_t> counterIds = holder.GetCaptureData().GetCounterIds();
+    profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
+    BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+    profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
+    BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+    profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
+    BOOST_CHECK_THROW(commandHandler(packetA), armnn::RuntimeException);
+    profilingStateMachine.TransitionToState(ProfilingState::Active);
+    BOOST_CHECK_NO_THROW(commandHandler(packetA));
+
+    const std::vector<uint16_t> counterIdsA = holder.GetCaptureData().GetCounterIds();
 
     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period1);
-    BOOST_TEST(counterIds.size() == 2);
-    BOOST_TEST(counterIds[0] == 4000);
-    BOOST_TEST(counterIds[1] == 5000);
+    BOOST_TEST(counterIdsA.size() == 2);
+    BOOST_TEST(counterIdsA[0] == 4000);
+    BOOST_TEST(counterIdsA[1] == 5000);
 
     auto readBuffer = mockBuffer.GetReadableBuffer();
 
@@ -1766,10 +1788,10 @@ BOOST_AUTO_TEST_CASE(CounterSelectionCommandHandlerParseData)
 
     commandHandler(packetB);
 
-    counterIds = holder.GetCaptureData().GetCounterIds();
+    const std::vector<uint16_t> counterIdsB = holder.GetCaptureData().GetCounterIds();
 
     BOOST_TEST(holder.GetCaptureData().GetCapturePeriod() == period2);
-    BOOST_TEST(counterIds.size() == 0);
+    BOOST_TEST(counterIdsB.size() == 0);
 
     readBuffer = mockBuffer.GetReadableBuffer();
 
@@ -2024,35 +2046,40 @@ BOOST_AUTO_TEST_CASE(CheckPeriodicCounterCaptureThread)
     public:
         CaptureReader() {}
 
+        bool IsCounterRegistered(uint16_t counterUid) const override
+        {
+            return m_Data.find(counterUid) != m_Data.end();
+        }
+
         uint16_t GetCounterCount() const override
         {
             return boost::numeric_cast<uint16_t>(m_Data.size());
         }
 
-        uint32_t GetCounterValue(uint16_t index) const override
+        uint32_t GetCounterValue(uint16_t counterUid) const override
         {
-            if (m_Data.find(index) == m_Data.end())
+            if (m_Data.find(counterUid) == m_Data.end())
             {
                 return 0;
             }
 
-            return m_Data.at(index);
+            return m_Data.at(counterUid).load();
         }
 
-        void SetCounterValue(uint16_t index, uint32_t value)
+        void SetCounterValue(uint16_t counterUid, uint32_t value)
         {
-            if (m_Data.find(index) == m_Data.end())
+            if (m_Data.find(counterUid) == m_Data.end())
             {
-                m_Data.insert(std::pair<uint16_t, uint32_t>(index, value));
+                m_Data.insert(std::make_pair(counterUid, value));
             }
             else
             {
-                m_Data.at(index) = value;
+                m_Data.at(counterUid).store(value);
             }
         }
 
     private:
-        std::unordered_map<uint16_t, uint32_t> m_Data;
+        std::unordered_map<uint16_t, std::atomic<uint32_t>> m_Data;
     };
 
     ProfilingStateMachine profilingStateMachine;
@@ -2261,19 +2288,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadConnectionAcknowledgedPacket)
 
     // Bring the profiling service to the "WaitingForAck" state
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
-    profilingService.Update();
+    profilingService.Update(); // Initialize the counter directory
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
-    profilingService.Update();
-    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
-    profilingService.Update();
-
-    // Wait for a bit to make sure that we get the packet
-    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+    profilingService.Update();// Create the profiling connection
 
     // Get the mock profiling connection
     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
     BOOST_CHECK(mockProfilingConnection);
 
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update();
+
+    // Wait for the Stream Metadata packet to be sent
+    helper.WaitForProfilingPacketsSent();
+
     // Check that the mock profiling connection contains one Stream Metadata packet
     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
@@ -2330,19 +2361,23 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodConnectionAcknowledgedPacket)
 
     // Bring the profiling service to the "WaitingForAck" state
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
-    profilingService.Update();
+    profilingService.Update(); // Initialize the counter directory
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
-    profilingService.Update();
-    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
-    profilingService.Update();
-
-    // Wait for a bit to make sure that we get the packet
-    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+    profilingService.Update(); // Create the profiling connection
 
     // Get the mock profiling connection
     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
     BOOST_CHECK(mockProfilingConnection);
 
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet to be sent
+    helper.WaitForProfilingPacketsSent();
+
     // Check that the mock profiling connection contains one Stream Metadata packet
     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
     BOOST_TEST(writtenData.size() == 1);
@@ -2403,7 +2438,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
     profilingService.Update(); // Create the profiling connection
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
-    profilingService.Update(); // Start the threads
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
     helper.ForceTransitionToState(ProfilingState::Active);
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
 
@@ -2411,6 +2452,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
     BOOST_CHECK(mockProfilingConnection);
 
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
     // reply from an external profiling service
 
@@ -2437,7 +2481,7 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadRequestCounterDirectoryPacket)
     // Check that the expected error has occurred and logged to the standard output
     BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=123 and Version=4194304 does not exist"));
 
-    // The Connection Acknowledged Command Handler should not have updated the profiling state
+    // The Request Counter Directory Command Handler should not have updated the profiling state
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
 
     // Reset the profiling service to stop any running thread
@@ -2462,7 +2506,13 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
     profilingService.Update(); // Create the profiling connection
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
-    profilingService.Update(); // Start the threads
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
     helper.ForceTransitionToState(ProfilingState::Active);
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
 
@@ -2470,6 +2520,9 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
     BOOST_CHECK(mockProfilingConnection);
 
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
     // reply from an external profiling service
 
@@ -2489,17 +2542,470 @@ BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodRequestCounterDirectoryPacket)
     // Write the packet to the mock profiling connection
     mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
 
+    // Wait for the Counter Directory packet to be sent
+    helper.WaitForProfilingPacketsSent();
+
+    // Check that the mock profiling connection contains one Counter Directory packet
+    const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+    BOOST_TEST(writtenData.size() == 1);
+    BOOST_TEST(writtenData[0] == 416); // The size of the expected Counter Directory packet
+
+    // The Request Counter Directory Command Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacket)
+{
+    // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+    LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Redirect the standard output to a local stream so that we can parse the warning message
+    std::stringstream ss;
+    StreamRedirector streamRedirector(std::cout, ss.rdbuf());
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+    // external profiling service
+
+    // Periodic Counter Selection packet header:
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 999; // Wrong packet id!!!
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    // Create the Periodic Counter Selection packet
+    Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
-    // the Create the Request Counter packet gets processed by the profiling service
+    // the Periodic Counter Selection packet gets processed by the profiling service
     std::this_thread::sleep_for(std::chrono::seconds(2));
 
-    // The Connection Acknowledged Command Handler should not have updated the profiling state
+    // Check that the expected error has occurred and logged to the standard output
+    BOOST_CHECK(boost::contains(ss.str(), "Functor with requested PacketId=999 and Version=4194304 does not exist"));
+
+    // The Periodic Counter Selection Handler should not have updated the profiling state
     BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
 
-    // Check that the mock profiling connection contains one Counter Directory packet
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid)
+{
+    // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
+    LogLevelSwapper logLevelSwapper(armnn::LogSeverity::Warning);
+
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+    // external profiling service
+
+    // Periodic Counter Selection packet header:
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 4;
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+    // Get the first valid counter UID
+    const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+    const Counters& counters = counterDirectory.GetCounters();
+    BOOST_CHECK(counters.size() > 1);
+    uint16_t counterUidA = counters.begin()->first; // First valid counter UID
+    uint16_t counterUidB = 9999;                    // Second invalid counter UID
+
+    uint32_t length = 8;
+
+    auto data = std::make_unique<unsigned char[]>(length);
+    WriteUint32(data.get(), 0, capturePeriod);
+    WriteUint16(data.get(), 4, counterUidA);
+    WriteUint16(data.get(), 6, counterUidB);
+
+    // Create the Periodic Counter Selection packet
+    Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+                                                                 // Capture thread
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+    // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+    int expectedPackets = 2;
+    std::vector<uint32_t> receivedPackets;
+
+    // Keep waiting until all the expected packets have been received
+    do
+    {
+        helper.WaitForProfilingPacketsSent();
+        const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+        if (writtenData.empty())
+        {
+            BOOST_ERROR("Packets should be available for reading at this point");
+            return;
+        }
+        receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+        expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+    }
+    while (expectedPackets > 0);
+    BOOST_TEST(!receivedPackets.empty());
+
+    // The size of the expected Periodic Counter Selection packet
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
+    // The size of the expected Periodic Counter Capture packet
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
+
+    // The Periodic Counter Selection Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters)
+{
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+    // external profiling service
+
+    // Periodic Counter Selection packet header:
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 4;
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    // Create the Periodic Counter Selection packet
+    Packet periodicCounterSelectionPacket(header); // Length == 0, this will disable the collection of counters
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+    // Wait for the Periodic Counter Selection packet to be sent
+    helper.WaitForProfilingPacketsSent();
+
+    // The Periodic Counter Selection Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Check that the mock profiling connection contains one Periodic Counter Selection
     const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
-    BOOST_TEST(writtenData.size() == 1);
-    BOOST_TEST(writtenData[0] == 416); // The size of a valid Counter Directory packet
+    BOOST_TEST(writtenData.size() == 1); // Only one packet is expected (no Periodic Counter packets)
+    BOOST_TEST(writtenData[0] == 12); // The size of the expected Periodic Counter Selection (echos the sent one)
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter)
+{
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+    // external profiling service
+
+    // Periodic Counter Selection packet header:
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 4;
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+    // Get the first valid counter UID
+    const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+    const Counters& counters = counterDirectory.GetCounters();
+    BOOST_CHECK(!counters.empty());
+    uint16_t counterUid = counters.begin()->first; // Valid counter UID
+
+    uint32_t length = 6;
+
+    auto data = std::make_unique<unsigned char[]>(length);
+    WriteUint32(data.get(), 0, capturePeriod);
+    WriteUint16(data.get(), 4, counterUid);
+
+    // Create the Periodic Counter Selection packet
+    Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+                                                                 // Capture thread
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+    // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+    int expectedPackets = 2;
+    std::vector<uint32_t> receivedPackets;
+
+    // Keep waiting until all the expected packets have been received
+    do
+    {
+        helper.WaitForProfilingPacketsSent();
+        const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+        if (writtenData.empty())
+        {
+            BOOST_ERROR("Packets should be available for reading at this point");
+            return;
+        }
+        receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+        expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+    }
+    while (expectedPackets > 0);
+    BOOST_TEST(!receivedPackets.empty());
+
+    // The size of the expected Periodic Counter Selection packet (echos the sent one)
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 14) != receivedPackets.end()));
+    // The size of the expected Periodic Counter Capture packet
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 22) != receivedPackets.end()));
+
+    // The Periodic Counter Selection Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Reset the profiling service to stop any running thread
+    options.m_EnableProfiling = false;
+    profilingService.ResetExternalProfilingOptions(options, true);
+}
+
+BOOST_AUTO_TEST_CASE(CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters)
+{
+    // Swap the profiling connection factory in the profiling service instance with our mock one
+    SwapProfilingConnectionFactoryHelper helper;
+
+    // Reset the profiling service to the uninitialized state
+    armnn::Runtime::CreationOptions::ExternalProfilingOptions options;
+    options.m_EnableProfiling = true;
+    ProfilingService& profilingService = ProfilingService::Instance();
+    profilingService.ResetExternalProfilingOptions(options, true);
+
+    // Bring the profiling service to the "Active" state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
+    profilingService.Update(); // Initialize the counter directory
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
+    profilingService.Update(); // Create the profiling connection
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
+    profilingService.Update(); // Start the command handler and the send thread
+
+    // Wait for the Stream Metadata packet the be sent
+    // (we are not testing the connection acknowledgement here so it will be ignored by this test)
+    helper.WaitForProfilingPacketsSent();
+
+    // Force the profiling service to the "Active" state
+    helper.ForceTransitionToState(ProfilingState::Active);
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
+
+    // Get the mock profiling connection
+    MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
+    BOOST_CHECK(mockProfilingConnection);
+
+    // Remove the packets received so far
+    mockProfilingConnection->Clear();
+
+    // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
+    // external profiling service
+
+    // Periodic Counter Selection packet header:
+    // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
+    // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
+    // 8:15  [8]  reserved: Reserved, value 0b00000000
+    // 0:7   [8]  reserved: Reserved, value 0b00000000
+    uint32_t packetFamily = 0;
+    uint32_t packetId     = 4;
+    uint32_t header = ((packetFamily & 0x0000003F) << 26) |
+                      ((packetId     & 0x000003FF) << 16);
+
+    uint32_t capturePeriod = 123456; // Some capture period (microseconds)
+
+    // Get the first valid counter UID
+    const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
+    const Counters& counters = counterDirectory.GetCounters();
+    BOOST_CHECK(counters.size() > 1);
+    uint16_t counterUidA = counters.begin()->first;     // First valid counter UID
+    uint16_t counterUidB = (counters.begin()++)->first; // Second valid counter UID
+
+    uint32_t length = 8;
+
+    auto data = std::make_unique<unsigned char[]>(length);
+    WriteUint32(data.get(), 0, capturePeriod);
+    WriteUint16(data.get(), 4, counterUidA);
+    WriteUint16(data.get(), 6, counterUidB);
+
+    // Create the Periodic Counter Selection packet
+    Packet periodicCounterSelectionPacket(header, length, data); // Length > 0, this will start the Period Counter
+                                                                 // Capture thread
+
+    // Write the packet to the mock profiling connection
+    mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
+
+    // Expecting one Periodic Counter Selection packet and at least one Periodic Counter Capture packet
+    int expectedPackets = 2;
+    std::vector<uint32_t> receivedPackets;
+
+    // Keep waiting until all the expected packets have been received
+    do
+    {
+        helper.WaitForProfilingPacketsSent();
+        const std::vector<uint32_t> writtenData = mockProfilingConnection->GetWrittenData();
+        if (writtenData.empty())
+        {
+            BOOST_ERROR("Packets should be available for reading at this point");
+            return;
+        }
+        receivedPackets.insert(receivedPackets.end(), writtenData.begin(), writtenData.end());
+        expectedPackets -= boost::numeric_cast<int>(writtenData.size());
+    }
+    while (expectedPackets > 0);
+    BOOST_TEST(!receivedPackets.empty());
+
+    // The size of the expected Periodic Counter Selection packet (echos the sent one)
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 16) != receivedPackets.end()));
+    // The size of the expected Periodic Counter Capture packet
+    BOOST_TEST((std::find(receivedPackets.begin(), receivedPackets.end(), 28) != receivedPackets.end()));
+
+    // The Periodic Counter Selection Handler should not have updated the profiling state
+    BOOST_CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
 
     // Reset the profiling service to stop any running thread
     options.m_EnableProfiling = false;
index 4d2f974344cbe4a25e0857e5560f8dd681038b8f..21c98723be93ca8896d57ce9723e4faa8b1ecf00 100644 (file)
@@ -9,14 +9,12 @@
 
 #include <CommandHandlerFunctor.hpp>
 #include <IProfilingConnection.hpp>
-#include <IProfilingConnectionFactory.hpp>
 #include <Logging.hpp>
 #include <ProfilingService.hpp>
 
 #include <boost/test/unit_test.hpp>
 
 #include <chrono>
-#include <iostream>
 #include <thread>
 
 namespace armnn
@@ -137,15 +135,6 @@ class TestFunctorC : public TestFunctorA
     using TestFunctorA::TestFunctorA;
 };
 
-class MockProfilingConnectionFactory : public IProfilingConnectionFactory
-{
-public:
-    IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
-    {
-        return std::make_unique<MockProfilingConnection>();
-    }
-};
-
 class SwapProfilingConnectionFactoryHelper : public ProfilingService
 {
 public:
@@ -182,6 +171,11 @@ public:
         TransitionToState(ProfilingService::Instance(), newState);
     }
 
+    void WaitForProfilingPacketsSent()
+    {
+        return WaitForPacketSent(ProfilingService::Instance());
+    }
+
 private:
     MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
     IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
index 871ca741247d8306fffc2c6dc961e43e79162319..73fc39b437a5549124b7ba78d3912e9b556762a2 100644 (file)
@@ -7,6 +7,7 @@
 
 #include <SendCounterPacket.hpp>
 #include <ProfilingUtils.hpp>
+#include <IProfilingConnectionFactory.hpp>
 
 #include <armnn/Exceptions.hpp>
 #include <armnn/Optional.hpp>
@@ -74,11 +75,13 @@ public:
         return std::move(m_Packet);
     }
 
-    const std::vector<uint32_t> GetWrittenData() const
+    const std::vector<uint32_t> GetWrittenData()
     {
         std::lock_guard<std::mutex> lock(m_Mutex);
 
-        return m_WrittenData;
+        std::vector<uint32_t> writtenData = m_WrittenData;
+        m_WrittenData.clear();
+        return writtenData;
     }
 
     void Clear()
@@ -95,6 +98,15 @@ private:
     mutable std::mutex m_Mutex;
 };
 
+class MockProfilingConnectionFactory : public IProfilingConnectionFactory
+{
+public:
+    IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override
+    {
+        return std::make_unique<MockProfilingConnection>();
+    }
+};
+
 class MockPacketBuffer : public IPacketBuffer
 {
 public: