IVGCVSW-4666 Call EnableProfiling when state switches to active
authorFinn Williams <Finn.Williams@arm.com>
Thu, 9 Apr 2020 15:05:28 +0000 (16:05 +0100)
committerFinn Williams <Finn.Williams@arm.com>
Fri, 10 Apr 2020 13:48:56 +0000 (14:48 +0100)
 * Move the call to EnableProfiling() into ConnectionAcknowledgedHandler
 * Fix an issue with MockGatord forcing some command handlers to be quiet
 * Add some small unrelated improvements and typo fixes to the
   periodic counter command handlers

Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I9e6066b78d1f782cfaf27c11571c0ec5cb5d126f

12 files changed:
src/armnn/Runtime.cpp
src/backends/backendsCommon/test/MockBackend.hpp
src/profiling/ConnectionAcknowledgedCommandHandler.cpp
src/profiling/ConnectionAcknowledgedCommandHandler.hpp
src/profiling/CounterDirectory.cpp
src/profiling/PeriodicCounterCapture.cpp
src/profiling/PeriodicCounterCapture.hpp
src/profiling/PeriodicCounterSelectionCommandHandler.cpp
src/profiling/PeriodicCounterSelectionCommandHandler.hpp
src/profiling/ProfilingService.hpp
tests/profiling/gatordmock/GatordMockService.hpp
tests/profiling/gatordmock/tests/GatordMockTests.cpp

index 32c7c39..483eea7 100644 (file)
@@ -178,9 +178,6 @@ Runtime::Runtime(const CreationOptions& options)
         throw RuntimeException("It is not possible to enable timeline reporting without profiling being enabled");
     }
 
-    // pass configuration info to the profiling service
-    m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions);
-
     // Load any available/compatible dynamic backend before the runtime
     // goes through the backend registry
     LoadDynamicBackends(options.m_DynamicBackendsPath);
@@ -213,24 +210,19 @@ Runtime::Runtime(const CreationOptions& options)
             // Backends that don't support profiling will return a null profiling context.
             if (profilingContext)
             {
-                // Enable profiling on the backend and assert that it returns true
-                if(profilingContext->EnableProfiling(true))
-                {
-                    // Pass the context onto the profiling service.
-                    m_ProfilingService.AddBackendProfilingContext(id, profilingContext);
-                }
-                else
-                {
-                    throw BackendProfilingException("Unable to enable profiling on Backend Id: " + id.Get());
-                }
+                // Pass the context onto the profiling service.
+                m_ProfilingService.AddBackendProfilingContext(id, profilingContext);
             }
         }
         catch (const BackendUnavailableException&)
         {
             // Ignore backends which are unavailable
         }
-
     }
+
+    // pass configuration info to the profiling service
+    m_ProfilingService.ConfigureProfilingService(options.m_ProfilingOptions);
+
     m_DeviceSpec.AddSupportedBackends(supportedBackends);
 }
 
@@ -273,7 +265,6 @@ Runtime::~Runtime()
         }
     }
 
-
     // Clear all dynamic backends.
     DynamicBackendUtils::DeregisterDynamicBackends(m_DeviceSpec.GetDynamicBackends());
     m_DeviceSpec.ClearDynamicBackends();
index e1570ff..d90ad79 100644 (file)
@@ -45,18 +45,19 @@ public:
     uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
     {
         std::unique_ptr<profiling::IRegisterBackendCounters> counterRegistrar =
-            m_BackendProfiling->GetCounterRegistrationInterface(currentMaxGlobalCounterId);
+            m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
 
         std::string categoryName("MockCounters");
         counterRegistrar->RegisterCategory(categoryName);
-        uint16_t nextMaxGlobalCounterId =
-            counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
 
-        nextMaxGlobalCounterId = counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
+        counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
+
+        counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
                                                                    "Another notional counter");
 
         std::string units("microseconds");
-        nextMaxGlobalCounterId = counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
+        uint16_t nextMaxGlobalCounterId =
+                counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
                                                                    "A dummy four core counter", units, 4);
         return nextMaxGlobalCounterId;
     }
@@ -91,6 +92,9 @@ public:
 
     bool EnableProfiling(bool)
     {
+        auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
+        sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
+        sendTimelinePacket->Commit();
         return true;
     }
 
index 0071bfc..995562f 100644 (file)
@@ -41,9 +41,21 @@ void ConnectionAcknowledgedCommandHandler::operator()(const Packet& packet)
         // Send the counter directory packet.
         m_SendCounterPacket.SendCounterDirectoryPacket(m_CounterDirectory);
         m_SendTimelinePacket.SendTimelineMessageDirectoryPackage();
-
         TimelineUtilityMethods::SendWellKnownLabelsAndEventClasses(m_SendTimelinePacket);
 
+        if(m_BackendProfilingContext.has_value())
+        {
+            for (auto backendContext : m_BackendProfilingContext.value())
+            {
+                // Enable profiling on the backend and assert that it returns true
+                if(!backendContext.second->EnableProfiling(true))
+                {
+                    throw BackendProfilingException(
+                            "Unable to enable profiling on Backend Id: " + backendContext.first.Get());
+                }
+            }
+        }
+
         break;
     case ProfilingState::Active:
         return; // NOP
index 6054306..e2bdff8 100644 (file)
@@ -5,11 +5,13 @@
 
 #pragma once
 
+#include <armnn/backends/profiling/IBackendProfilingContext.hpp>
 #include "CommandHandlerFunctor.hpp"
 #include "ISendCounterPacket.hpp"
 #include "armnn/profiling/ISendTimelinePacket.hpp"
 #include "Packet.hpp"
 #include "ProfilingStateMachine.hpp"
+#include <future>
 
 namespace armnn
 {
@@ -20,6 +22,9 @@ namespace profiling
 class ConnectionAcknowledgedCommandHandler final : public CommandHandlerFunctor
 {
 
+typedef const std::unordered_map<BackendId, std::shared_ptr<armnn::profiling::IBackendProfilingContext>>&
+    BackendProfilingContexts;
+
 public:
     ConnectionAcknowledgedCommandHandler(uint32_t familyId,
                                          uint32_t packetId,
@@ -27,12 +32,14 @@ public:
                                          ICounterDirectory& counterDirectory,
                                          ISendCounterPacket& sendCounterPacket,
                                          ISendTimelinePacket& sendTimelinePacket,
-                                         ProfilingStateMachine& profilingStateMachine)
+                                         ProfilingStateMachine& profilingStateMachine,
+                                         Optional<BackendProfilingContexts> backendProfilingContexts = EmptyOptional())
         : CommandHandlerFunctor(familyId, packetId, version)
         , m_CounterDirectory(counterDirectory)
         , m_SendCounterPacket(sendCounterPacket)
         , m_SendTimelinePacket(sendTimelinePacket)
         , m_StateMachine(profilingStateMachine)
+        , m_BackendProfilingContext(backendProfilingContexts)
     {}
 
     void operator()(const Packet& packet) override;
@@ -42,7 +49,7 @@ private:
     ISendCounterPacket&      m_SendCounterPacket;
     ISendTimelinePacket&     m_SendTimelinePacket;
     ProfilingStateMachine&   m_StateMachine;
-
+    Optional<BackendProfilingContexts> m_BackendProfilingContext;
 };
 
 } // namespace profiling
index 415a660..ae1c497 100644 (file)
@@ -498,7 +498,7 @@ CountersIt CounterDirectory::FindCounter(const std::string& counterName) const
     return std::find_if(m_Counters.begin(), m_Counters.end(), [&counterName](const auto& pair)
     {
         ARMNN_ASSERT(pair.second);
-        ARMNN_ASSERT(pair.second->m_Uid == pair.first);
+        ARMNN_ASSERT(pair.first >= pair.second->m_Uid && pair.first <= pair.second->m_MaxCounterUid);
 
         return pair.second->m_Name == counterName;
     });
index b143295..4ad1d11 100644 (file)
@@ -125,7 +125,7 @@ void PeriodicCounterCapture::Capture(const IReadCounterValues& readCounterValues
         for_each(activeBackends.begin(), activeBackends.end(), [&](const armnn::BackendId& backendId)
         {
             DispatchPeriodicCounterCapturePacket(
-                backendId, m_BackendProfilingContext.at(backendId)->ReportCounterValues());
+                backendId, m_BackendProfilingContexts.at(backendId)->ReportCounterValues());
         });
 
         // Wait the indicated capture period (microseconds)
index ff05623..51ac273 100644 (file)
@@ -39,7 +39,7 @@ public:
             , m_ReadCounterValues(readCounterValue)
             , m_SendCounterPacket(packet)
             , m_CounterIdMap(counterIdMap)
-            , m_BackendProfilingContext(backendProfilingContexts)
+            , m_BackendProfilingContexts(backendProfilingContexts)
     {}
     ~PeriodicCounterCapture() { Stop(); }
 
@@ -61,7 +61,7 @@ private:
     ISendCounterPacket&       m_SendCounterPacket;
     const ICounterMappings&   m_CounterIdMap;
     const std::unordered_map<armnn::BackendId,
-            std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContext;
+            std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContexts;
 };
 
 } // namespace profiling
index d218433..4e3e6e5 100644 (file)
@@ -140,7 +140,6 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
         // save the new backend counter ids for next time
         m_PrevBackendCounterIds = backendCounterIds;
 
-
         // Set the capture data with only the valid armnn counter UIDs
         m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
 
@@ -168,8 +167,8 @@ void PeriodicCounterSelectionCommandHandler::operator()(const Packet& packet)
 
 std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
                                                                       const u_int32_t capturePeriod,
-                                                                      std::set<uint16_t> newCounterIds,
-                                                                      std::set<uint16_t> unusedCounterIds)
+                                                                      const std::set<uint16_t> newCounterIds,
+                                                                      const std::set<uint16_t> unusedCounterIds)
 {
     std::set<armnn::BackendId> changedBackends;
     std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
index 437d712..b59d84c 100644 (file)
@@ -37,7 +37,7 @@ public:
                                            uint32_t version,
                                            const std::unordered_map<BackendId,
                                                    std::shared_ptr<armnn::profiling::IBackendProfilingContext>>&
-                                           backendProfilingContext,
+                                                   backendProfilingContexts,
                                            const ICounterMappings& counterIdMap,
                                            Holder& captureDataHolder,
                                            const uint16_t maxArmnnCounterId,
@@ -46,7 +46,7 @@ public:
                                            ISendCounterPacket& sendCounterPacket,
                                            const ProfilingStateMachine& profilingStateMachine)
         : CommandHandlerFunctor(familyId, packetId, version)
-        , m_BackendProfilingContext(backendProfilingContext)
+        , m_BackendProfilingContexts(backendProfilingContexts)
         , m_CounterIdMap(counterIdMap)
         , m_CaptureDataHolder(captureDataHolder)
         , m_MaxArmCounterId(maxArmnnCounterId)
@@ -66,7 +66,7 @@ private:
 
     std::unordered_map<armnn::BackendId, std::vector<uint16_t>> m_BackendCounterMap;
     const std::unordered_map<BackendId,
-          std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContext;
+          std::shared_ptr<armnn::profiling::IBackendProfilingContext>>& m_BackendProfilingContexts;
     const ICounterMappings& m_CounterIdMap;
     Holder& m_CaptureDataHolder;
     const uint16_t m_MaxArmCounterId;
@@ -82,7 +82,7 @@ private:
                                 const std::vector<uint16_t> counterIds)
     {
         Optional<std::string> errorMsg =
-                m_BackendProfilingContext.at(backendId)->ActivateCounters(capturePeriod, counterIds);
+                m_BackendProfilingContexts.at(backendId)->ActivateCounters(capturePeriod, counterIds);
 
         if(errorMsg.has_value())
         {
@@ -92,8 +92,8 @@ private:
     }
     void ParseData(const Packet& packet, CaptureData& captureData);
     std::set<armnn::BackendId> ProcessBackendCounterIds(const u_int32_t capturePeriod,
-                                                        std::set<uint16_t> newCounterIds,
-                                                        std::set<uint16_t> unusedCounterIds);
+                                                        const std::set<uint16_t> newCounterIds,
+                                                        const std::set<uint16_t> unusedCounterIds);
 
 };
 
index a6c5e29..f3d10e7 100644 (file)
@@ -80,7 +80,8 @@ public:
                                                  m_CounterDirectory,
                                                  m_SendCounterPacket,
                                                  m_SendTimelinePacket,
-                                                 m_StateMachine)
+                                                 m_StateMachine,
+                                                 m_BackendProfilingContexts)
         , m_RequestCounterDirectoryCommandHandler(0,
                                                   3,
                                                   m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(),
index 2ff93c9..9b72f72 100644 (file)
@@ -57,16 +57,16 @@ public:
             , m_HandlerRegistry()
             , m_TimelineDecoder()
             , m_StreamMetadataCommandHandler(
-                    0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), true)
+                    0, 0, m_PacketVersionResolver.ResolvePacketVersion(0, 0).GetEncodedValue(), !echoPackets)
             , m_CounterCaptureCommandHandler(
-                    0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), true)
+                    0, 4, m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), !echoPackets)
             , m_DirectoryCaptureCommandHandler(
-                    0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), true)
+                    0, 2, m_PacketVersionResolver.ResolvePacketVersion(0, 2).GetEncodedValue(), !echoPackets)
             , m_TimelineCaptureCommandHandler(
                     1, 1, m_PacketVersionResolver.ResolvePacketVersion(1, 1).GetEncodedValue(), m_TimelineDecoder)
             , m_TimelineDirectoryCaptureCommandHandler(
                     1, 0, m_PacketVersionResolver.ResolvePacketVersion(1, 0).GetEncodedValue(),
-                    m_TimelineCaptureCommandHandler, true)
+                    m_TimelineCaptureCommandHandler, !echoPackets)
     {
         m_TimelineDecoder.SetDefaultCallbacks();
 
index f8b42df..11a96fd 100644 (file)
@@ -443,8 +443,11 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation)
     WaitFor([&](){return timelineDecoder.GetModel().m_EventClasses.size() >= 2;},
             "MockGatord did not receive well known timeline labels");
 
+    WaitFor([&](){return timelineDecoder.GetModel().m_Entities.size() >= 1;},
+            "MockGatord did not receive mock backend test entity");
+
     // Packets we expect from SendWellKnownLabelsAndEventClassesTest
-    BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 0);
+    BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 1);
     BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size()  == 2);
     BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size()  == 10);
     BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size()  == 0);
@@ -471,7 +474,7 @@ BOOST_AUTO_TEST_CASE(GatorDMockTimeLineActivation)
             "MockGatord did not receive well known timeline labels");
 
     // Packets we expect from SendWellKnownLabelsAndEventClassesTest * 2 and the loaded model
-    BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 5);
+    BOOST_CHECK(timelineDecoder.GetModel().m_Entities.size() == 6);
     BOOST_CHECK(timelineDecoder.GetModel().m_EventClasses.size()  == 4);
     BOOST_CHECK(timelineDecoder.GetModel().m_Labels.size()  == 24);
     BOOST_CHECK(timelineDecoder.GetModel().m_Relationships.size()  == 28);