IVGCVSW-3937 Refactor the command thread
authorMatteo Martincigh <matteo.martincigh@arm.com>
Fri, 4 Oct 2019 13:40:04 +0000 (14:40 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Mon, 7 Oct 2019 10:08:27 +0000 (10:08 +0000)
 * Integrated the Join method into Stop
 * Updated the unit tests accordingly
 * General code refactoring

Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Change-Id: If8537e77b3d3ff2b780f58a07df01191a91d83d2

src/profiling/CommandThread.cpp
src/profiling/CommandThread.hpp
src/profiling/test/ProfilingTests.cpp

index bd4aa96c7ca0dbb53fbef54e07eb0f23403e6775..320e4bcf5c21debc1a31833d2129415097d1311f 100644 (file)
@@ -12,86 +12,76 @@ namespace armnn
 namespace profiling
 {
 
-CommandThread::CommandThread(uint32_t timeout,
-                             bool stopAfterTimeout,
-                             CommandHandlerRegistry& commandHandlerRegistry,
-                             PacketVersionResolver& packetVersionResolver,
-                             IProfilingConnection& socketProfilingConnection)
-    : m_Timeout(timeout)
-    , m_StopAfterTimeout(stopAfterTimeout)
-    , m_IsRunning(false)
-    , m_CommandHandlerRegistry(commandHandlerRegistry)
-    , m_PacketVersionResolver(packetVersionResolver)
-    , m_SocketProfilingConnection(socketProfilingConnection)
-{};
-
-void CommandThread::WaitForPacket()
+void CommandThread::Start()
 {
-    do {
-        try
-        {
-            Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
-            Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
-
-            CommandHandlerFunctor* commandHandlerFunctor =
-                m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
-            commandHandlerFunctor->operator()(packet);
-        }
-        catch(const armnn::TimeoutException&)
-        {
-            if(m_StopAfterTimeout)
-            {
-                m_IsRunning.store(false, std::memory_order_relaxed);
-                return;
-            }
-        }
-        catch(...)
-        {
-            //might want to differentiate the errors more
-            m_IsRunning.store(false, std::memory_order_relaxed);
-            return;
-        }
-
-    } while(m_KeepRunning.load(std::memory_order_relaxed));
+    if (IsRunning())
+    {
+        return;
+    }
 
-    m_IsRunning.store(false, std::memory_order_relaxed);
+    m_IsRunning.store(true, std::memory_order_relaxed);
+    m_KeepRunning.store(true, std::memory_order_relaxed);
+    m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
 }
 
-void CommandThread::Start()
+void CommandThread::Stop()
 {
-    if (!m_CommandThread.joinable() && !IsRunning())
+    m_KeepRunning.store(false, std::memory_order_relaxed);
+
+    if (m_CommandThread.joinable())
     {
-        m_IsRunning.store(true, std::memory_order_relaxed);
-        m_KeepRunning.store(true, std::memory_order_relaxed);
-        m_CommandThread = std::thread(&CommandThread::WaitForPacket, this);
+        m_CommandThread.join();
     }
 }
 
-void CommandThread::Stop()
+bool CommandThread::IsRunning() const
 {
-    m_KeepRunning.store(false, std::memory_order_relaxed);
+    return m_IsRunning.load(std::memory_order_relaxed);
 }
 
-void CommandThread::Join()
+void CommandThread::SetTimeout(uint32_t timeout)
 {
-    m_CommandThread.join();
+    m_Timeout.store(timeout, std::memory_order_relaxed);
 }
 
-bool CommandThread::IsRunning() const
+void CommandThread::SetStopAfterTimeout(bool stopAfterTimeout)
 {
-    return m_IsRunning.load(std::memory_order_relaxed);
+    m_StopAfterTimeout.store(stopAfterTimeout, std::memory_order_relaxed);
 }
 
-bool CommandThread::StopAfterTimeout(bool stopAfterTimeout)
+void CommandThread::WaitForPacket()
 {
-    if (!IsRunning())
+    do
     {
-        m_StopAfterTimeout = stopAfterTimeout;
-        return true;
+        try
+        {
+            Packet packet = m_SocketProfilingConnection.ReadPacket(m_Timeout);
+            Version version = m_PacketVersionResolver.ResolvePacketVersion(packet.GetPacketId());
+
+            CommandHandlerFunctor* commandHandlerFunctor =
+                m_CommandHandlerRegistry.GetFunctor(packet.GetPacketId(), version.GetEncodedValue());
+            BOOST_ASSERT(commandHandlerFunctor);
+            commandHandlerFunctor->operator()(packet);
+        }
+        catch (const armnn::TimeoutException&)
+        {
+            if (m_StopAfterTimeout)
+            {
+                m_KeepRunning.store(false, std::memory_order_relaxed);
+            }
+        }
+        catch (...)
+        {
+            // Might want to differentiate the errors more
+            m_KeepRunning.store(false, std::memory_order_relaxed);
+        }
+
     }
-    return false;
+    while (m_KeepRunning.load(std::memory_order_relaxed));
+
+    m_IsRunning.store(false, std::memory_order_relaxed);
 }
 
-}//namespace profiling
+} // namespace profiling
 
-}//namespace armnn
+} // namespace armnn
index 6237cd29142ed32a8fcbb65bf56a82d212b9979f..0456ba43720325c7ec23f29b73e536cfa03c5217 100644 (file)
@@ -26,19 +26,31 @@ public:
                   bool stopAfterTimeout,
                   CommandHandlerRegistry& commandHandlerRegistry,
                   PacketVersionResolver& packetVersionResolver,
-                  IProfilingConnection& socketProfilingConnection);
+                  IProfilingConnection& socketProfilingConnection)
+        : m_Timeout(timeout)
+        , m_StopAfterTimeout(stopAfterTimeout)
+        , m_IsRunning(false)
+        , m_KeepRunning(false)
+        , m_CommandThread()
+        , m_CommandHandlerRegistry(commandHandlerRegistry)
+        , m_PacketVersionResolver(packetVersionResolver)
+        , m_SocketProfilingConnection(socketProfilingConnection)
+    {}
+    ~CommandThread() { Stop(); }
 
     void Start();
     void Stop();
-    void Join();
+
     bool IsRunning() const;
-    bool StopAfterTimeout(bool StopAfterTimeout);
+
+    void SetTimeout(uint32_t timeout);
+    void SetStopAfterTimeout(bool stopAfterTimeout);
 
 private:
     void WaitForPacket();
 
-    uint32_t m_Timeout;
-    bool m_StopAfterTimeout;
+    std::atomic<uint32_t> m_Timeout;
+    std::atomic<bool> m_StopAfterTimeout;
     std::atomic<bool> m_IsRunning;
     std::atomic<bool> m_KeepRunning;
     std::thread m_CommandThread;
@@ -48,6 +60,6 @@ private:
     IProfilingConnection& m_SocketProfilingConnection;
 };
 
-}//namespace profiling
+} // namespace profiling
 
-}//namespace armnn
\ No newline at end of file
+} // namespace armnn
index d14791c43d9fbb5c8f3a27275cdea550ec9150c7..9dd7cd3d646dc098e817c5b696cd1dc90ff03982 100644 (file)
@@ -174,7 +174,6 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
         commandThread0.Start();
 
         commandThread0.Stop();
-        commandThread0.Join();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
 
@@ -188,11 +187,15 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
                                      testProfilingConnectionTimeOutError);
 
         commandThread1.Start();
-        commandThread1.Join();
+
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        BOOST_CHECK(!commandThread1.IsRunning());
+        commandThread1.Stop();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
         //now commandThread1 should persist after a timeout
-        commandThread1.StopAfterTimeout(false);
+        commandThread1.SetStopAfterTimeout(false);
         commandThread1.Start();
 
         for (int i = 0; i < 100; i++)
@@ -208,11 +211,9 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
         }
 
         commandThread1.Stop();
-        commandThread1.Join();
 
         BOOST_CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
 
-
         CommandThread commandThread2(1,
                                      false,
                                      commandHandlerRegistry,
@@ -226,13 +227,13 @@ BOOST_AUTO_TEST_CASE(CheckCommandThread)
             if (!commandThread2.IsRunning())
             {
                 //commandThread2 should stop once it encounters a non timing error
-                commandThread2.Join();
                 return;
             }
             std::this_thread::sleep_for(std::chrono::milliseconds(5));
         }
 
         BOOST_ERROR("commandThread2 has failed to stop");
+        commandThread2.Stop();
 }
 
 BOOST_AUTO_TEST_CASE(CheckEncodeVersion)